Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -4,14 +4,11 @@ import itertools
|
||||
import math
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import Iterator
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
from typing import Iterator, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
__all__ = [
|
||||
"DistributedBucketSampler",
|
||||
@@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1)
|
||||
)
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
@@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if (
|
||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
||||
): # type: ignore[arg-type]
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas
|
||||
len(self.dataset) / self.num_replicas,
|
||||
) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
@@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
grouped_batch_size = self.batch_size * self.num_replicas
|
||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||
batches = [
|
||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
||||
for b in range(n_batch)
|
||||
]
|
||||
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
||||
shuffle(batches)
|
||||
indices = list(itertools.chain(*batches))
|
||||
else:
|
||||
@@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
||||
:padding_size
|
||||
]
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Text2SemanticDataModule(LightningDataModule):
|
||||
@@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
|
||||
# pad_val=self.config['data']['pad_val'])
|
||||
|
||||
def train_dataloader(self):
|
||||
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
||||
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
||||
batch_size = (
|
||||
self.config["train"]["batch_size"] // 2
|
||||
if self.config["train"].get("if_dpo", False) is True
|
||||
else self.config["train"]["batch_size"]
|
||||
)
|
||||
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
||||
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
||||
return DataLoader(
|
||||
self._train_dataset,
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import pdb
|
||||
import sys
|
||||
|
||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||
import traceback, os
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch, json
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
version = os.environ.get('version',None)
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
@@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = (
|
||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
)
|
||||
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
batch = np.stack(padded_sequences)
|
||||
@@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
|
||||
super().__init__()
|
||||
|
||||
self.semantic_data = pd.read_csv(
|
||||
semantic_path, delimiter="\t", encoding="utf-8"
|
||||
semantic_path,
|
||||
delimiter="\t",
|
||||
encoding="utf-8",
|
||||
)
|
||||
# get dict
|
||||
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||
self.path3 = "%s/3-bert" % (
|
||||
os.path.dirname(phoneme_path)
|
||||
os.path.dirname(
|
||||
phoneme_path,
|
||||
)
|
||||
) # "%s/3-bert"%exp_dir#bert_dir
|
||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||
assert os.path.exists(self.path2)
|
||||
@@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
|
||||
for i in range(semantic_data_len):
|
||||
# 先依次遍历
|
||||
# get str
|
||||
item_name = self.semantic_data.iloc[i,0]
|
||||
item_name = self.semantic_data.iloc[i, 0]
|
||||
# print(self.phoneme_data)
|
||||
try:
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
@@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
semantic_str = self.semantic_data.iloc[i,1]
|
||||
semantic_str = self.semantic_data.iloc[i, 1]
|
||||
# get token list
|
||||
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||
@@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
if (
|
||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
||||
): ###########2:改为恒定限制为semantic/2.5就行
|
||||
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
||||
num_deleted_ps += 1
|
||||
continue
|
||||
# if len(semantic_ids) > 1000:###########3
|
||||
@@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
|
||||
|
||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||
|
||||
if (
|
||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
||||
): ##########4#3~25#每秒多少个phone
|
||||
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
||||
num_deleted_ps += 1
|
||||
# print(item_name)
|
||||
continue
|
||||
@@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
|
||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||
if num_deleted_bigger > 0:
|
||||
print(
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
||||
)
|
||||
if num_deleted_ps > 0:
|
||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||
print(
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
||||
)
|
||||
"""
|
||||
there are 31 semantic datas not in phoneme datas
|
||||
@@ -306,7 +300,10 @@ if __name__ == "__main__":
|
||||
|
||||
batch_size = 12
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate,
|
||||
shuffle=False,
|
||||
)
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i % 1000 == 0:
|
||||
|
||||
Reference in New Issue
Block a user