修复了一些bug,优化了一些代码
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
|
||||
import os, sys
|
||||
|
||||
from tqdm import tqdm
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
@@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
|
||||
# from tools.i18n.i18n import I18nAuto
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
# i18n = I18nAuto()
|
||||
i18n = I18nAuto()
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
@@ -51,9 +53,11 @@ class TextPreprocessor:
|
||||
self.device = device
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
for text in texts:
|
||||
print(i18n("############ 提取文本Bert特征 ############"))
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
||||
res={
|
||||
"phones": phones,
|
||||
@@ -67,14 +71,16 @@ class TextPreprocessor:
|
||||
text = text.strip("\n")
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
# print(i18n("实际输入的目标文本:"), text)
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
|
||||
seg_method = get_seg_method(text_split_method)
|
||||
text = seg_method(text)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
# print(i18n("实际输入的目标文本(切句后):"), text)
|
||||
print(i18n("实际输入的目标文本(切句后):"))
|
||||
print(text)
|
||||
_texts = text.split("\n")
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
@@ -105,7 +111,7 @@ class TextPreprocessor:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
if language in ["auto", "zh", "ja"]:
|
||||
# LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["lang"] == "ko":
|
||||
langlist.append("zh")
|
||||
@@ -116,7 +122,7 @@ class TextPreprocessor:
|
||||
langlist.append(language if language!="auto" else tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
# LangSegment.setfilters(["en"])
|
||||
LangSegment.setfilters(["en"])
|
||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
@@ -153,7 +159,7 @@ class TextPreprocessor:
|
||||
# phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
|
||||
return phones, bert_feature, norm_text
|
||||
return phones_list, bert_feature, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user