Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1 +1 @@
|
||||
from text.g2pw.g2pw import *
|
||||
from text.g2pw.g2pw import *
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
@@ -23,21 +24,24 @@ import numpy as np
|
||||
|
||||
from .utils import tokenize_and_map
|
||||
|
||||
ANCHOR_CHAR = '▁'
|
||||
ANCHOR_CHAR = "▁"
|
||||
|
||||
|
||||
def prepare_onnx_input(tokenizer,
|
||||
labels: List[str],
|
||||
char2phonemes: Dict[str, List[int]],
|
||||
chars: List[str],
|
||||
texts: List[str],
|
||||
query_ids: List[int],
|
||||
use_mask: bool=False,
|
||||
window_size: int=None,
|
||||
max_len: int=512) -> Dict[str, np.array]:
|
||||
def prepare_onnx_input(
|
||||
tokenizer,
|
||||
labels: List[str],
|
||||
char2phonemes: Dict[str, List[int]],
|
||||
chars: List[str],
|
||||
texts: List[str],
|
||||
query_ids: List[int],
|
||||
use_mask: bool = False,
|
||||
window_size: int = None,
|
||||
max_len: int = 512,
|
||||
) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
window_size=window_size, texts=texts, query_ids=query_ids)
|
||||
window_size=window_size, texts=texts, query_ids=query_ids
|
||||
)
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
attention_masks = []
|
||||
@@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(
|
||||
tokenizer=tokenizer, text=text)
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=tokens,
|
||||
text2token=text2token,
|
||||
token2text=token2text)
|
||||
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||
)
|
||||
|
||||
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
|
||||
input_id = list(
|
||||
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
||||
if use_mask else [1] * len(labels)
|
||||
phoneme_mask = (
|
||||
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||
)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[
|
||||
query_id] + 1 # [CLS] token locate at first place
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
@@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
|
||||
position_ids.append(position_id)
|
||||
|
||||
outputs = {
|
||||
'input_ids': np.array(input_ids).astype(np.int64),
|
||||
'token_type_ids': np.array(token_type_ids).astype(np.int64),
|
||||
'attention_masks': np.array(attention_masks).astype(np.int64),
|
||||
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
||||
'char_ids': np.array(char_ids).astype(np.int64),
|
||||
'position_ids': np.array(position_ids).astype(np.int64),
|
||||
"input_ids": np.array(input_ids).astype(np.int64),
|
||||
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||
"char_ids": np.array(char_ids).astype(np.int64),
|
||||
"position_ids": np.array(position_ids).astype(np.int64),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
def _truncate_texts(window_size: int, texts: List[str],
|
||||
query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
truncated_texts = []
|
||||
truncated_query_ids = []
|
||||
for text, query_id in zip(texts, query_ids):
|
||||
@@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
|
||||
return truncated_texts, truncated_query_ids
|
||||
|
||||
|
||||
def _truncate(max_len: int,
|
||||
text: str,
|
||||
query_id: int,
|
||||
tokens: List[str],
|
||||
text2token: List[int],
|
||||
token2text: List[Tuple[int]]):
|
||||
def _truncate(
|
||||
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
|
||||
):
|
||||
truncate_len = max_len - 2
|
||||
if len(tokens) <= truncate_len:
|
||||
return (text, query_id, tokens, text2token, token2text)
|
||||
@@ -137,14 +131,16 @@ def _truncate(max_len: int,
|
||||
start = token2text[token_start][0]
|
||||
end = token2text[token_end - 1][1]
|
||||
|
||||
return (text[start:end], query_id - start, tokens[token_start:token_end], [
|
||||
i - token_start if i is not None else None
|
||||
for i in text2token[start:end]
|
||||
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
|
||||
return (
|
||||
text[start:end],
|
||||
query_id - start,
|
||||
tokens[token_start:token_end],
|
||||
[i - token_start if i is not None else None for i in text2token[start:end]],
|
||||
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
|
||||
)
|
||||
|
||||
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
@@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
return labels, char2phonemes
|
||||
|
||||
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(
|
||||
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
||||
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
|
||||
return labels, char2phonemes
|
||||
|
||||
@@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
|
||||
|
||||
|
||||
class G2PWPinyin(Pinyin):
|
||||
def __init__(self, model_dir='G2PWModel/', model_source=None,
|
||||
enable_non_tradional_chinese=True,
|
||||
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir="G2PWModel/",
|
||||
model_source=None,
|
||||
enable_non_tradional_chinese=True,
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False,
|
||||
**kwargs,
|
||||
):
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style='pinyin',
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
self._converter = Converter(
|
||||
self._g2pw, v_to_u=v_to_u,
|
||||
self._g2pw,
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi,
|
||||
)
|
||||
@@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
|
||||
|
||||
|
||||
class Converter(UltimateConverter):
|
||||
def __init__(self, g2pw_instance, v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False, **kwargs):
|
||||
def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
super(Converter, self).__init__(
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi, **kwargs)
|
||||
v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
|
||||
)
|
||||
|
||||
self._g2pw = g2pw_instance
|
||||
|
||||
def convert(self, words, style, heteronym, errors, strict, **kwargs):
|
||||
pys = []
|
||||
if RE_HANS.match(words):
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
|
||||
errors=errors, strict=strict)
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
|
||||
post_data = self.post_pinyin(words, heteronym, pys)
|
||||
if post_data is not None:
|
||||
pys = post_data
|
||||
|
||||
pys = self.convert_styles(
|
||||
pys, words, style, heteronym, errors, strict)
|
||||
pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
|
||||
|
||||
else:
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors,
|
||||
heteronym=heteronym, strict=strict)
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
|
||||
if py:
|
||||
pys.extend(py)
|
||||
|
||||
@@ -73,13 +75,11 @@ class Converter(UltimateConverter):
|
||||
g2pw_pinyin = self._g2pw(han)
|
||||
|
||||
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
return super(Converter, self).convert(
|
||||
han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
|
||||
for i, item in enumerate(g2pw_pinyin[0]):
|
||||
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
py = super(Converter, self).convert(
|
||||
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
pinyins.extend(py)
|
||||
else:
|
||||
pinyins.append([to_tone(item)])
|
||||
@@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
|
||||
if lst:
|
||||
new_lst_list.append(lst)
|
||||
else:
|
||||
new_lst_list.append([''])
|
||||
new_lst_list.append([""])
|
||||
|
||||
return new_lst_list
|
||||
|
||||
@@ -127,17 +127,17 @@ def get_dict():
|
||||
|
||||
def read_dict():
|
||||
polyphonic_dict = {}
|
||||
with open(PP_DICT_PATH,encoding="utf-8") as f:
|
||||
with open(PP_DICT_PATH, encoding="utf-8") as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
key, value_str = line.split(":")
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f:
|
||||
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
key, value_str = line.split(":")
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
|
||||
@@ -2,44 +2,43 @@
|
||||
# This code is modified from https://github.com/GitYCC/g2pW
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import json
|
||||
import os
|
||||
import zipfile,requests
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
import zipfile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import requests
|
||||
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
from opencc import OpenCC
|
||||
from pypinyin import Style, pinyin
|
||||
from transformers import AutoTokenizer
|
||||
from pypinyin import pinyin
|
||||
from pypinyin import Style
|
||||
|
||||
from .dataset import get_char_phoneme_labels
|
||||
from .dataset import get_phoneme_labels
|
||||
from .dataset import prepare_onnx_input
|
||||
from .utils import load_config
|
||||
from ..zh_normalization.char_convert import tranditional_to_simplified
|
||||
from .dataset import get_char_phoneme_labels, get_phoneme_labels, prepare_onnx_input
|
||||
from .utils import load_config
|
||||
|
||||
model_version = '1.1'
|
||||
model_version = "1.1"
|
||||
|
||||
|
||||
def predict(session, onnx_input: Dict[str, Any],
|
||||
labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
all_preds = []
|
||||
all_confidences = []
|
||||
probs = session.run([], {
|
||||
"input_ids": onnx_input['input_ids'],
|
||||
"token_type_ids": onnx_input['token_type_ids'],
|
||||
"attention_mask": onnx_input['attention_masks'],
|
||||
"phoneme_mask": onnx_input['phoneme_masks'],
|
||||
"char_ids": onnx_input['char_ids'],
|
||||
"position_ids": onnx_input['position_ids']
|
||||
})[0]
|
||||
probs = session.run(
|
||||
[],
|
||||
{
|
||||
"input_ids": onnx_input["input_ids"],
|
||||
"token_type_ids": onnx_input["token_type_ids"],
|
||||
"attention_mask": onnx_input["attention_masks"],
|
||||
"phoneme_mask": onnx_input["phoneme_masks"],
|
||||
"char_ids": onnx_input["char_ids"],
|
||||
"position_ids": onnx_input["position_ids"],
|
||||
},
|
||||
)[0]
|
||||
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
max_probs = []
|
||||
@@ -51,17 +50,17 @@ def predict(session, onnx_input: Dict[str, Any],
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
|
||||
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"#"https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, 'wb') as f:
|
||||
with open(zip_dir, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
@@ -69,17 +68,20 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
print("Extracting g2pw model...")
|
||||
with zipfile.ZipFile(zip_dir, "r") as zip_ref:
|
||||
zip_ref.extractall(parent_directory)
|
||||
|
||||
|
||||
os.rename(extract_dir, extract_dir_new)
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
def __init__(self,
|
||||
model_dir: str='G2PWModel/',
|
||||
style: str='bopomofo',
|
||||
model_source: str=None,
|
||||
enable_non_tradional_chinese: bool=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
@@ -87,41 +89,59 @@ class G2PWOnnxConverter:
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
try:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
except:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
||||
self.config = load_config(
|
||||
config_path=os.path.join(uncompress_path, 'config.py'),
|
||||
use_default=True)
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path,
|
||||
'POLYPHONIC_CHARS.txt')
|
||||
monophonic_chars_path = os.path.join(uncompress_path,
|
||||
'MONOPHONIC_CHARS.txt')
|
||||
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||
self.polyphonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(polyphonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
self.non_polyphonic = {
|
||||
'一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗',
|
||||
'肖', '瘙', '誒', '泊', '听', '噢'
|
||||
"一",
|
||||
"不",
|
||||
"和",
|
||||
"咋",
|
||||
"嗲",
|
||||
"剖",
|
||||
"差",
|
||||
"攢",
|
||||
"倒",
|
||||
"難",
|
||||
"奔",
|
||||
"勁",
|
||||
"拗",
|
||||
"肖",
|
||||
"瘙",
|
||||
"誒",
|
||||
"泊",
|
||||
"听",
|
||||
"噢",
|
||||
}
|
||||
self.non_monophonic = {'似', '攢'}
|
||||
self.non_monophonic = {"似", "攢"}
|
||||
self.monophonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(monophonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars
|
||||
) if self.config.use_char_phoneme else get_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars)
|
||||
self.labels, self.char2phonemes = (
|
||||
get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||
if self.config.use_char_phoneme
|
||||
else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||
)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
|
||||
@@ -130,41 +150,29 @@ class G2PWOnnxConverter:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
|
||||
self.monophonic_chars_dict = {
|
||||
char: phoneme
|
||||
for char, phoneme in self.monophonic_chars
|
||||
}
|
||||
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
|
||||
self.pos_tags = [
|
||||
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
||||
]
|
||||
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path,
|
||||
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
'bopomofo': lambda x: x,
|
||||
'pinyin': self._convert_bopomofo_to_pinyin,
|
||||
"bopomofo": lambda x: x,
|
||||
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC('s2tw')
|
||||
self.cc = OpenCC("s2tw")
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
assert tone in '12345'
|
||||
assert tone in "12345"
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
@@ -184,8 +192,7 @@ class G2PWOnnxConverter:
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
||||
sentences=sentences)
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
@@ -198,14 +205,12 @@ class G2PWOnnxConverter:
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None)
|
||||
window_size=None,
|
||||
)
|
||||
|
||||
preds, confidences = predict(
|
||||
session=self.session_g2pW,
|
||||
onnx_input=onnx_input,
|
||||
labels=self.labels)
|
||||
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(' ')[1] for pred in preds]
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
@@ -213,15 +218,12 @@ class G2PWOnnxConverter:
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(
|
||||
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
@@ -229,8 +231,7 @@ class G2PWOnnxConverter:
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(
|
||||
self.monophonic_chars_dict[char])
|
||||
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
@@ -24,14 +25,14 @@ def wordize_and_map(text: str):
|
||||
index_map_from_text_to_word = []
|
||||
index_map_from_word_to_text = []
|
||||
while len(text) > 0:
|
||||
match_space = re.match(r'^ +', text)
|
||||
match_space = re.match(r"^ +", text)
|
||||
if match_space:
|
||||
space_str = match_space.group(0)
|
||||
index_map_from_text_to_word += [None] * len(space_str)
|
||||
text = text[len(space_str):]
|
||||
text = text[len(space_str) :]
|
||||
continue
|
||||
|
||||
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||||
match_en = re.match(r"^[a-zA-Z0-9]+", text)
|
||||
if match_en:
|
||||
en_word = match_en.group(0)
|
||||
|
||||
@@ -42,7 +43,7 @@ def wordize_and_map(text: str):
|
||||
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||||
|
||||
words.append(en_word)
|
||||
text = text[len(en_word):]
|
||||
text = text[len(en_word) :]
|
||||
else:
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + 1
|
||||
@@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
|
||||
for word, (word_start, word_end) in zip(words, word2text):
|
||||
word_tokens = tokenizer.tokenize(word)
|
||||
|
||||
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||||
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
|
||||
index_map_from_token_to_text.append((word_start, word_end))
|
||||
tokens.append('[UNK]')
|
||||
tokens.append("[UNK]")
|
||||
else:
|
||||
current_word_start = word_start
|
||||
for word_token in word_tokens:
|
||||
word_token_len = len(re.sub(r'^##', '', word_token))
|
||||
index_map_from_token_to_text.append(
|
||||
(current_word_start, current_word_start + word_token_len))
|
||||
word_token_len = len(re.sub(r"^##", "", word_token))
|
||||
index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
|
||||
current_word_start = current_word_start + word_token_len
|
||||
tokens.append(word_token)
|
||||
|
||||
@@ -85,53 +85,51 @@ def tokenize_and_map(tokenizer, text: str):
|
||||
|
||||
def _load_config(config_path: os.PathLike):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("__init__", config_path)
|
||||
config = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config)
|
||||
return config
|
||||
|
||||
|
||||
default_config_dict = {
|
||||
'manual_seed': 1313,
|
||||
'model_source': 'bert-base-chinese',
|
||||
'window_size': 32,
|
||||
'num_workers': 2,
|
||||
'use_mask': True,
|
||||
'use_char_phoneme': False,
|
||||
'use_conditional': True,
|
||||
'param_conditional': {
|
||||
'affect_location': 'softmax',
|
||||
'bias': True,
|
||||
'char-linear': True,
|
||||
'pos-linear': False,
|
||||
'char+pos-second': True,
|
||||
'char+pos-second_lowrank': False,
|
||||
'lowrank_size': 0,
|
||||
'char+pos-second_fm': False,
|
||||
'fm_size': 0,
|
||||
'fix_mode': None,
|
||||
'count_json': 'train.count.json'
|
||||
"manual_seed": 1313,
|
||||
"model_source": "bert-base-chinese",
|
||||
"window_size": 32,
|
||||
"num_workers": 2,
|
||||
"use_mask": True,
|
||||
"use_char_phoneme": False,
|
||||
"use_conditional": True,
|
||||
"param_conditional": {
|
||||
"affect_location": "softmax",
|
||||
"bias": True,
|
||||
"char-linear": True,
|
||||
"pos-linear": False,
|
||||
"char+pos-second": True,
|
||||
"char+pos-second_lowrank": False,
|
||||
"lowrank_size": 0,
|
||||
"char+pos-second_fm": False,
|
||||
"fm_size": 0,
|
||||
"fix_mode": None,
|
||||
"count_json": "train.count.json",
|
||||
},
|
||||
'lr': 5e-5,
|
||||
'val_interval': 200,
|
||||
'num_iter': 10000,
|
||||
'use_focal': False,
|
||||
'param_focal': {
|
||||
'alpha': 0.0,
|
||||
'gamma': 0.7
|
||||
"lr": 5e-5,
|
||||
"val_interval": 200,
|
||||
"num_iter": 10000,
|
||||
"use_focal": False,
|
||||
"param_focal": {"alpha": 0.0, "gamma": 0.7},
|
||||
"use_pos": True,
|
||||
"param_pos ": {
|
||||
"weight": 0.1,
|
||||
"pos_joint_training": True,
|
||||
"train_pos_path": "train.pos",
|
||||
"valid_pos_path": "dev.pos",
|
||||
"test_pos_path": "test.pos",
|
||||
},
|
||||
'use_pos': True,
|
||||
'param_pos ': {
|
||||
'weight': 0.1,
|
||||
'pos_joint_training': True,
|
||||
'train_pos_path': 'train.pos',
|
||||
'valid_pos_path': 'dev.pos',
|
||||
'test_pos_path': 'test.pos'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path: os.PathLike, use_default: bool=False):
|
||||
def load_config(config_path: os.PathLike, use_default: bool = False):
|
||||
config = _load_config(config_path)
|
||||
if use_default:
|
||||
for attr, val in default_config_dict.items():
|
||||
|
||||
Reference in New Issue
Block a user