[fast_inference] 优化batch inference的mask策略 (#1477)

* 优化了batch inference的mask策略,使音频合成的质量更加稳定;改善了一些代码逻辑。

* 删除无用代码
This commit is contained in:
ChasonJiang
2024-08-16 10:49:53 +08:00
committed by GitHub
parent 7c43b41e6d
commit f5a5f1890f
3 changed files with 153 additions and 79 deletions

View File

@@ -81,6 +81,8 @@ class TTS_Config:
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}
configs:dict = None
languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
def __init__(self, configs: Union[dict, str]=None):
# 设置默认配置文件路径
@@ -97,7 +99,7 @@ class TTS_Config:
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None)
if default_configs is not None:
@@ -138,8 +140,7 @@ class TTS_Config:
self.hop_length:int = 640
self.win_length:int = 2048
self.n_speakers:int = 300
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
def _load_configs(self, configs_path: str)->dict:
@@ -489,8 +490,8 @@ class TTS:
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
all_bert_max_len = 0
all_phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
@@ -505,8 +506,8 @@ class TTS:
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1])
all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
@@ -520,7 +521,7 @@ class TTS:
all_bert_features_batch = all_bert_features_list
max_len = max(bert_max_len, phones_max_len)
max_len = max(all_bert_max_len, all_phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad。padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
@@ -630,7 +631,7 @@ class TTS:
if parallel_infer:
print(i18n("并行推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
else:
print(i18n("并行推理模式已关闭"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307
@@ -942,4 +943,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
return processed_audio