[fast_inference] 回退策略,减少padding影响,开放选项,同步代码 (#986)
* Update README * Optimize-English-G2P * docs: change akward expression * docs: update Changelog_KO.md * Fix CN punc in EN,add 's match * Adjust normalize and g2p logic * Update zh_CN.json * Update README (#827) Update README.md Update some outdated file paths and commands * 修复英文多音字,调整字典热加载,新增姓名匹配 (#869) * Fix homograph dict * Add JSON in dict * Adjust hot dict to hot reload * Add English name dict * Adjust get name dict logic * Make API Great Again (#894) * Add zh/jp/en mix * Optimize code readability and formatted output. * Try OGG streaming * Add stream mode arg * Add media type arg * Add cut punc arg * Eliminate punc risk * Update README (#895) * Update README * Update README * update README * update README * fix typo s/Licence /License (#904) * fix reformat cmd (#917) Co-authored-by: starylan <starylan@outlook.com> * Update README.md * Normalize chinese arithmetic operations (#947) * 改变训练和推理时的mask策略,以修复当batch_size>1时,产生的复读现象 * 同步main分支代码,增加“保持随机”选项 * 在colab中运行colab_webui.ipynb发生的uvr5模型缺失问题 (#968) 在colab中使用git下载uvr5模型时报错: fatal: destination path 'uvr5_weights' already exists and is not an empty directory. 通过在下载前将原本从本仓库下载的uvr5_weights文件夹删除可以解决问题。 * [ASR] 修复FasterWhisper遍历输入路径失败 (#956) * remove glob * rename * reset mirror pos * 回退mask策略; 回退pad策略; 在T2SBlock中添加padding_mask,以减少pad的影响; 开放repetition_penalty参数,让用户自行调整重复惩罚的强度; 增加parallel_infer参数,用于开启或关闭并行推理,关闭时与0307版本保持一致; 在webui中增加“保持随机”选项; 同步main分支代码。 * 删除无用注释 --------- Co-authored-by: Lion <drain.daters.0p@icloud.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: KamioRinn <snowsdream@live.com> Co-authored-by: Pengoose <pengoose_dev@naver.com> Co-authored-by: Yuan-Man <68322456+Yuan-ManX@users.noreply.github.com> Co-authored-by: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Co-authored-by: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Co-authored-by: Lion-Wu <130235128+Lion-Wu@users.noreply.github.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: SapphireLab <36986837+SapphireLab@users.noreply.github.com> Co-authored-by: starylan <starylan@outlook.com> Co-authored-by: shadow01a <141255649+shadow01a@users.noreply.github.com>
This commit is contained in:
@@ -37,7 +37,6 @@ default:
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
flash_attn_enabled: true
|
||||
|
||||
custom:
|
||||
device: cuda
|
||||
@@ -46,7 +45,6 @@ custom:
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
flash_attn_enabled: true
|
||||
|
||||
|
||||
"""
|
||||
@@ -66,6 +64,9 @@ def set_seed(seed:int):
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.enabled = True
|
||||
# 开启后会影响精度
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
except:
|
||||
pass
|
||||
return seed
|
||||
@@ -78,7 +79,6 @@ class TTS_Config:
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
}
|
||||
configs:dict = None
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
@@ -108,7 +108,6 @@ class TTS_Config:
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
@@ -141,7 +140,7 @@ class TTS_Config:
|
||||
self.n_speakers:int = 300
|
||||
|
||||
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
# print(self)
|
||||
|
||||
|
||||
def _load_configs(self, configs_path: str)->dict:
|
||||
with open(configs_path, 'r') as f:
|
||||
@@ -169,7 +168,6 @@ class TTS_Config:
|
||||
"vits_weights_path" : self.vits_weights_path,
|
||||
"bert_base_path" : self.bert_base_path,
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||
"flash_attn_enabled" : self.flash_attn_enabled
|
||||
}
|
||||
return self.config
|
||||
|
||||
@@ -289,8 +287,7 @@ class TTS:
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
||||
flash_attn_enabled=self.configs.flash_attn_enabled)
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
@@ -435,8 +432,6 @@ class TTS:
|
||||
device:torch.device=torch.device("cpu"),
|
||||
precision:torch.dtype=torch.float32,
|
||||
):
|
||||
# 但是这里不能套,反而会负优化
|
||||
# with torch.no_grad():
|
||||
_data:list = []
|
||||
index_and_len_list = []
|
||||
for idx, item in enumerate(data):
|
||||
@@ -484,8 +479,6 @@ class TTS:
|
||||
norm_text_batch = []
|
||||
bert_max_len = 0
|
||||
phones_max_len = 0
|
||||
# 但是这里也不能套,反而会负优化
|
||||
# with torch.no_grad():
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
@@ -518,11 +511,11 @@ class TTS:
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
# all_bert_features_batch = all_bert_features_list
|
||||
# all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device)
|
||||
# for idx, item in enumerate(all_bert_features_list):
|
||||
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
|
||||
# #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
|
||||
@@ -539,7 +532,8 @@ class TTS:
|
||||
"all_phones": all_phones_batch,
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
||||
"all_bert_features": all_bert_features_batch,
|
||||
"norm_text": norm_text_batch
|
||||
"norm_text": norm_text_batch,
|
||||
"max_len": max_len,
|
||||
}
|
||||
_data.append(batch)
|
||||
|
||||
@@ -569,7 +563,6 @@ class TTS:
|
||||
'''
|
||||
self.stop_flag = True
|
||||
|
||||
# 使用装饰器
|
||||
@torch.no_grad()
|
||||
def run(self, inputs:dict):
|
||||
"""
|
||||
@@ -594,6 +587,8 @@ class TTS:
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
}
|
||||
returns:
|
||||
tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
@@ -618,9 +613,17 @@ class TTS:
|
||||
seed = inputs.get("seed", -1)
|
||||
seed = -1 if seed in ["", None] else seed
|
||||
actual_seed = set_seed(seed)
|
||||
parallel_infer = inputs.get("parallel_infer", True)
|
||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||
|
||||
if parallel_infer:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
|
||||
else:
|
||||
print(i18n("并行推理模式已关闭"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307
|
||||
|
||||
if return_fragment:
|
||||
# split_bucket = False
|
||||
print(i18n("分段返回模式已开启"))
|
||||
if split_bucket:
|
||||
split_bucket = False
|
||||
@@ -740,6 +743,7 @@ class TTS:
|
||||
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features:torch.LongTensor = item["all_bert_features"]
|
||||
norm_text:str = item["norm_text"]
|
||||
max_len = item["max_len"]
|
||||
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||
if no_prompt_text :
|
||||
@@ -758,6 +762,8 @@ class TTS:
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
t4 = ttime()
|
||||
t_34 += t4 - t3
|
||||
|
||||
Reference in New Issue
Block a user