more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler",
]
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]):
@@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches
"""
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int]=None,
rank: Optional[int]=None,
shuffle: bool=True,
seed: int=0,
drop_last: bool=False,
batch_size: int=32) -> None:
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(
self.
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) /
self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths
def make_buckets(self, bucket_width: float=2.0):
def make_buckets(self, bucket_width: float = 2.0):
buckets = []
cur = []
max_sec = bucket_width
@@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) *
grouped_batch_size] for b in range(n_batch)
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
shuffle(batches)
indices = list(itertools.chain(*batches))
@@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@@ -6,14 +6,21 @@ from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__()
self.config = config
self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers']
self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self):
pass
@@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'],
pad_val=self.config['data']['pad_val'])
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,
@@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = self.config['train']['batch_size']
sampler = DistributedBucketSampler(
self._train_dataset, batch_size=batch_size)
batch_size = self.config["train"]["batch_size"]
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
batch_size=batch_size,
@@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate,
num_workers=self.num_workers,
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
def val_dataloader(self):
@@ -52,9 +59,9 @@ class Text2SemanticDataModule(LightningDataModule):
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers,12),
num_workers=max(self.num_workers, 12),
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
# 这个会使用到嘛?
@@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule):
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate)
collate_fn=self._train_dataset.collate,
)

View File

@@ -1,21 +1,24 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback,os
import traceback, os
from typing import Dict
from typing import List
import numpy as np
import pandas as pd
import torch,json
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from text import cleaned_text_to_sequence
# from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0]
ndim = seq.ndim
@@ -28,44 +31,52 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
ndim - axis - 1)
padded_seq = np.pad(
seq, padding, mode='constant', constant_values=pad_value)
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch
class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training."""
def __init__(self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25) -> None:
def __init__(
self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25,
) -> None:
super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens
@@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset):
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])#
self.hz=int(os.environ.get("hz","25hz")[:-2])
self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token
self.max_sec = max_sec
@@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self):
semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys())
@@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data['item_name'][i]
item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data['semantic_audio'][i]
semantic_str = self.semantic_data["semantic_audio"][i]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1
continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
try:
phoneme_ids = cleaned_text_to_sequence(phoneme)
@@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset):
idx += 1
self.item_names.append(item_name)
min_num=100#20直接不补#30补了也不存ckpt
leng =len(self.semantic_phoneme)
if(leng<min_num):
tmp1=self.semantic_phoneme
tmp2=self.item_names
self.semantic_phoneme=[]
self.item_names=[]
for _ in range(max(2,int(min_num/leng))):
self.semantic_phoneme+=tmp1
self.item_names+=tmp2
min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme)
if leng < min_num:
tmp1 = self.semantic_phoneme
tmp2 = self.item_names
self.semantic_phoneme = []
self.item_names = []
for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme += tmp1
self.item_names += tmp2
if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
@@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
)
'''
"""
there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463
'''
"""
# 345410 for LibriTTS
print("dataset.__len__():", self.__len__())
@@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset):
# semantic tokens target
semantic_ids_len = len(semantic_ids)
flag=0
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
else:flag=1
if(flag==1):
if os.path.exists(path_bert) == True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature=None
bert_feature = None
else:
assert bert_feature.shape[-1] == len(phoneme_ids)
return {
'idx': idx,
'phoneme_ids': phoneme_ids,
'phoneme_ids_len': phoneme_ids_len,
'semantic_ids': semantic_ids,
'semantic_ids_len': semantic_ids_len,
'bert_feature': bert_feature,
"idx": idx,
"phoneme_ids": phoneme_ids,
"phoneme_ids_len": phoneme_ids_len,
"semantic_ids": semantic_ids,
"semantic_ids_len": semantic_ids_len,
"bert_feature": bert_feature,
}
def get_sample_length(self, idx: int):
@@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_()
for idx, item in enumerate(examples):
bert = item['bert_feature']
if(bert!=None):
bert_padded[idx, :, :bert.shape[-1]] = bert
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
@@ -276,27 +293,27 @@ class Text2SemanticDataset(Dataset):
}
if __name__ == '__main__':
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy',
semantic_path=root_dir + 'semantic_train.tsv')
phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False)
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
)
for i, batch in enumerate(dataloader):
if(i%1000==0):print(i)
if i % 1000 == 0:
print(i)
# if i == 0:
# print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)