Version Check (#1390)
* version check * fix webui and symbols * fix v1 language map
This commit is contained in:
@@ -152,6 +152,11 @@ def change_sovits_weights(sovits_path):
|
||||
hps = dict_s2["config"]
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
hps.model.version = "v1"
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
# print("sovits版本:",hps.model.version)
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
@@ -231,9 +236,9 @@ dict_language = {
|
||||
}
|
||||
|
||||
|
||||
def clean_text_inf(text, language):
|
||||
phones, word2ph, norm_text = clean_text(text, language)
|
||||
phones = cleaned_text_to_sequence(phones)
|
||||
def clean_text_inf(text, language, version):
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
dtype=torch.float16 if is_half == True else torch.float32
|
||||
@@ -259,7 +264,7 @@ def get_first(text):
|
||||
return text
|
||||
|
||||
from text import chinese
|
||||
def get_phones_and_bert(text,language):
|
||||
def get_phones_and_bert(text,language,version):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
if language == "en":
|
||||
@@ -274,16 +279,16 @@ def get_phones_and_bert(text,language):
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"zh")
|
||||
return get_phones_and_bert(formattext,"zh",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"yue")
|
||||
return get_phones_and_bert(formattext,"yue",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
@@ -317,7 +322,7 @@ def get_phones_and_bert(text,language):
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
@@ -357,6 +362,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
version = vq_model.version
|
||||
|
||||
if not ref_free:
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||
@@ -413,7 +421,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
if not ref_free:
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
|
||||
for i_text,text in enumerate(texts):
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
@@ -421,7 +429,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||
print(i18n("实际输入的目标文本(每句):"), text)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||
if not ref_free:
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
Reference in New Issue
Block a user