more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler",
]
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]):
@@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches
"""
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int]=None,
rank: Optional[int]=None,
shuffle: bool=True,
seed: int=0,
drop_last: bool=False,
batch_size: int=32) -> None:
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(
self.
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) /
self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths
def make_buckets(self, bucket_width: float=2.0):
def make_buckets(self, bucket_width: float = 2.0):
buckets = []
cur = []
max_sec = bucket_width
@@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) *
grouped_batch_size] for b in range(n_batch)
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
shuffle(batches)
indices = list(itertools.chain(*batches))
@@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@@ -6,14 +6,21 @@ from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__()
self.config = config
self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers']
self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self):
pass
@@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'],
pad_val=self.config['data']['pad_val'])
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,
@@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = self.config['train']['batch_size']
sampler = DistributedBucketSampler(
self._train_dataset, batch_size=batch_size)
batch_size = self.config["train"]["batch_size"]
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
batch_size=batch_size,
@@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate,
num_workers=self.num_workers,
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
def val_dataloader(self):
@@ -52,9 +59,9 @@ class Text2SemanticDataModule(LightningDataModule):
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers,12),
num_workers=max(self.num_workers, 12),
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
# 这个会使用到嘛?
@@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule):
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate)
collate_fn=self._train_dataset.collate,
)

View File

@@ -1,21 +1,24 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback,os
import traceback, os
from typing import Dict
from typing import List
import numpy as np
import pandas as pd
import torch,json
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from text import cleaned_text_to_sequence
# from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0]
ndim = seq.ndim
@@ -28,44 +31,52 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
ndim - axis - 1)
padded_seq = np.pad(
seq, padding, mode='constant', constant_values=pad_value)
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch
class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training."""
def __init__(self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25) -> None:
def __init__(
self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25,
) -> None:
super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens
@@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset):
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])#
self.hz=int(os.environ.get("hz","25hz")[:-2])
self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token
self.max_sec = max_sec
@@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self):
semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys())
@@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data['item_name'][i]
item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data['semantic_audio'][i]
semantic_str = self.semantic_data["semantic_audio"][i]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1
continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
try:
phoneme_ids = cleaned_text_to_sequence(phoneme)
@@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset):
idx += 1
self.item_names.append(item_name)
min_num=100#20直接不补#30补了也不存ckpt
leng =len(self.semantic_phoneme)
if(leng<min_num):
tmp1=self.semantic_phoneme
tmp2=self.item_names
self.semantic_phoneme=[]
self.item_names=[]
for _ in range(max(2,int(min_num/leng))):
self.semantic_phoneme+=tmp1
self.item_names+=tmp2
min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme)
if leng < min_num:
tmp1 = self.semantic_phoneme
tmp2 = self.item_names
self.semantic_phoneme = []
self.item_names = []
for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme += tmp1
self.item_names += tmp2
if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
@@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
)
'''
"""
there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463
'''
"""
# 345410 for LibriTTS
print("dataset.__len__():", self.__len__())
@@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset):
# semantic tokens target
semantic_ids_len = len(semantic_ids)
flag=0
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
else:flag=1
if(flag==1):
if os.path.exists(path_bert) == True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature=None
bert_feature = None
else:
assert bert_feature.shape[-1] == len(phoneme_ids)
return {
'idx': idx,
'phoneme_ids': phoneme_ids,
'phoneme_ids_len': phoneme_ids_len,
'semantic_ids': semantic_ids,
'semantic_ids_len': semantic_ids_len,
'bert_feature': bert_feature,
"idx": idx,
"phoneme_ids": phoneme_ids,
"phoneme_ids_len": phoneme_ids_len,
"semantic_ids": semantic_ids,
"semantic_ids_len": semantic_ids_len,
"bert_feature": bert_feature,
}
def get_sample_length(self, idx: int):
@@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_()
for idx, item in enumerate(examples):
bert = item['bert_feature']
if(bert!=None):
bert_padded[idx, :, :bert.shape[-1]] = bert
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
@@ -276,27 +293,27 @@ class Text2SemanticDataset(Dataset):
}
if __name__ == '__main__':
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy',
semantic_path=root_dir + 'semantic_train.tsv')
phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False)
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
)
for i, batch in enumerate(dataloader):
if(i%1000==0):print(i)
if i % 1000 == 0:
print(i)
# if i == 0:
# print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)

View File

@@ -1,5 +1,6 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os,sys
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict
@@ -12,29 +13,35 @@ from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir,is_train=True):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1=config.get("pretrained_s1")
if(pretrained_s1 and is_train):
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
)
if is_train:
self.automatic_optimization = False
self.save_hyperparameters()
self.eval_dir = output_dir / 'eval'
self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
loss, acc = self.model.forward(
batch['phoneme_ids'], batch['phoneme_ids_len'],
batch['semantic_ids'], batch['semantic_ids_len'],
batch['bert_feature'])
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
@@ -47,63 +54,67 @@ class Text2SemanticLightningModule(LightningModule):
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
self.log(
"lr",
scheduler.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
self.log(
f"top_{self.top_k}_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
# batch['semantic_ids'], batch['semantic_ids_len'],
# batch['bert_feature']
# )
#
# self.log(
# "val_total_loss",
# loss,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
# self.log(
# f"val_top_{self.top_k}_acc",
# acc,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
#
# # get infer output
# semantic_len = batch['semantic_ids'].size(1)
# prompt_len = min(int(semantic_len * 0.5), 150)
# prompt = batch['semantic_ids'][:, :prompt_len]
# pred_semantic = self.model.infer(batch['phoneme_ids'],
# batch['phoneme_ids_len'], prompt,
# batch['bert_feature']
# )
# save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def validation_step(self, batch: Dict, batch_idx: int):
return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
# batch['semantic_ids'], batch['semantic_ids_len'],
# batch['bert_feature']
# )
#
# self.log(
# "val_total_loss",
# loss,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
# self.log(
# f"val_top_{self.top_k}_acc",
# acc,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
#
# # get infer output
# semantic_len = batch['semantic_ids'].size(1)
# prompt_len = min(int(semantic_len * 0.5), 150)
# prompt = batch['semantic_ids'][:, :prompt_len]
# pred_semantic = self.model.infer(batch['phoneme_ids'],
# batch['phoneme_ids_len'], prompt,
# batch['bert_feature']
# )
# save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append([
name_param_pair[0]
for name_param_pair in self.model.named_parameters()
])
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,
@@ -111,18 +122,19 @@ class Text2SemanticLightningModule(LightningModule):
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000, )
clipping_update_period=1000,
)
return {
"optimizer": lm_opt,
"lr_scheduler": {
"scheduler":
WarmupCosineLRSchedule(
"scheduler": WarmupCosineLRSchedule(
lm_opt,
init_lr=self.config['optimizer']['lr_init'],
peak_lr=self.config['optimizer']['lr'],
end_lr=self.config['optimizer']['lr_end'],
warmup_steps=self.config['optimizer']['warmup_steps'],
total_steps=self.config['optimizer']['decay_steps'])
}
init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config["optimizer"]["decay_steps"],
)
},
}

View File

@@ -3,7 +3,12 @@ import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
@@ -22,35 +27,39 @@ default_config = {
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024
"EOS": 1024,
}
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config['model']["hidden_dim"]
self.embedding_dim = config['model']["embedding_dim"]
self.num_head = config['model']["head"]
self.num_layers = config['model']["n_layer"]
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config['model']["vocab_size"]
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
self.p_dropout = config['model']["dropout"]
self.EOS = config['model']["EOS"]
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout)
self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder(
TransformerEncoderLayer(
@@ -59,28 +68,30 @@ class Text2SemanticDecoder(nn.Module):
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first, ),
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None, )
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(
self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS, )
ignore_index=self.EOS,
)
def forward(self, x, x_lens, y, y_lens, bert_feature):
'''
"""
x: phoneme_ids
y: semantic_ids
'''
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
@@ -102,18 +113,23 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True, )
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1, ),
diagonal=1,
),
(x_len, 0),
value=False, )
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len))
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
@@ -122,26 +138,28 @@ class Text2SemanticDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask, )
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction='sum')
loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(self,
x,
x_lens,
prompts,
bert_feature,
top_k: int=-100,
early_stop_num: int=-1,
temperature: float=1.0):
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
@@ -159,35 +177,37 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True, )
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False, )
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask, )
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature)
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len
) > early_stop_num:
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction')
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
@@ -198,23 +218,24 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(
y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(self,
x,#####全部文本token
x_lens,
prompts,####参考音频token
bert_feature,
top_k: int=-100,
early_stop_num: int=-1,
temperature: float=1.0):
def infer_panel(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
@@ -224,75 +245,81 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache={
"all_stage":self.num_layers,
"k":[None]*self.num_layers,###根据配置自己手写
"v":[None]*self.num_layers,
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb":None,##只需要对最新的samples求emb再拼历史的就行
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer":1,
"stage":0
"first_infer": 1,
"stage": 0,
}
for idx in tqdm(range(1500)):
if(cache["first_infer"]==1):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
cache["y_emb"]=y_emb
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
if(cache["first_infer"]==1):
if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1)
else:
xy_pos=y_pos[:,-1:]
xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1]
###以下3个不做缓存
if (cache["first_infer"] == 1):
if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True, )
y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False, )
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
else:
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
# pdb.set_trace()
###缓存重头戏
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,cache=cache )
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len
) > early_stop_num:
samples = sample(
logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction')
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
cache["first_infer"]=0
return y,idx
cache["first_infer"] = 0
return y, idx

View File

@@ -2,6 +2,7 @@
import torch
import torch.nn.functional as F
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
@@ -9,7 +10,7 @@ def sequence_mask(length, max_length=None):
return x.unsqueeze(0) < length.unsqueeze(1)
def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
@@ -38,11 +39,9 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1):
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
@@ -53,16 +52,14 @@ def top_k_top_p_filtering(logits,
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep),
logits.size(-1)) # Safety check
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
@@ -70,13 +67,13 @@ def top_k_top_p_filtering(logits,
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
@@ -100,6 +97,8 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
from typing import Optional, Tuple
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
@@ -115,7 +114,7 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
previous_tokens=previous_tokens.squeeze()
previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
@@ -159,4 +158,3 @@ def sample(
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

View File

@@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward=multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
@@ -76,66 +78,71 @@ class MultiheadAttention(Module):
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None, ) -> None:
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (self.kdim == embed_dim and
self.vdim == embed_dim)
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs))
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs))
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs))
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs))
torch.empty(3 * embed_dim, **factory_kwargs)
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters()
else:
@@ -143,7 +150,8 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
@@ -156,7 +164,8 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
@@ -190,14 +199,15 @@ class MultiheadAttention(Module):
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor]=None,
need_weights: bool=True,
attn_mask: Optional[Tensor]=None,
average_attn_weights: bool=True,cache=None
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@@ -251,23 +261,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask):
key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (self.in_proj_bias is not None and
query.dtype != self.in_proj_bias.dtype):
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (self.in_proj_weight is not None and
query.dtype != self.in_proj_weight.dtype):
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
@@ -288,29 +301,41 @@ class MultiheadAttention(Module):
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input")
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (query, key, value, self.in_proj_weight,
self.in_proj_bias, self.out_proj.weight,
self.out_proj.bias, )
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args]):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU")
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]):
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@@ -322,17 +347,21 @@ class MultiheadAttention(Module):
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None else attn_mask,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else 0
if attn_mask is not None else None, )
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}")
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
@@ -343,9 +372,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
@@ -370,7 +397,9 @@ class MultiheadAttention(Module):
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,cache=cache )
average_attn_weights=average_attn_weights,
cache=cache,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
@@ -390,7 +419,9 @@ class MultiheadAttention(Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,cache=cache )
average_attn_weights=average_attn_weights,
cache=cache,
)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:

View File

@@ -7,10 +7,11 @@ from torch import nn
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float=0.0, ):
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
@@ -24,7 +25,7 @@ class TokenEmbedding(nn.Module):
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index:index + 1]
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
@@ -34,11 +35,12 @@ class TokenEmbedding(nn.Module):
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float=0.0,
scale: bool=False,
alpha: bool=False, ):
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
@@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module):
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32).unsqueeze(1)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.embedding_dim))
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
@@ -74,5 +77,5 @@ class SinePositionalEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)

View File

@@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
"""
def __init__(self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0):
def __init__(
self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0,
):
self.init_lr = init_lr
self.peak_lr = peak_lr
self.end_lr = end_lr
@@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
self._last_lr = [self.lr]
def set_lr(self, lr):
self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups:
# g['lr'] = lr
g['lr'] = self.end_lr###锁定用线性
g["lr"] = self.end_lr ###锁定用线性
def step(self):
if self._current_step < self.warmup_steps:
@@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps)
self.total_steps - self.warmup_steps
)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
@@ -55,25 +58,19 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
self.set_lr(lr)
self.lr = lr
self._current_step += 1
return self.lr
if __name__ == '__main__':
if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0)
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
)
lrs = []
for i in range(25000):
s.step()

View File

@@ -1,9 +1,16 @@
from torch.nn.functional import *
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
# import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched(
query: Tensor,
key: Tensor,
@@ -29,7 +36,8 @@ def multi_head_attention_forward_patched(
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,cache=None
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@@ -105,7 +113,17 @@ def multi_head_attention_forward_patched(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
@@ -134,10 +152,13 @@ def multi_head_attention_forward_patched(
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attn_weights,cache=cache
average_attn_weights=average_attn_weights,
cache=cache,
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
@@ -159,7 +180,7 @@ def multi_head_attention_forward_patched(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
target_type=query.dtype,
)
if is_causal and attn_mask is None:
@@ -184,59 +205,84 @@ def multi_head_attention_forward_patched(
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
if(cache!=None):
if(cache["first_infer"]==1):
cache["k"][cache["stage"]]=k
q, k, v = _in_projection(
query,
key,
value,
q_proj_weight,
k_proj_weight,
v_proj_weight,
b_q,
b_k,
b_v,
)
if cache != None:
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
# print(0,cache["k"].shape)
cache["v"][cache["stage"]]=v
else:###12个layer每个都要留自己的cache_kv
cache["v"][cache["stage"]] = v
else: ###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape)
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
cache["k"][cache["stage"]] = torch.cat(
[cache["k"][cache["stage"]], k], 0
) ##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
# print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0]
k=cache["k"][cache["stage"]]
v=cache["v"][cache["stage"]]
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
# if attn_mask is not None:
# attn_mask=attn_mask[-1:,]
# print(attn_mask.shape,attn_mask)
# print(attn_mask.shape,attn_mask)
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
# print(2333,cache)
# prep attention mask
@@ -255,14 +301,20 @@ def multi_head_attention_forward_patched(
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
@@ -286,26 +338,34 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
@@ -316,10 +376,15 @@ def multi_head_attention_forward_patched(
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
assert key_padding_mask.shape == (
bsz,
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
else:
@@ -337,10 +402,14 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
@@ -349,7 +418,9 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@@ -377,8 +448,12 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
) + torch.rand_like(deriv)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
(d, ) = ctx.saved_tensors
(d,) = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.043637
ceil = 1.2
@@ -96,11 +97,12 @@ class DoubleSwish(torch.nn.Module):
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor:
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
@@ -125,16 +127,22 @@ class ActivationBalancerFunction(torch.autograd.Function):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, )
return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor(
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@@ -145,23 +153,25 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = (
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor)
min=0, max=max_factor
)
return below_threshold - above_threshold
def _compute_sign_factor(
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@@ -171,18 +181,18 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) *
(gain_factor / min_positive)).clamp_(
min=0, max=max_factor)
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) *
(gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor)
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@@ -230,17 +240,18 @@ class ActivationBalancer(torch.nn.Module):
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: float=0.05,
max_positive: float=0.95,
max_factor: float=0.04,
sign_gain_factor: float=0.01,
scale_gain_factor: float=0.02,
min_abs: float=0.2,
max_abs: float=100.0,
min_prob: float=0.1, ):
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
@@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
torch.jit.is_tracing()):
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
return _no_op(x)
count = self.cpu_count
@@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
if random.random() < prob:
sign_gain_factor = 0.5
@@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
self.min_positive,
self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
else:
sign_factor = None
@@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply(
x,
scale_factor,
sign_factor,
self.channel_dim, )
self.channel_dim,
)
else:
return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
min_prob=0.25) -> nn.Sequential:
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential(
balancer,
DoubleSwish(), )
DoubleSwish(),
)

View File

@@ -26,26 +26,28 @@ class LayerNorm(nn.Module):
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float=1e-5,
elementwise_affine: bool=True,
device=None,
dtype=None, ) -> None:
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape, ) # type: ignore[assignment]
self.normalized_shape = tuple(
normalized_shape) # type: ignore[arg-type]
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@@ -57,36 +59,43 @@ class LayerNorm(nn.Module):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps, ), embedding, )
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight,
self.bias, self.eps)
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__))
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float=1e-5,
device=None,
dtype=None, ) -> None:
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
@@ -121,11 +130,13 @@ class TransformerEncoder(nn.Module):
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,
return_layer_states: bool=False,cache=None ) -> Tensor:
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
@@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0])
if self.norm is not None:
@@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
output = src
for mod in self.layers:
output = mod(output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None:
output = self.norm(output)
@@ -168,43 +184,47 @@ class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int=2048,
dropout: float=0.1,
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
batch_first: bool=False,
norm_first: bool=False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module=nn.Linear,
linear2_self_attention_cls: nn.Module=nn.Linear,
linear1_feedforward_cls: nn.Module=nn.Linear,
linear2_feedforward_cls: nn.Module=nn.Linear,
layer_norm_cls: nn.Module=LayerNorm,
layer_norm_eps: float=1e-5,
adaptive_layer_norm=False, ) -> None:
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead)
# import os
# os._exit(2333333)
self.self_attn = MultiheadAttention(
d_model,#512 16
d_model, # 512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs, )
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
**factory_kwargs)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
**factory_kwargs)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@@ -249,10 +267,12 @@ class TransformerEncoderLayer(nn.Module):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask):
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
@@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,cache=cache )
src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
stage_embedding, )
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
@@ -295,12 +319,14 @@ class TransformerEncoderLayer(nn.Module):
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask)
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os
# os._exit(23333)
x = self.self_attn(
@@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0]
need_weights=False,
cache=cache,
)[0]
return self.dropout1(x)
# feed forward block
@@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module):
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@@ -27,46 +27,44 @@ class GruutPhonemizer:
"": "",
"": "",
"«": "«",
"»": "»"
"»": "»",
}
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text)
return text.strip()
def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes:
return ''
if word.phonemes[0] in ['', '|']:
return ""
if word.phonemes[0] in ["", "|"]:
return word.text.strip()
phonemes = ''.join(word.phonemes)
phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex
phonemes = re.sub(r'[ˈˌː͡]', '', phonemes)
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip()
def phonemize(self, text: str, espeak: bool=False) -> str:
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(
text_to_phonemize, lang="en-us", espeak=espeak)
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
return ' '.join(words)
return " ".join(words)
def transform(self, phonemes):
# convert phonemes to ids
# dictionary is in symbols.py
return [
self.symbol_to_id[p] for p in phonemes
if p in self.symbol_to_id.keys()
]
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
PAD = '_'
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")

View File

@@ -11,22 +11,24 @@ def load_yaml_config(path):
def save_config_to_yaml(config, path):
assert path.endswith('.yaml')
with open(path, 'w') as f:
assert path.endswith(".yaml")
with open(path, "w") as f:
f.write(yaml.dump(config))
f.close()
def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args)
if not name.startswith('_'))
with open(path, 'a') as args_file:
args_file.write('==> torch version: {}\n'.format(torch.__version__))
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
args_file.write('==> Cmd:\n')
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write('\n==> args:\n')
args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v)))
args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close()