Version Check (#1390)

* version check

* fix webui and symbols

* fix v1 language map
This commit is contained in:
KamioRinn
2024-08-05 17:24:42 +08:00
committed by GitHub
parent 0c25e57959
commit 4e34814c70
8 changed files with 157 additions and 78 deletions

100
api.py
View File

@@ -11,7 +11,7 @@
调用请求缺少参考音频时使用
`-dr` - `默认参考音频路径`
`-dt` - `默认参考音频文本`
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
`-d` - `推理设备, "cuda","cpu"`
`-a` - `绑定地址, 默认"127.0.0.1"`
@@ -201,6 +201,11 @@ def change_sovits_weights(sovits_path):
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
print("sovits版本:",hps.model.version)
model_params_dict = vars(hps.model)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
@@ -251,9 +256,9 @@ def get_bert_feature(text, word2ph):
return phone_level_feature.T
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
def clean_text_inf(text, language, version):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
@@ -269,54 +274,64 @@ def get_bert_inf(phones, word2ph, norm_text, language):
return bert
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
from text import chinese
def get_phones_and_bert(text,language,version):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
textlist.append(tmp["text"])
else:
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# logger.info(textlist)
# logger.info(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
@@ -328,14 +343,32 @@ def get_phones_and_bert(text,language):
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
class DictToAttrRecursive:
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
# 如果值是字典,递归调用构造函数
setattr(self, key, DictToAttrRecursive(value))
else:
setattr(self, key, value)
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def get_spepc(hps, filename):
@@ -488,9 +521,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
version = vq_model.version
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
texts = text.split("\n")
audio_bytes = BytesIO()
@@ -500,7 +534,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
@@ -606,17 +640,27 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
# --------------------------------
dict_language = {
"中文": "all_zh",
"粤语": "all_yue",
"英文": "en",
"日文": "all_ja",
"韩文": "all_ko",
"中英混合": "zh",
"粤英混合": "yue",
"日英混合": "ja",
"韩英混合": "ko",
"多语种混合": "auto", #多语种启动切分识别语种
"多语种混合(粤语)": "auto_yue",
"all_zh": "all_zh",
"all_yue": "all_yue",
"en": "en",
"all_ja": "all_ja",
"all_ko": "all_ko",
"zh": "zh",
"yue": "yue",
"ja": "ja",
"ko": "ko",
"auto": "auto",
"auto_yue": "auto_yue",
}
# logger