Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,21 +1,17 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
# reference: https://github.com/lifeiteng/vall-e
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback, os
from typing import Dict
from typing import List
import os
import traceback
from typing import Dict, List
import numpy as np
import pandas as pd
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader, Dataset
version = os.environ.get('version',None)
version = os.environ.get("version", None)
from text import cleaned_text_to_sequence
@@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
@@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
super().__init__()
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
semantic_path,
delimiter="\t",
encoding="utf-8",
)
# get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.dirname(phoneme_path)
os.path.dirname(
phoneme_path,
)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
@@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data.iloc[i,0]
item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data.iloc[i,1]
semantic_str = self.semantic_data.iloc[i, 1]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
@@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
print(
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
)
if num_deleted_ps > 0:
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
)
"""
there are 31 semantic datas not in phoneme datas
@@ -306,7 +300,10 @@ if __name__ == "__main__":
batch_size = 12
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False,
)
for i, batch in enumerate(dataloader):
if i % 1000 == 0: