Version Check (#1390)

* version check

* fix webui and symbols

* fix v1 language map
This commit is contained in:
KamioRinn
2024-08-05 17:24:42 +08:00
committed by GitHub
parent 0c25e57959
commit 4e34814c70
8 changed files with 157 additions and 78 deletions

View File

@@ -152,6 +152,11 @@ def change_sovits_weights(sovits_path):
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
# print("sovits版本:",hps.model.version)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@@ -231,9 +236,9 @@ dict_language = {
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
def clean_text_inf(text, language, version):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
@@ -259,7 +264,7 @@ def get_first(text):
return text
from text import chinese
def get_phones_and_bert(text,language):
def get_phones_and_bert(text,language,version):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
@@ -274,16 +279,16 @@ def get_phones_and_bert(text,language):
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"zh")
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"yue")
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
@@ -317,7 +322,7 @@ def get_phones_and_bert(text,language):
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
@@ -357,6 +362,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
version = vq_model.version
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
@@ -413,7 +421,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
for i_text,text in enumerate(texts):
# 解决输入目标文本的空行导致报错的问题
@@ -421,7 +429,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)