Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -15,6 +15,7 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple
@@ -23,21 +24,24 @@ import numpy as np
from .utils import tokenize_and_map
ANCHOR_CHAR = ''
ANCHOR_CHAR = ""
def prepare_onnx_input(tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
def prepare_onnx_input(
tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
window_size=window_size, texts=texts, query_ids=query_ids
)
input_ids = []
token_type_ids = []
attention_masks = []
@@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_id = list(
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
phoneme_mask = (
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char)
position_id = text2token[
query_id] + 1 # [CLS] token locate at first place
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
@@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids).astype(np.int64),
'token_type_ids': np.array(token_type_ids).astype(np.int64),
'attention_masks': np.array(attention_masks).astype(np.int64),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids).astype(np.int64),
'position_ids': np.array(position_ids).astype(np.int64),
"input_ids": np.array(input_ids).astype(np.int64),
"token_type_ids": np.array(token_type_ids).astype(np.int64),
"attention_masks": np.array(attention_masks).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),
}
return outputs
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
@@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
return truncated_texts, truncated_query_ids
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
def _truncate(
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
@@ -137,14 +131,16 @@ def _truncate(max_len: int,
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (text[start:end], query_id - start, tokens[token_start:token_end], [
i - token_start if i is not None else None
for i in text2token[start:end]
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
)
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
@@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
return labels, char2phonemes