Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -2,43 +2,51 @@
|
||||
import re
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
import pyopenjtalk
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
|
||||
# 防止win下无法读取模型
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
python_dir = os.getcwd()
|
||||
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
|
||||
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
|
||||
if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
|
||||
else:
|
||||
import shutil
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||
os.mkdir(os.path.join("TEMP", "ja"))
|
||||
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
|
||||
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
|
||||
shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), )
|
||||
shutil.copytree(
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
|
||||
os.path.join("TEMP", "ja", "open_jtalk_dic"),
|
||||
)
|
||||
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
|
||||
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
|
||||
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
|
||||
if current_file_path[: len(python_dir)].upper() == python_dir.upper():
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
|
||||
else:
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||
os.mkdir(os.path.join("TEMP", "ja"))
|
||||
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
|
||||
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
|
||||
shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"))
|
||||
shutil.copyfile(
|
||||
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
|
||||
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
|
||||
)
|
||||
current_file_path = os.path.join("TEMP", "ja")
|
||||
|
||||
|
||||
def get_hash(fp: str) -> str:
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(fp, "rb") as f:
|
||||
@@ -51,21 +59,26 @@ try:
|
||||
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
|
||||
# 如果没有用户词典,就生成一个;如果有,就检查md5,如果不一样,就重新生成
|
||||
if os.path.exists(USERDIC_CSV_PATH):
|
||||
if not os.path.exists(USERDIC_BIN_PATH) or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r",encoding='utf-8').read():
|
||||
if (
|
||||
not os.path.exists(USERDIC_BIN_PATH)
|
||||
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
|
||||
):
|
||||
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
|
||||
with open(USERDIC_HASH_PATH, "w", encoding='utf-8') as f:
|
||||
with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(get_hash(USERDIC_CSV_PATH))
|
||||
|
||||
if os.path.exists(USERDIC_BIN_PATH):
|
||||
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
|
||||
except Exception as e:
|
||||
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
|
||||
except Exception:
|
||||
# print(e)
|
||||
import pyopenjtalk
|
||||
|
||||
# failed to load user dictionary, ignore.
|
||||
pass
|
||||
|
||||
|
||||
from text.symbols import punctuation
|
||||
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(
|
||||
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
||||
@@ -123,9 +136,9 @@ def post_replace_ph(ph):
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@@ -152,7 +165,7 @@ def preprocess_jap(text, with_prosody=False):
|
||||
text += p.split(" ")
|
||||
|
||||
if i < len(marks):
|
||||
if marks[i] == " ":# 防止意外的UNK
|
||||
if marks[i] == " ": # 防止意外的UNK
|
||||
continue
|
||||
text += [marks[i].replace(" ", "")]
|
||||
return text
|
||||
@@ -165,6 +178,7 @@ def text_normalize(text):
|
||||
text = replace_consecutive_punctuation(text)
|
||||
return text
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
|
||||
@@ -241,6 +255,7 @@ def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
|
||||
return phones
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def _numeric_feature_by_regex(regex, s):
|
||||
match = re.search(regex, s)
|
||||
@@ -248,6 +263,7 @@ def _numeric_feature_by_regex(regex, s):
|
||||
return -50
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
def g2p(norm_text, with_prosody=True):
|
||||
phones = preprocess_jap(norm_text, with_prosody)
|
||||
phones = [post_replace_ph(i) for i in phones]
|
||||
|
||||
Reference in New Issue
Block a user