Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1 +1 @@
from text.g2pw.g2pw import *
from text.g2pw.g2pw import *

View File

@@ -15,6 +15,7 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple
@@ -23,21 +24,24 @@ import numpy as np
from .utils import tokenize_and_map
ANCHOR_CHAR = ''
ANCHOR_CHAR = ""
def prepare_onnx_input(tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
def prepare_onnx_input(
tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
window_size=window_size, texts=texts, query_ids=query_ids
)
input_ids = []
token_type_ids = []
attention_masks = []
@@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_id = list(
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
phoneme_mask = (
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char)
position_id = text2token[
query_id] + 1 # [CLS] token locate at first place
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
@@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids).astype(np.int64),
'token_type_ids': np.array(token_type_ids).astype(np.int64),
'attention_masks': np.array(attention_masks).astype(np.int64),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids).astype(np.int64),
'position_ids': np.array(position_ids).astype(np.int64),
"input_ids": np.array(input_ids).astype(np.int64),
"token_type_ids": np.array(token_type_ids).astype(np.int64),
"attention_masks": np.array(attention_masks).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),
}
return outputs
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
@@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
return truncated_texts, truncated_query_ids
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
def _truncate(
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
@@ -137,14 +131,16 @@ def _truncate(max_len: int,
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (text[start:end], query_id - start, tokens[token_start:token_end], [
i - token_start if i is not None else None
for i in text2token[start:end]
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
)
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
@@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
return labels, char2phonemes

View File

@@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
class G2PWPinyin(Pinyin):
def __init__(self, model_dir='G2PWModel/', model_source=None,
enable_non_tradional_chinese=True,
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
def __init__(
self,
model_dir="G2PWModel/",
model_source=None,
enable_non_tradional_chinese=True,
v_to_u=False,
neutral_tone_with_five=False,
tone_sandhi=False,
**kwargs,
):
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style='pinyin',
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
self._converter = Converter(
self._g2pw, v_to_u=v_to_u,
self._g2pw,
v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi,
)
@@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
class Converter(UltimateConverter):
def __init__(self, g2pw_instance, v_to_u=False,
neutral_tone_with_five=False,
tone_sandhi=False, **kwargs):
def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
super(Converter, self).__init__(
v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi, **kwargs)
v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
)
self._g2pw = g2pw_instance
def convert(self, words, style, heteronym, errors, strict, **kwargs):
pys = []
if RE_HANS.match(words):
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
errors=errors, strict=strict)
pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
post_data = self.post_pinyin(words, heteronym, pys)
if post_data is not None:
pys = post_data
pys = self.convert_styles(
pys, words, style, heteronym, errors, strict)
pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
else:
py = self.handle_nopinyin(words, style=style, errors=errors,
heteronym=heteronym, strict=strict)
py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
if py:
pys.extend(py)
@@ -73,13 +75,11 @@ class Converter(UltimateConverter):
g2pw_pinyin = self._g2pw(han)
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
return super(Converter, self).convert(
han, Style.TONE, heteronym, errors, strict, **kwargs)
return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
for i, item in enumerate(g2pw_pinyin[0]):
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
py = super(Converter, self).convert(
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
pinyins.extend(py)
else:
pinyins.append([to_tone(item)])
@@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
if lst:
new_lst_list.append(lst)
else:
new_lst_list.append([''])
new_lst_list.append([""])
return new_lst_list
@@ -127,17 +127,17 @@ def get_dict():
def read_dict():
polyphonic_dict = {}
with open(PP_DICT_PATH,encoding="utf-8") as f:
with open(PP_DICT_PATH, encoding="utf-8") as f:
line = f.readline()
while line:
key, value_str = line.split(':')
key, value_str = line.split(":")
value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value
line = f.readline()
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f:
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
line = f.readline()
while line:
key, value_str = line.split(':')
key, value_str = line.split(":")
value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value
line = f.readline()

View File

@@ -2,44 +2,43 @@
# This code is modified from https://github.com/GitYCC/g2pW
import warnings
warnings.filterwarnings("ignore")
import json
import os
import zipfile,requests
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
import zipfile
from typing import Any, Dict, List, Tuple
import numpy as np
import onnxruntime
import requests
onnxruntime.set_default_logger_severity(3)
from opencc import OpenCC
from pypinyin import Style, pinyin
from transformers import AutoTokenizer
from pypinyin import pinyin
from pypinyin import Style
from .dataset import get_char_phoneme_labels
from .dataset import get_phoneme_labels
from .dataset import prepare_onnx_input
from .utils import load_config
from ..zh_normalization.char_convert import tranditional_to_simplified
from .dataset import get_char_phoneme_labels, get_phoneme_labels, prepare_onnx_input
from .utils import load_config
model_version = '1.1'
model_version = "1.1"
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
all_confidences = []
probs = session.run([], {
"input_ids": onnx_input['input_ids'],
"token_type_ids": onnx_input['token_type_ids'],
"attention_mask": onnx_input['attention_masks'],
"phoneme_mask": onnx_input['phoneme_masks'],
"char_ids": onnx_input['char_ids'],
"position_ids": onnx_input['position_ids']
})[0]
probs = session.run(
[],
{
"input_ids": onnx_input["input_ids"],
"token_type_ids": onnx_input["token_type_ids"],
"attention_mask": onnx_input["attention_masks"],
"phoneme_mask": onnx_input["phoneme_masks"],
"char_ids": onnx_input["char_ids"],
"position_ids": onnx_input["position_ids"],
},
)[0]
preds = np.argmax(probs, axis=1).tolist()
max_probs = []
@@ -51,17 +50,17 @@ def predict(session, onnx_input: Dict[str, Any],
return all_preds, all_confidences
def download_and_decompress(model_dir: str='G2PWModel/'):
def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
print("Downloading g2pw model...")
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"#"https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status()
with open(zip_dir, 'wb') as f:
with open(zip_dir, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
@@ -69,17 +68,20 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
print("Extracting g2pw model...")
with zipfile.ZipFile(zip_dir, "r") as zip_ref:
zip_ref.extractall(parent_directory)
os.rename(extract_dir, extract_dir_new)
return model_dir
class G2PWOnnxConverter:
def __init__(self,
model_dir: str='G2PWModel/',
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
uncompress_path = download_and_decompress(model_dir)
sess_options = onnxruntime.SessionOptions()
@@ -87,41 +89,59 @@ class G2PWOnnxConverter:
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
try:
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
except:
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
self.config = load_config(
config_path=os.path.join(uncompress_path, 'config.py'),
use_default=True)
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(uncompress_path,
'POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(uncompress_path,
'MONOPHONIC_CHARS.txt')
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
self.polyphonic_chars = [
line.split('\t')
for line in open(polyphonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
self.non_polyphonic = {
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', ''
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
}
self.non_monophonic = {'', ''}
self.non_monophonic = {"", ""}
self.monophonic_chars = [
line.split('\t')
for line in open(monophonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
polyphonic_chars=self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
polyphonic_chars=self.polyphonic_chars)
self.labels, self.char2phonemes = (
get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
if self.config.use_char_phoneme
else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
)
self.chars = sorted(list(self.char2phonemes.keys()))
@@ -130,41 +150,29 @@ class G2PWOnnxConverter:
if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char)
self.monophonic_chars_dict = {
char: phoneme
for char, phoneme in self.monophonic_chars
}
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
for char in self.non_monophonic:
if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char)
self.pos_tags = [
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
]
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
with open(
os.path.join(uncompress_path,
'bopomofo_to_pinyin_wo_tune_dict.json'),
'r',
encoding='utf-8') as fr:
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
"bopomofo": lambda x: x,
"pinyin": self._convert_bopomofo_to_pinyin,
}[style]
with open(
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
'r',
encoding='utf-8') as fr:
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC('s2tw')
self.cc = OpenCC("s2tw")
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
assert tone in '12345'
assert tone in "12345"
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
@@ -184,8 +192,7 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences=sentences)
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
@@ -198,14 +205,12 @@ class G2PWOnnxConverter:
texts=texts,
query_ids=query_ids,
use_mask=self.config.use_mask,
window_size=None)
window_size=None,
)
preds, confidences = predict(
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]
preds = [pred.split(" ")[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
@@ -213,15 +218,12 @@ class G2PWOnnxConverter:
return results
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin(
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
@@ -229,8 +231,7 @@ class G2PWOnnxConverter:
query_ids.append(i)
sent_ids.append(sent_id)
elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func(
self.monophonic_chars_dict[char])
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])

View File

@@ -15,6 +15,7 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re
@@ -24,14 +25,14 @@ def wordize_and_map(text: str):
index_map_from_text_to_word = []
index_map_from_word_to_text = []
while len(text) > 0:
match_space = re.match(r'^ +', text)
match_space = re.match(r"^ +", text)
if match_space:
space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):]
text = text[len(space_str) :]
continue
match_en = re.match(r'^[a-zA-Z0-9]+', text)
match_en = re.match(r"^[a-zA-Z0-9]+", text)
if match_en:
en_word = match_en.group(0)
@@ -42,7 +43,7 @@ def wordize_and_map(text: str):
index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word)
text = text[len(en_word):]
text = text[len(en_word) :]
else:
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1
@@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]')
tokens.append("[UNK]")
else:
current_word_start = word_start
for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token))
index_map_from_token_to_text.append(
(current_word_start, current_word_start + word_token_len))
word_token_len = len(re.sub(r"^##", "", word_token))
index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len
tokens.append(word_token)
@@ -85,53 +85,51 @@ def tokenize_and_map(tokenizer, text: str):
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
spec = importlib.util.spec_from_file_location("__init__", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
default_config_dict = {
'manual_seed': 1313,
'model_source': 'bert-base-chinese',
'window_size': 32,
'num_workers': 2,
'use_mask': True,
'use_char_phoneme': False,
'use_conditional': True,
'param_conditional': {
'affect_location': 'softmax',
'bias': True,
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
'fm_size': 0,
'fix_mode': None,
'count_json': 'train.count.json'
"manual_seed": 1313,
"model_source": "bert-base-chinese",
"window_size": 32,
"num_workers": 2,
"use_mask": True,
"use_char_phoneme": False,
"use_conditional": True,
"param_conditional": {
"affect_location": "softmax",
"bias": True,
"char-linear": True,
"pos-linear": False,
"char+pos-second": True,
"char+pos-second_lowrank": False,
"lowrank_size": 0,
"char+pos-second_fm": False,
"fm_size": 0,
"fix_mode": None,
"count_json": "train.count.json",
},
'lr': 5e-5,
'val_interval': 200,
'num_iter': 10000,
'use_focal': False,
'param_focal': {
'alpha': 0.0,
'gamma': 0.7
"lr": 5e-5,
"val_interval": 200,
"num_iter": 10000,
"use_focal": False,
"param_focal": {"alpha": 0.0, "gamma": 0.7},
"use_pos": True,
"param_pos ": {
"weight": 0.1,
"pos_joint_training": True,
"train_pos_path": "train.pos",
"valid_pos_path": "dev.pos",
"test_pos_path": "test.pos",
},
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
}
def load_config(config_path: os.PathLike, use_default: bool=False):
def load_config(config_path: os.PathLike, use_default: bool = False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():