more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -1,21 +1,24 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback,os
import traceback, os
from typing import Dict
from typing import List
import numpy as np
import pandas as pd
import torch,json
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from text import cleaned_text_to_sequence
# from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0]
ndim = seq.ndim
@@ -28,44 +31,52 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
ndim - axis - 1)
padded_seq = np.pad(
seq, padding, mode='constant', constant_values=pad_value)
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch
class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training."""
def __init__(self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25) -> None:
def __init__(
self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25,
) -> None:
super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens
@@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset):
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])#
self.hz=int(os.environ.get("hz","25hz")[:-2])
self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token
self.max_sec = max_sec
@@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self):
semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys())
@@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data['item_name'][i]
item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data['semantic_audio'][i]
semantic_str = self.semantic_data["semantic_audio"][i]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1
continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
try:
phoneme_ids = cleaned_text_to_sequence(phoneme)
@@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset):
idx += 1
self.item_names.append(item_name)
min_num=100#20直接不补#30补了也不存ckpt
leng =len(self.semantic_phoneme)
if(leng<min_num):
tmp1=self.semantic_phoneme
tmp2=self.item_names
self.semantic_phoneme=[]
self.item_names=[]
for _ in range(max(2,int(min_num/leng))):
self.semantic_phoneme+=tmp1
self.item_names+=tmp2
min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme)
if leng < min_num:
tmp1 = self.semantic_phoneme
tmp2 = self.item_names
self.semantic_phoneme = []
self.item_names = []
for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme += tmp1
self.item_names += tmp2
if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
@@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
)
'''
"""
there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463
'''
"""
# 345410 for LibriTTS
print("dataset.__len__():", self.__len__())
@@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset):
# semantic tokens target
semantic_ids_len = len(semantic_ids)
flag=0
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
else:flag=1
if(flag==1):
if os.path.exists(path_bert) == True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature=None
bert_feature = None
else:
assert bert_feature.shape[-1] == len(phoneme_ids)
return {
'idx': idx,
'phoneme_ids': phoneme_ids,
'phoneme_ids_len': phoneme_ids_len,
'semantic_ids': semantic_ids,
'semantic_ids_len': semantic_ids_len,
'bert_feature': bert_feature,
"idx": idx,
"phoneme_ids": phoneme_ids,
"phoneme_ids_len": phoneme_ids_len,
"semantic_ids": semantic_ids,
"semantic_ids_len": semantic_ids_len,
"bert_feature": bert_feature,
}
def get_sample_length(self, idx: int):
@@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_()
for idx, item in enumerate(examples):
bert = item['bert_feature']
if(bert!=None):
bert_padded[idx, :, :bert.shape[-1]] = bert
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
@@ -276,27 +293,27 @@ class Text2SemanticDataset(Dataset):
}
if __name__ == '__main__':
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy',
semantic_path=root_dir + 'semantic_train.tsv')
phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False)
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
)
for i, batch in enumerate(dataloader):
if(i%1000==0):print(i)
if i % 1000 == 0:
print(i)
# if i == 0:
# print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)