[fast_inference] 优化batch inference的mask策略 (#1477)
* 优化了batch inference的mask策略,使音频合成的质量更加稳定;改善了一些代码逻辑。 * 删除无用代码
This commit is contained in:
@@ -81,6 +81,8 @@ class TTS_Config:
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
}
|
||||
configs:dict = None
|
||||
languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
|
||||
# 设置默认配置文件路径
|
||||
@@ -97,7 +99,7 @@ class TTS_Config:
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(self.configs_path)
|
||||
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
default_configs:dict = configs.get("default", None)
|
||||
if default_configs is not None:
|
||||
@@ -138,8 +140,7 @@ class TTS_Config:
|
||||
self.hop_length:int = 640
|
||||
self.win_length:int = 2048
|
||||
self.n_speakers:int = 300
|
||||
|
||||
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
|
||||
|
||||
|
||||
def _load_configs(self, configs_path: str)->dict:
|
||||
@@ -489,8 +490,8 @@ class TTS:
|
||||
all_phones_len_list = []
|
||||
all_bert_features_list = []
|
||||
norm_text_batch = []
|
||||
bert_max_len = 0
|
||||
phones_max_len = 0
|
||||
all_bert_max_len = 0
|
||||
all_phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
@@ -505,8 +506,8 @@ class TTS:
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||
all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1])
|
||||
all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1])
|
||||
|
||||
phones_list.append(phones)
|
||||
phones_len_list.append(phones.shape[-1])
|
||||
@@ -520,7 +521,7 @@ class TTS:
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
|
||||
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
max_len = max(all_bert_max_len, all_phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
@@ -630,7 +631,7 @@ class TTS:
|
||||
|
||||
if parallel_infer:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
else:
|
||||
print(i18n("并行推理模式已关闭"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307
|
||||
@@ -942,4 +943,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
# 将管道输出解码为 NumPy 数组
|
||||
processed_audio = np.frombuffer(out, np.int16)
|
||||
|
||||
return processed_audio
|
||||
return processed_audio
|
||||
Reference in New Issue
Block a user