Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,41 +1,57 @@
import re
from typing import Callable
punctuation = set(['!', '?', '', ',', '.', '-'," "])
punctuation = set(["!", "?", "", ",", ".", "-", " "])
METHODS = dict()
def get_method(name:str)->Callable:
def get_method(name: str) -> Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def get_method_names()->list:
def get_method_names() -> list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
}
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
segments = re.split("([" + punctuation + "])", text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
current_segment = ""
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len:
@@ -43,13 +59,12 @@ def split_big_text(text, max_len=510):
current_segment = segment
else:
current_segment += segment
# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)
return result
return result
def split(todo_text):
@@ -90,7 +105,7 @@ def cut1(inp):
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
@@ -123,6 +138,7 @@ def cut2(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
@@ -131,26 +147,28 @@ def cut3(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
#按英文句号.切
# 按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
inp = inp.strip("\n")
punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''}
punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
@@ -166,8 +184,6 @@ def cut5(inp):
return "\n".join(opt)
if __name__ == '__main__':
if __name__ == "__main__":
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))