缓解了batch_size>1时的复读问题,缓解方法是:在T2S模型中,先对phones进行embedding、对bert_features进行project,再pad到相同长度。

This commit is contained in:
chasonjiang
2024-03-16 21:04:49 +08:00
parent 3c78539c44
commit 864a148d75
2 changed files with 85 additions and 43 deletions

View File

@@ -504,18 +504,29 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_batch_infer_with_flash_attn(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_embedding(x)
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0
for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
@@ -658,17 +669,30 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_batch_only(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0
for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder