Update Documentation (#2032)
* add exception handling * Fill in the missing content * remove GB code * Simplify i18n text and remove trailing spaces * ignore v3 model dir * Update Changelog * Fix encoding
This commit is contained in:
@@ -49,18 +49,18 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
print(f'############ {i18n("切分文本")} ############')
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(i18n("############ 提取文本Bert特征 ############"))
|
||||
print(f'############ {i18n("提取文本Bert特征")} ############')
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text=="":
|
||||
@@ -77,14 +77,14 @@ class TextPreprocessor:
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
|
||||
|
||||
seg_method = get_seg_method(text_split_method)
|
||||
text = seg_method(text)
|
||||
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
@@ -93,29 +93,29 @@ class TextPreprocessor:
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if (len(text) > 510):
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
|
||||
|
||||
print(i18n("实际输入的目标文本(切句后):"))
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
|
||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
@@ -203,7 +203,7 @@ class TextPreprocessor:
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
|
||||
def clean_text_inf(self, text:str, language:str, version:str="v2"):
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
@@ -232,13 +232,10 @@ class TextPreprocessor:
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(self,text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user