Version Check (#1390)

* version check

* fix webui and symbols

* fix v1 language map
This commit is contained in:
KamioRinn
2024-08-05 17:24:42 +08:00
committed by GitHub
parent 0c25e57959
commit 4e34814c70
8 changed files with 157 additions and 78 deletions

View File

@@ -152,6 +152,11 @@ def change_sovits_weights(sovits_path):
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
# print("sovits版本:",hps.model.version)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@@ -231,9 +236,9 @@ dict_language = {
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
def clean_text_inf(text, language, version):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
@@ -259,7 +264,7 @@ def get_first(text):
return text
from text import chinese
def get_phones_and_bert(text,language):
def get_phones_and_bert(text,language,version):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
@@ -274,16 +279,16 @@ def get_phones_and_bert(text,language):
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"zh")
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"yue")
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
@@ -317,7 +322,7 @@ def get_phones_and_bert(text,language):
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
@@ -357,6 +362,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
version = vq_model.version
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
@@ -413,7 +421,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
for i_text,text in enumerate(texts):
# 解决输入目标文本的空行导致报错的问题
@@ -421,7 +429,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)

View File

@@ -15,7 +15,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from text import symbols
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib
@@ -185,6 +187,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version = "v2",
):
super().__init__()
self.out_channels = out_channels
@@ -195,6 +198,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@@ -210,6 +214,11 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@@ -827,6 +836,7 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version = "v2",
**kwargs
):
super().__init__()
@@ -847,6 +857,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
@@ -857,6 +868,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version = version,
)
self.dec = Generator(
inter_channels,
@@ -881,7 +893,7 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.version=os.environ.get("version","v1")
# self.version=os.environ.get("version","v1")
if(self.version=="v1"):
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:

View File

@@ -1,18 +1,26 @@
import os
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
# if os.environ.get("version","v1")=="v1":
# from text.symbols import symbols
# else:
# from text.symbols2 import symbols
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
def cleaned_text_to_sequence(cleaned_text):
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
def cleaned_text_to_sequence(cleaned_text, version):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
if version == "v1":
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
else:
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
return phones

View File

@@ -1,13 +1,17 @@
from text import japanese, cleaned_text_to_sequence, english,korean,cantonese
import os
if os.environ.get("version","v1")=="v1":
from text import chinese
from text.symbols import symbols
else:
from text import chinese2 as chinese
from text.symbols2 import symbols
# if os.environ.get("version","v1")=="v1":
# from text import chinese
# from text.symbols import symbols
# else:
# from text import chinese2 as chinese
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from text import chinese as chinese_v1
from text import chinese2 as chinese_v2
language_module_map = {"zh": chinese, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
special = [
# ("%", "zh", "SP"),
("", "zh", "SP2"),
@@ -16,13 +20,20 @@ special = [
]
def clean_text(text, language):
def clean_text(text, language, version):
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
if(language not in language_module_map):
language="en"
text=" "
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol)
return clean_special(text, language, special_s, target_symbol, version)
language_module = language_module_map[language]
if hasattr(language_module,"text_normalize"):
norm_text = language_module.text_normalize(text)
@@ -42,11 +53,18 @@ def clean_text(text, language):
word2ph = None
for ph in phones:
assert ph in symbols
phones = ['UNK' if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol):
def clean_special(text, language, special_s, target_symbol, version):
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
"""
特殊静音段sp符号处理
"""

View File

@@ -6,10 +6,7 @@ from g2p_en import G2p
from text.symbols import punctuation
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
from text.symbols2 import symbols
import unicodedata
from builtins import str as unicode

View File

@@ -4,12 +4,6 @@ import sys
import pyopenjtalk
import os
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
from text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
@@ -61,12 +55,13 @@ def post_replace_ph(ph):
"": ",",
"...": "",
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = "UNK"
# if ph in symbols:
# return ph
# if ph not in symbols:
# ph = "UNK"
return ph

View File

@@ -2,11 +2,8 @@ import re
from jamo import h2j, j2hcj
import ko_pron
from g2pk2 import G2p
import os
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'