Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -4,14 +4,11 @@ import itertools
import math
import random
from random import shuffle
from typing import Iterator
from typing import Optional
from typing import TypeVar
from typing import Iterator, Optional, TypeVar
import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from torch.utils.data import Sampler
from torch.utils.data import Dataset, Sampler
__all__ = [
"DistributedBucketSampler",
@@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas
len(self.dataset) / self.num_replicas,
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
@@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
grouped_batch_size = self.batch_size * self.num_replicas
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
shuffle(batches)
indices = list(itertools.chain(*batches))
else:
@@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]

View File

@@ -1,9 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
# reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
@@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
batch_size = (
self.config["train"]["batch_size"] // 2
if self.config["train"].get("if_dpo", False) is True
else self.config["train"]["batch_size"]
)
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,

View File

@@ -1,21 +1,17 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
# reference: https://github.com/lifeiteng/vall-e
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback, os
from typing import Dict
from typing import List
import os
import traceback
from typing import Dict, List
import numpy as np
import pandas as pd
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader, Dataset
version = os.environ.get('version',None)
version = os.environ.get("version", None)
from text import cleaned_text_to_sequence
@@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
@@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
super().__init__()
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
semantic_path,
delimiter="\t",
encoding="utf-8",
)
# get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.dirname(phoneme_path)
os.path.dirname(
phoneme_path,
)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
@@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data.iloc[i,0]
item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data.iloc[i,1]
semantic_str = self.semantic_data.iloc[i, 1]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
@@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
print(
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
)
if num_deleted_ps > 0:
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
)
"""
there are 31 semantic datas not in phoneme datas
@@ -306,7 +300,10 @@ if __name__ == "__main__":
batch_size = 12
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False,
)
for i, batch in enumerate(dataloader):
if i % 1000 == 0:

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -8,10 +9,12 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
@@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
)
)
if is_train:
@@ -35,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
@@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -8,6 +9,7 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
@@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
),
)
if is_train:
self.automatic_optimization = False
@@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@@ -2,27 +2,24 @@
# reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask, make_pad_mask_left
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
@@ -36,10 +33,17 @@ default_config = {
"EOS": 1024,
}
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
@@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
@@ -82,20 +87,20 @@ class T2SMLP:
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
@@ -114,24 +119,32 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
@@ -149,9 +162,7 @@ class T2SBlock:
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@@ -161,13 +172,20 @@ class T2SBlock:
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
@@ -176,7 +194,6 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
@@ -187,7 +204,11 @@ class T2SBlock:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
@@ -202,17 +223,19 @@ class T2SBlock:
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
@@ -220,14 +243,17 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : torch.Tensor=None,
torch_sdpa:bool=True
self,
x: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache
@@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
self.embedding_dim,
self.vocab_size,
self.p_dropout,
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.h = TransformerEncoder(
@@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
layer.linear2.bias,
)
block = T2SBlock(
@@ -309,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
layer.norm2.eps,
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
@@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
@@ -404,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
@@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
@@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@@ -563,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len",x_lens.max())
max_len = kwargs.get("max_len", x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@@ -574,14 +615,15 @@ class Text2SemanticDecoder(nn.Module):
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item)
x:torch.Tensor = torch.stack(x_list, dim=0)
x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
stop = False
@@ -594,34 +636,32 @@ class Text2SemanticDecoder(nn.Module):
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
@@ -639,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
@@ -655,74 +694,69 @@ class Text2SemanticDecoder(nn.Module):
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None]*y.shape[0]
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
attn_mask = F.pad(attn_mask,(0,1),value=False)
attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
if None not in idx_list:
stop = True
if stop:
if y.shape[1]==0:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@@ -730,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if (None in idx_list):
if None in idx_list:
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, [0] * x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
def infer_panel_naive_batched(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
**kwargs,
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs)
y, idx = self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs,
)
y_list.append(y[0])
idx_list.append(idx)
return y_list, idx_list
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -828,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
@@ -840,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
@@ -870,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)

View File

@@ -1,17 +1,13 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm
from AR.modules.embedding_onnx import SinePositionalEmbedding
from AR.modules.embedding_onnx import TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm
from AR.modules.transformer_onnx import TransformerEncoder
from AR.modules.transformer_onnx import TransformerEncoderLayer
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
@@ -26,12 +22,13 @@ default_config = {
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs(
logits,
previous_tokens = None,
previous_tokens=None,
temperature: float = 1.0,
top_k = None,
top_p = None,
top_k=None,
top_p=None,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
@@ -39,19 +36,27 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@@ -67,7 +72,7 @@ def logits_to_probs(
def multinomial_sample_one_no_sync(
probs_sort
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@@ -79,7 +84,9 @@ def sample(
**sampling_kwargs,
):
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@@ -91,7 +98,7 @@ class OnnxEncoder(nn.Module):
self.ar_text_embedding = ar_text_embedding
self.bert_proj = bert_proj
self.ar_text_position = ar_text_position
def forward(self, x, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
class T2SFirstStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@@ -111,11 +128,11 @@ class T2SFirstStageDecoder(nn.Module):
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, prompt):
y = prompt
x_example = x[:,:,0] * 0.0
#N, 1, 512
x_example = x[:, :, 0] * 0.0
# N, 1, 512
cache = {
"all_stage": self.num_layers,
"k": None,
@@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:,:,0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
@@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
@@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
class T2SStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
}
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
y_example = y_pos[:,:,0] * 0.0
y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
@@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
early_stop_num = self.early_stop_num
@@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_example = x[:,:,0] * 0.0
x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
@@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
@@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), value=False
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else:
@@ -335,4 +391,4 @@ class Text2SemanticDecoder(nn.Module):
break
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0
return y, idx
return y, idx

View File

@@ -1,8 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
# reference: https://github.com/lifeiteng/vall-e
from typing import Tuple
import torch
import torch.nn.functional as F
from typing import Tuple
def sequence_mask(length, max_length=None):
if max_length is None:
@@ -67,14 +69,18 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
return expaned_lengths<0
return expaned_lengths < 0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
@@ -105,9 +111,7 @@ def top_k_top_p_filtering(
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
@@ -130,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
return token
from typing import Optional, Tuple
from typing import Optional
def multinomial_sample_one_no_sync(
@@ -156,19 +160,21 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@@ -176,7 +182,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
@@ -188,18 +194,19 @@ def sample(
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
)
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
def dpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
@@ -214,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
return losses.mean(), chosen_rewards, rejected_rewards
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
def get_batch_logps(
logits_target: torch.FloatTensor,
logits_reject: torch.FloatTensor,
labels_target: torch.LongTensor,
labels_reject: torch.LongTensor,
average_log_prob: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# dummy token; we'll ignore the losses on these tokens later
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
per_token_logps_target = torch.gather(
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
).squeeze(2)
per_token_logps_reject = torch.gather(
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
).squeeze(2)
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
def make_reject_y(y_o, y_lens):
def repeat_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, range_text, range_text, shf])
return new_y
def lost_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, shf])
return new_y
bs = len(y_lens)
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
process_item_idx = torch.randint(0, 1, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
elif process_item_idx==1:
elif process_item_idx == 1:
new_y = lost_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
@@ -256,7 +276,7 @@ def make_reject_y(y_o, y_lens):
pad_length = max_length - reject_y_lens[b]
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
reject_y = torch.stack(reject_y, dim = 0)
reject_y = torch.stack(reject_y, dim=0)
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
return reject_y, reject_y_lens

View File

@@ -1,17 +1,14 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
@@ -73,6 +70,7 @@ class MultiheadAttention(Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
@@ -104,9 +102,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@@ -117,31 +113,32 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty((embed_dim, embed_dim), **factory_kwargs),
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty((embed_dim, self.kdim), **factory_kwargs),
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty((embed_dim, self.vdim), **factory_kwargs),
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
self._reset_parameters()
@@ -150,7 +147,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@@ -164,7 +164,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:
@@ -261,28 +264,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
key_padding_mask,
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
raise AssertionError("only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
)
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
)
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
@@ -300,9 +301,7 @@ class MultiheadAttention(Module):
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
@@ -322,20 +321,10 @@ class MultiheadAttention(Module):
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@@ -350,11 +339,7 @@ class MultiheadAttention(Module):
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested

View File

@@ -1,17 +1,13 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
@@ -47,9 +43,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@@ -60,18 +54,30 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty(
(embed_dim, embed_dim),
**factory_kwargs,
)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty(
(embed_dim, self.kdim),
**factory_kwargs,
)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty(
(embed_dim, self.vdim),
**factory_kwargs,
)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty(
(3 * embed_dim, embed_dim),
**factory_kwargs,
)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
@@ -79,13 +85,11 @@ class MultiheadAttention(Module):
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
torch.empty(3 * embed_dim, **factory_kwargs),
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
else:
@@ -93,7 +97,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@@ -107,7 +114,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:

View File

@@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
return
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

View File

@@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim)

View File

@@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
lr = self.end_lr
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
)
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
@@ -70,7 +66,13 @@ if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0,
)
lrs = []
for i in range(25000):

View File

@@ -16,8 +16,7 @@
import contextlib
import logging
from collections import defaultdict
from typing import List
from typing import Tuple
from typing import List, Tuple
import torch
from torch import Tensor
@@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
group_params_names: name for each parameter in group,
which is List[str].
"""
batches = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
@@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [
batches_names[batches_names_keys[idx]] for idx in sorted_idx
]
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
@@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
])
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
tuples.append((p_stacked, state, batch_names))
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@@ -164,25 +154,24 @@ class ScaledAdam(BatchedOptimizer):
"""
def __init__(
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=3.0,
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True, ):
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=3.0,
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter")
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period, )
clipping_update_period=clipping_update_period,
)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
@@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups,
self.parameters_names):
with self.batched_params(group["params"],
group_params_names) as batches:
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
if (len(batches[0][1]) ==
0): # if len(first state) == 0: not yet initialized
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
@@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
@@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
@@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period,
*param_rms.shape, **kwargs)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale(self,
group: dict,
tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
@@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients")
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"])**2).sum()
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt()
if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(
clipping_update_period, device=p.device)
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
@@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
for n in range(0, 5):
index = min(
clipping_update_period - 1,
(clipping_update_period // 4) * n, )
(clipping_update_period // 4) * n,
)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 /
clipping_update_period
if "num_clipped" in first_state else 0.0)
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
@@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor):
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
"""
Show information of parameter wihch dominanting tot_sumsq.
@@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(batch_param_names,
batch_sumsq_orig,
batch_rms_orig, batch_grad):
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(
batch_param_names,
batch_sumsq_orig,
batch_rms_orig,
batch_grad,
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0), )
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(),
key=lambda item: item[1][0],
reverse=True, )
reverse=True,
)
}
dominant_param_name = next(iter(sorted_by_proportion))
(dominant_proportion, dominant_sumsq, dominant_rms,
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
(
dominant_proportion,
dominant_sumsq,
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict,
clipping_scale: float):
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
@@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2)
.mean(dim=list(range(1, p.ndim)), keepdim=True)
.sqrt())
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = step + 1
def _size_update(self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
def _size_update(
self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict,
) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
@@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state[
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (-size_lr * (bias_correction2**0.5) *
scale_grads.sum(dim=0) / denom)
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
@@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"]
if "zero_step" in state else 0)
bias_correction2 = 1 - beta2**(this_step + 1)
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
@@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2**(state["step"] + 1)
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]

View File

@@ -5,7 +5,6 @@ from torch.nn.functional import (
_none_or_dtype,
_in_projection_packed,
)
from torch.nn import functional as F
import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
@@ -25,18 +24,18 @@ def multi_head_attention_forward_patched(
dropout_p: float,
out_proj_weight,
out_proj_bias,
training = True,
key_padding_mask = None,
need_weights = True,
attn_mask = None,
use_separate_proj_weight = False,
q_proj_weight = None,
k_proj_weight = None,
v_proj_weight = None,
static_k = None,
static_v = None,
average_attn_weights = True,
is_causal = False,
training=True,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
use_separate_proj_weight=False,
q_proj_weight=None,
k_proj_weight=None,
v_proj_weight=None,
static_k=None,
static_v=None,
average_attn_weights=True,
is_causal=False,
cache=None,
):
r"""
@@ -156,9 +155,7 @@ def multi_head_attention_forward_patched(
cache=cache,
)
is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
@@ -211,45 +208,33 @@ def multi_head_attention_forward_patched(
# longer causal.
is_causal = False
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert embed_dim == embed_dim_to_check, (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
)
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
)
else:
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
@@ -312,9 +297,7 @@ def multi_head_attention_forward_patched(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
@@ -338,34 +321,26 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert static_k.size(0) == bsz * num_heads, (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
)
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert static_v.size(0) == bsz * num_heads, (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
)
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
@@ -381,9 +356,7 @@ def multi_head_attention_forward_patched(
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
@@ -402,14 +375,10 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
@@ -418,9 +387,7 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@@ -449,13 +416,9 @@ def multi_head_attention_forward_patched(
v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@@ -1,11 +1,9 @@
from torch.nn.functional import *
from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
def multi_head_attention_forward_patched(
query,
key,
@@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
_, _, embed_dim = query.shape
attn_mask = _canonical_mask(
@@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(-1, 1, attn_output.size(1))

View File

@@ -13,12 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import random
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import torch.nn as nn
@@ -61,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@@ -153,13 +148,9 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
return below_threshold - above_threshold
@@ -181,18 +172,16 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor
)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@@ -320,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x)
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
return nn.Sequential(
balancer,
DoubleSwish(),

View File

@@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
@@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
@@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError("only bool and floating types of key_padding_mask are supported")
if self.norm_first:
x = x + self._sa_block(

View File

@@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
@@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
@@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
@@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

View File

@@ -30,9 +30,7 @@ class GruutPhonemizer:
"«": "«",
"»": "»",
}
self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
@@ -53,13 +51,8 @@ class GruutPhonemizer:
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
return " ".join(words)
def transform(self, phonemes):

View File

@@ -3,7 +3,9 @@
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
IPA_LETTERS = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}

View File

@@ -2,12 +2,12 @@ import re
def str2bool(str):
return True if str.lower() == 'true' else False
return True if str.lower() == "true" else False
def get_newest_ckpt(string_list):
# 定义一个正则表达式模式,用于匹配字符串中的数字
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
extracted_info = []
@@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
step = int(match.group(2))
extracted_info.append((epoch, step, string))
# 按照 epoch 后面的数字和 step 后面的数字进行排序
sorted_info = sorted(
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
# 获取最新的 ckpt 文件名
newest_ckpt = sorted_info[0][2]
return newest_ckpt
@@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
# 文本存在且不为空时 return True
def check_txt_file(file_path):
try:
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
text = file.readline().strip()
assert text.strip() != ''
assert text.strip() != ""
return text
except Exception:
return False

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python3
"""Initialize modules for espnet2 neural networks."""
import torch
from typeguard import check_argument_types

View File

@@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
def write_args(args, path):
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write("\n==> args:\n")