修复了一些bug,优化了一些代码

This commit is contained in:
chasonjiang
2024-03-11 17:16:04 +08:00
parent 3535cfe3b0
commit d23f3a62c4
5 changed files with 72 additions and 51 deletions

View File

@@ -1,5 +1,7 @@
import os, sys
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
# from tools.i18n.i18n import I18nAuto
from tools.i18n.i18n import I18nAuto
# i18n = I18nAuto()
i18n = I18nAuto()
def get_first(text:str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
@@ -51,9 +53,11 @@ class TextPreprocessor:
self.device = device
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
print(i18n("############ 切分文本 ############"))
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
for text in texts:
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
res={
"phones": phones,
@@ -67,14 +71,16 @@ class TextPreprocessor:
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text
# print(i18n("实际输入的目标文本:"), text)
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
# print(i18n("实际输入的目标文本(切句后):"), text)
print(i18n("实际输入的目标文本(切句后):"))
print(text)
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
@@ -105,7 +111,7 @@ class TextPreprocessor:
textlist=[]
langlist=[]
if language in ["auto", "zh", "ja"]:
# LangSegment.setfilters(["zh","ja","en","ko"])
LangSegment.setfilters(["zh","ja","en","ko"])
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
@@ -116,7 +122,7 @@ class TextPreprocessor:
langlist.append(language if language!="auto" else tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
# LangSegment.setfilters(["en"])
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
while " " in formattext:
formattext = formattext.replace(" ", " ")
@@ -153,7 +159,7 @@ class TextPreprocessor:
# phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones, bert_feature, norm_text
return phones_list, bert_feature, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: