Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -2,27 +2,24 @@
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import make_pad_mask, make_pad_mask_left
|
||||
from AR.models.utils import (
|
||||
topk_sampling,
|
||||
sample,
|
||||
logits_to_probs,
|
||||
multinomial_sample_one_no_sync,
|
||||
dpo_loss,
|
||||
make_reject_y,
|
||||
get_batch_logps
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding
|
||||
from AR.modules.embedding import TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm
|
||||
from AR.modules.transformer import TransformerEncoder
|
||||
from AR.modules.transformer import TransformerEncoderLayer
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import (
|
||||
dpo_loss,
|
||||
get_batch_logps,
|
||||
make_pad_mask,
|
||||
make_pad_mask_left,
|
||||
make_reject_y,
|
||||
sample,
|
||||
topk_sampling,
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@@ -36,10 +33,17 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
|
||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
def scaled_dot_product_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
if scale is None:
|
||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||
else:
|
||||
@@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
else:
|
||||
attn_mask[attn_mask!=float("-inf")] =0
|
||||
attn_mask[attn_mask==float("-inf")] =1
|
||||
attn_mask[attn_mask != float("-inf")] = 0
|
||||
attn_mask[attn_mask == float("-inf")] = 1
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
|
||||
return attn_weight @ value
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@@ -82,20 +87,20 @@ class T2SMLP:
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2,
|
||||
self,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
@@ -114,24 +119,32 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
|
||||
|
||||
def process_prompt(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k.shape[1]
|
||||
|
||||
|
||||
q = self.to_mask(q, padding_mask)
|
||||
k_cache = self.to_mask(k, padding_mask)
|
||||
v_cache = self.to_mask(v, padding_mask)
|
||||
@@ -149,9 +162,7 @@ class T2SBlock:
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@@ -161,13 +172,20 @@ class T2SBlock:
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
|
||||
|
||||
def decode_next_token(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
@@ -176,7 +194,6 @@ class T2SBlock:
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||
else:
|
||||
@@ -187,7 +204,11 @@ class T2SBlock:
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w1,
|
||||
self.norm_b1,
|
||||
self.norm_eps1,
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
@@ -202,17 +223,19 @@ class T2SBlock:
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
||||
padding_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
):
|
||||
k_cache : List[torch.Tensor] = []
|
||||
v_cache : List[torch.Tensor] = []
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
||||
k_cache.append(k_cache_)
|
||||
@@ -220,14 +243,17 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask : torch.Tensor=None,
|
||||
torch_sdpa:bool=True
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
||||
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# assert self.EOS == 1024
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.phoneme_vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
|
||||
self.h = TransformerEncoder(
|
||||
@@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
layer.linear2.bias,
|
||||
)
|
||||
|
||||
block = T2SBlock(
|
||||
@@ -309,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
@@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||
|
||||
###### DPO #############
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||
x, x_lens, reject_y, reject_y_lens, bert_feature
|
||||
)
|
||||
|
||||
reject_xy_dec, _ = self.h(
|
||||
(reject_xy_pos, None),
|
||||
@@ -404,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||
|
||||
|
||||
loss = loss_1 + loss_2
|
||||
|
||||
return loss, acc
|
||||
@@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
||||
def infer(
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
y.device
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
||||
|
||||
xy_dec, _ = self.h(
|
||||
(xy_pos, None),
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = topk_sampling(
|
||||
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
||||
)
|
||||
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
@@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
return y
|
||||
|
||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
||||
y_mask_int, (0, 1), value=1
|
||||
)
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel_batch_infer(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
@@ -563,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
|
||||
):
|
||||
if prompts is None:
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
early_stop_num=early_stop_num,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
max_len = kwargs.get("max_len",x_lens.max())
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
@@ -574,14 +615,15 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
|
||||
x_item = (
|
||||
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
||||
) ### padding left
|
||||
x_list.append(x_item)
|
||||
x:torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
x: torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
|
||||
x_len = x.shape[1]
|
||||
stop = False
|
||||
|
||||
@@ -594,34 +636,32 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
|
||||
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
|
||||
|
||||
|
||||
# (bsz, x_len + y_len)
|
||||
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
|
||||
x_mask = F.pad(
|
||||
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
x_mask = F.pad(
|
||||
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
### 上面是错误的,会导致padding的token被"看见"
|
||||
|
||||
@@ -639,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||
|
||||
|
||||
# 正确的attn_mask应该是这样的:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
@@ -655,74 +694,69 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
idx_list = [None] * y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask,(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0]==self.EOS
|
||||
l2 = tokens==self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0] == self.EOS
|
||||
l2 = tokens == self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
for i, batch_index in enumerate(batch_idx_map):
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
if not (None in idx_list):
|
||||
|
||||
if None not in idx_list:
|
||||
stop = True
|
||||
|
||||
|
||||
if stop:
|
||||
if y.shape[1]==0:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
@@ -730,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if (None in idx_list):
|
||||
if None in idx_list:
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive_batched(self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
|
||||
def infer_panel_naive_batched(
|
||||
self,
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs)
|
||||
y, idx = self.infer_panel_naive(
|
||||
x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs,
|
||||
)
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
|
||||
return y_list, idx_list
|
||||
|
||||
|
||||
def infer_panel_naive(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@@ -828,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
@@ -840,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
@@ -870,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx
|
||||
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
||||
return self.infer_panel_naive(
|
||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user