Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -2,27 +2,24 @@
# reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask, make_pad_mask_left
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
@@ -36,10 +33,17 @@ default_config = {
"EOS": 1024,
}
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
@@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
@@ -82,20 +87,20 @@ class T2SMLP:
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
@@ -114,24 +119,32 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
@@ -149,9 +162,7 @@ class T2SBlock:
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@@ -161,13 +172,20 @@ class T2SBlock:
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
@@ -176,7 +194,6 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
@@ -187,7 +204,11 @@ class T2SBlock:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
@@ -202,17 +223,19 @@ class T2SBlock:
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
@@ -220,14 +243,17 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : torch.Tensor=None,
torch_sdpa:bool=True
self,
x: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache
@@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
self.embedding_dim,
self.vocab_size,
self.p_dropout,
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.h = TransformerEncoder(
@@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
layer.linear2.bias,
)
block = T2SBlock(
@@ -309,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
layer.norm2.eps,
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
@@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
@@ -404,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
@@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
@@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@@ -563,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len",x_lens.max())
max_len = kwargs.get("max_len", x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@@ -574,14 +615,15 @@ class Text2SemanticDecoder(nn.Module):
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item)
x:torch.Tensor = torch.stack(x_list, dim=0)
x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
stop = False
@@ -594,34 +636,32 @@ class Text2SemanticDecoder(nn.Module):
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
@@ -639,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
@@ -655,74 +694,69 @@ class Text2SemanticDecoder(nn.Module):
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None]*y.shape[0]
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
attn_mask = F.pad(attn_mask,(0,1),value=False)
attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
if None not in idx_list:
stop = True
if stop:
if y.shape[1]==0:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@@ -730,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if (None in idx_list):
if None in idx_list:
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, [0] * x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
def infer_panel_naive_batched(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
**kwargs,
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs)
y, idx = self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs,
)
y_list.append(y[0])
idx_list.append(idx)
return y_list, idx_list
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -828,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
@@ -840,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
@@ -870,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)