Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,9 @@
import os, sys
import os
import sys
import threading
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -18,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(['!', '?', '', ',', '.', '-'])
punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text:str) -> str:
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2:
return texts
result = []
@@ -38,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(text) > 0:
if len(result) == 0:
result.append(text)
else:
@@ -46,28 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.bert_lock = threading.RLock()
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
print(f'############ {i18n("切分文本")} ############')
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f'############ {i18n("提取文本Bert特征")} ############')
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text=="":
if phones is None or norm_text == "":
continue
res={
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
@@ -75,11 +74,11 @@ class TextPreprocessor:
result.append(res)
return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
return []
if (text[0] not in splits and len(get_first(text)) < 4):
if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
@@ -95,18 +94,18 @@ class TextPreprocessor:
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if (len(text) > 510):
if len(text) > 510:
texts.extend(split_big_text(text))
else:
texts.append(text)
@@ -115,78 +114,79 @@ class TextPreprocessor:
print(texts)
return texts
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock:
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
# language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
if language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
# language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist = []
langlist = []
if language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text,language,version,final=True)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True)
return phones, bert, norm_text
return phones, bert, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
@@ -201,14 +201,14 @@ class TextPreprocessor:
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"):
language = language.replace("all_","")
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
@@ -219,21 +219,19 @@ class TextPreprocessor:
return feature
def filter_text(self,texts):
_text=[]
if all(text in [None, " ", "\n",""] for text in texts):
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(self,text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
def replace_consecutive_punctuation(self, text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result

View File

@@ -1 +1 @@
from . import TTS, text_segmentation_method
from . import TTS, text_segmentation_method

View File

@@ -1,41 +1,57 @@
import re
from typing import Callable
punctuation = set(['!', '?', '', ',', '.', '-'," "])
punctuation = set(["!", "?", "", ",", ".", "-", " "])
METHODS = dict()
def get_method(name:str)->Callable:
def get_method(name: str) -> Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def get_method_names()->list:
def get_method_names() -> list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
}
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
segments = re.split("([" + punctuation + "])", text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
current_segment = ""
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len:
@@ -43,13 +59,12 @@ def split_big_text(text, max_len=510):
current_segment = segment
else:
current_segment += segment
# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)
return result
return result
def split(todo_text):
@@ -90,7 +105,7 @@ def cut1(inp):
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
@@ -123,6 +138,7 @@ def cut2(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
@@ -131,26 +147,28 @@ def cut3(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
#按英文句号.切
# 按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
inp = inp.strip("\n")
punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''}
punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
@@ -166,8 +184,6 @@ def cut5(inp):
return "\n".join(opt)
if __name__ == '__main__':
if __name__ == "__main__":
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))