Version Check (#1390)
* version check * fix webui and symbols * fix v1 language map
This commit is contained in:
@@ -152,6 +152,11 @@ def change_sovits_weights(sovits_path):
|
||||
hps = dict_s2["config"]
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
hps.model.version = "v1"
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
# print("sovits版本:",hps.model.version)
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
@@ -231,9 +236,9 @@ dict_language = {
|
||||
}
|
||||
|
||||
|
||||
def clean_text_inf(text, language):
|
||||
phones, word2ph, norm_text = clean_text(text, language)
|
||||
phones = cleaned_text_to_sequence(phones)
|
||||
def clean_text_inf(text, language, version):
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
dtype=torch.float16 if is_half == True else torch.float32
|
||||
@@ -259,7 +264,7 @@ def get_first(text):
|
||||
return text
|
||||
|
||||
from text import chinese
|
||||
def get_phones_and_bert(text,language):
|
||||
def get_phones_and_bert(text,language,version):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
if language == "en":
|
||||
@@ -274,16 +279,16 @@ def get_phones_and_bert(text,language):
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"zh")
|
||||
return get_phones_and_bert(formattext,"zh",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"yue")
|
||||
return get_phones_and_bert(formattext,"yue",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
@@ -317,7 +322,7 @@ def get_phones_and_bert(text,language):
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
@@ -357,6 +362,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
version = vq_model.version
|
||||
|
||||
if not ref_free:
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||
@@ -413,7 +421,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
if not ref_free:
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
|
||||
for i_text,text in enumerate(texts):
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
@@ -421,7 +429,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||
print(i18n("实际输入的目标文本(每句):"), text)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||
if not ref_free:
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
@@ -15,7 +15,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
from text import symbols
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
from torch.cuda.amp import autocast
|
||||
import contextlib
|
||||
|
||||
@@ -185,6 +187,7 @@ class TextEncoder(nn.Module):
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
latent_channels=192,
|
||||
version = "v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
@@ -195,6 +198,7 @@ class TextEncoder(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.latent_channels = latent_channels
|
||||
self.version = version
|
||||
|
||||
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
||||
|
||||
@@ -210,6 +214,11 @@ class TextEncoder(nn.Module):
|
||||
self.encoder_text = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
|
||||
if self.version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||
|
||||
self.mrte = MRTE()
|
||||
@@ -827,6 +836,7 @@ class SynthesizerTrn(nn.Module):
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version = "v2",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -847,6 +857,7 @@ class SynthesizerTrn(nn.Module):
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(
|
||||
@@ -857,6 +868,7 @@ class SynthesizerTrn(nn.Module):
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
version = version,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
@@ -881,7 +893,7 @@ class SynthesizerTrn(nn.Module):
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
self.version=os.environ.get("version","v1")
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if(self.version=="v1"):
|
||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||
else:
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
import os
|
||||
if os.environ.get("version","v1")=="v1":
|
||||
from text.symbols import symbols
|
||||
else:
|
||||
from text.symbols2 import symbols
|
||||
# if os.environ.get("version","v1")=="v1":
|
||||
# from text.symbols import symbols
|
||||
# else:
|
||||
# from text.symbols2 import symbols
|
||||
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text):
|
||||
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
||||
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text, version):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
|
||||
if version == "v1":
|
||||
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
||||
else:
|
||||
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
||||
|
||||
return phones
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from text import japanese, cleaned_text_to_sequence, english,korean,cantonese
|
||||
import os
|
||||
if os.environ.get("version","v1")=="v1":
|
||||
from text import chinese
|
||||
from text.symbols import symbols
|
||||
else:
|
||||
from text import chinese2 as chinese
|
||||
from text.symbols2 import symbols
|
||||
# if os.environ.get("version","v1")=="v1":
|
||||
# from text import chinese
|
||||
# from text.symbols import symbols
|
||||
# else:
|
||||
# from text import chinese2 as chinese
|
||||
# from text.symbols2 import symbols
|
||||
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
from text import chinese as chinese_v1
|
||||
from text import chinese2 as chinese_v2
|
||||
|
||||
language_module_map = {"zh": chinese, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
|
||||
special = [
|
||||
# ("%", "zh", "SP"),
|
||||
("¥", "zh", "SP2"),
|
||||
@@ -16,13 +20,20 @@ special = [
|
||||
]
|
||||
|
||||
|
||||
def clean_text(text, language):
|
||||
def clean_text(text, language, version):
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
|
||||
|
||||
if(language not in language_module_map):
|
||||
language="en"
|
||||
text=" "
|
||||
for special_s, special_l, target_symbol in special:
|
||||
if special_s in text and language == special_l:
|
||||
return clean_special(text, language, special_s, target_symbol)
|
||||
return clean_special(text, language, special_s, target_symbol, version)
|
||||
language_module = language_module_map[language]
|
||||
if hasattr(language_module,"text_normalize"):
|
||||
norm_text = language_module.text_normalize(text)
|
||||
@@ -42,11 +53,18 @@ def clean_text(text, language):
|
||||
word2ph = None
|
||||
|
||||
for ph in phones:
|
||||
assert ph in symbols
|
||||
phones = ['UNK' if ph not in symbols else ph for ph in phones]
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def clean_special(text, language, special_s, target_symbol):
|
||||
def clean_special(text, language, special_s, target_symbol, version):
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
|
||||
|
||||
"""
|
||||
特殊静音段sp符号处理
|
||||
"""
|
||||
|
||||
@@ -6,10 +6,7 @@ from g2p_en import G2p
|
||||
|
||||
from text.symbols import punctuation
|
||||
|
||||
if os.environ.get("version","v1")=="v1":
|
||||
from text.symbols import symbols
|
||||
else:
|
||||
from text.symbols2 import symbols
|
||||
from text.symbols2 import symbols
|
||||
|
||||
import unicodedata
|
||||
from builtins import str as unicode
|
||||
|
||||
@@ -4,12 +4,6 @@ import sys
|
||||
|
||||
import pyopenjtalk
|
||||
|
||||
|
||||
import os
|
||||
if os.environ.get("version","v1")=="v1":
|
||||
from text.symbols import symbols
|
||||
else:
|
||||
from text.symbols2 import symbols
|
||||
from text.symbols import punctuation
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(
|
||||
@@ -61,12 +55,13 @@ def post_replace_ph(ph):
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
}
|
||||
|
||||
if ph in rep_map.keys():
|
||||
ph = rep_map[ph]
|
||||
if ph in symbols:
|
||||
return ph
|
||||
if ph not in symbols:
|
||||
ph = "UNK"
|
||||
# if ph in symbols:
|
||||
# return ph
|
||||
# if ph not in symbols:
|
||||
# ph = "UNK"
|
||||
return ph
|
||||
|
||||
|
||||
|
||||
@@ -2,11 +2,8 @@ import re
|
||||
from jamo import h2j, j2hcj
|
||||
import ko_pron
|
||||
from g2pk2 import G2p
|
||||
import os
|
||||
if os.environ.get("version","v1")=="v1":
|
||||
from text.symbols import symbols
|
||||
else:
|
||||
from text.symbols2 import symbols
|
||||
|
||||
from text.symbols2 import symbols
|
||||
|
||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
||||
|
||||
Reference in New Issue
Block a user