Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -8,10 +9,12 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
@@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
)
)
if is_train:
@@ -35,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
@@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -8,6 +9,7 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
@@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
),
)
if is_train:
self.automatic_optimization = False
@@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@@ -2,27 +2,24 @@
# reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask, make_pad_mask_left
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
@@ -36,10 +33,17 @@ default_config = {
"EOS": 1024,
}
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
@@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
@@ -82,20 +87,20 @@ class T2SMLP:
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
@@ -114,24 +119,32 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
@@ -149,9 +162,7 @@ class T2SBlock:
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@@ -161,13 +172,20 @@ class T2SBlock:
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
@@ -176,7 +194,6 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
@@ -187,7 +204,11 @@ class T2SBlock:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
@@ -202,17 +223,19 @@ class T2SBlock:
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
@@ -220,14 +243,17 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : torch.Tensor=None,
torch_sdpa:bool=True
self,
x: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache
@@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
self.embedding_dim,
self.vocab_size,
self.p_dropout,
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.h = TransformerEncoder(
@@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
layer.linear2.bias,
)
block = T2SBlock(
@@ -309,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
layer.norm2.eps,
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
@@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
@@ -404,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
@@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
@@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@@ -563,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len",x_lens.max())
max_len = kwargs.get("max_len", x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@@ -574,14 +615,15 @@ class Text2SemanticDecoder(nn.Module):
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item)
x:torch.Tensor = torch.stack(x_list, dim=0)
x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
stop = False
@@ -594,34 +636,32 @@ class Text2SemanticDecoder(nn.Module):
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
@@ -639,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
@@ -655,74 +694,69 @@ class Text2SemanticDecoder(nn.Module):
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None]*y.shape[0]
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
attn_mask = F.pad(attn_mask,(0,1),value=False)
attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
if None not in idx_list:
stop = True
if stop:
if y.shape[1]==0:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@@ -730,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if (None in idx_list):
if None in idx_list:
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, [0] * x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
def infer_panel_naive_batched(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
**kwargs,
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs)
y, idx = self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs,
)
y_list.append(y[0])
idx_list.append(idx)
return y_list, idx_list
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -828,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
@@ -840,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
@@ -870,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)

View File

@@ -1,17 +1,13 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm
from AR.modules.embedding_onnx import SinePositionalEmbedding
from AR.modules.embedding_onnx import TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm
from AR.modules.transformer_onnx import TransformerEncoder
from AR.modules.transformer_onnx import TransformerEncoderLayer
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
@@ -26,12 +22,13 @@ default_config = {
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs(
logits,
previous_tokens = None,
previous_tokens=None,
temperature: float = 1.0,
top_k = None,
top_p = None,
top_k=None,
top_p=None,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
@@ -39,19 +36,27 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@@ -67,7 +72,7 @@ def logits_to_probs(
def multinomial_sample_one_no_sync(
probs_sort
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@@ -79,7 +84,9 @@ def sample(
**sampling_kwargs,
):
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@@ -91,7 +98,7 @@ class OnnxEncoder(nn.Module):
self.ar_text_embedding = ar_text_embedding
self.bert_proj = bert_proj
self.ar_text_position = ar_text_position
def forward(self, x, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
class T2SFirstStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@@ -111,11 +128,11 @@ class T2SFirstStageDecoder(nn.Module):
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, prompt):
y = prompt
x_example = x[:,:,0] * 0.0
#N, 1, 512
x_example = x[:, :, 0] * 0.0
# N, 1, 512
cache = {
"all_stage": self.num_layers,
"k": None,
@@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:,:,0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
@@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
@@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
class T2SStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
}
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
y_example = y_pos[:,:,0] * 0.0
y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
@@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
early_stop_num = self.early_stop_num
@@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_example = x[:,:,0] * 0.0
x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
@@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
@@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), value=False
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else:
@@ -335,4 +391,4 @@ class Text2SemanticDecoder(nn.Module):
break
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0
return y, idx
return y, idx

View File

@@ -1,8 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
# reference: https://github.com/lifeiteng/vall-e
from typing import Tuple
import torch
import torch.nn.functional as F
from typing import Tuple
def sequence_mask(length, max_length=None):
if max_length is None:
@@ -67,14 +69,18 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
return expaned_lengths<0
return expaned_lengths < 0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
@@ -105,9 +111,7 @@ def top_k_top_p_filtering(
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
@@ -130,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
return token
from typing import Optional, Tuple
from typing import Optional
def multinomial_sample_one_no_sync(
@@ -156,19 +160,21 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@@ -176,7 +182,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
@@ -188,18 +194,19 @@ def sample(
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
)
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
def dpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
@@ -214,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
return losses.mean(), chosen_rewards, rejected_rewards
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
def get_batch_logps(
logits_target: torch.FloatTensor,
logits_reject: torch.FloatTensor,
labels_target: torch.LongTensor,
labels_reject: torch.LongTensor,
average_log_prob: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# dummy token; we'll ignore the losses on these tokens later
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
per_token_logps_target = torch.gather(
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
).squeeze(2)
per_token_logps_reject = torch.gather(
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
).squeeze(2)
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
def make_reject_y(y_o, y_lens):
def repeat_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, range_text, range_text, shf])
return new_y
def lost_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, shf])
return new_y
bs = len(y_lens)
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
process_item_idx = torch.randint(0, 1, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
elif process_item_idx==1:
elif process_item_idx == 1:
new_y = lost_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
@@ -256,7 +276,7 @@ def make_reject_y(y_o, y_lens):
pad_length = max_length - reject_y_lens[b]
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
reject_y = torch.stack(reject_y, dim = 0)
reject_y = torch.stack(reject_y, dim=0)
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
return reject_y, reject_y_lens