改变训练和推理时的mask策略,以修复当batch_size>1时,产生的复读现象 (#966)
This commit is contained in:
@@ -515,16 +515,16 @@ class TTS:
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
|
||||
|
||||
# max_len = max(bert_max_len, phones_max_len)
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
#### 直接对phones和bert_features进行pad,会增大复读概率。
|
||||
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
# all_bert_features_batch = all_bert_features_list
|
||||
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
|
||||
# for idx, item in enumerate(all_bert_features_list):
|
||||
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
|
||||
# #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||
# #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
|
||||
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
|
||||
# all_phones_batch = torch.stack(all_phones_list, dim=0)
|
||||
@@ -734,17 +734,18 @@ class TTS:
|
||||
continue
|
||||
|
||||
batch_phones:List[torch.LongTensor] = item["phones"]
|
||||
# batch_phones:torch.LongTensor = item["phones"]
|
||||
batch_phones_len:torch.LongTensor = item["phones_len"]
|
||||
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"]
|
||||
all_phoneme_ids:torch.LongTensor = item["all_phones"]
|
||||
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features:List[torch.LongTensor] = item["all_bert_features"]
|
||||
all_bert_features:torch.LongTensor = item["all_bert_features"]
|
||||
norm_text:str = item["norm_text"]
|
||||
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||
if no_prompt_text :
|
||||
prompt = None
|
||||
else:
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
||||
|
||||
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
|
||||
Reference in New Issue
Block a user