改变训练和推理时的mask策略,以修复当batch_size>1时,产生的复读现象 (#966)
This commit is contained in:
@@ -297,7 +297,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
# 取消对y[0]的mask,以防止复读,详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965
|
||||
x_attn_mask[:, x_len]=False
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||
@@ -393,6 +394,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
# 取消对y[0]的mask,以防止复读,详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965
|
||||
x_attn_mask[:, x_len]=False
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||
@@ -458,7 +461,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0,否则会导致batch_size>1时的复读情况
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
@@ -504,29 +507,29 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def infer_panel_batch_infer_with_flash_attn(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
bert_feature:torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
# 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||
max_len = 0
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_list = [self.ar_text_embedding(item) for item in x]
|
||||
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
x = torch.stack(x_list, dim=0)
|
||||
## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
# max_len = 0
|
||||
# for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
# x_list = [self.ar_text_embedding(item) for item in x]
|
||||
# x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
# x = torch.stack(x_list, dim=0)
|
||||
|
||||
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
# bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
# bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
# bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
|
||||
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
# x = self.ar_text_embedding(x)
|
||||
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
@@ -573,8 +576,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
y_mask = F.pad( ###yy的右上0扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), # diagonal必须为0,否则会导致batch_size>1时的复读情况
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
@@ -669,29 +672,29 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def infer_panel_batch_only(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
bert_feature:torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
# 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||
max_len = 0
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_list = [self.ar_text_embedding(item) for item in x]
|
||||
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
x = torch.stack(x_list, dim=0)
|
||||
## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
# max_len = 0
|
||||
# for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
# x_list = [self.ar_text_embedding(item) for item in x]
|
||||
# x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
# x = torch.stack(x_list, dim=0)
|
||||
|
||||
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
# bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
# bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
# bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
|
||||
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
# x = self.ar_text_embedding(x)
|
||||
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
@@ -747,7 +750,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=True,
|
||||
)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), # diagonal必须为0,否则会导致batch_size>1时的复读情况
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user