Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,38 +1,45 @@
import warnings
warnings.filterwarnings("ignore")
import utils, os
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
from torch.nn import functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from collections import OrderedDict as od
from random import randint
from module import commons
from peft import LoraConfig, PeftModel, get_peft_model
from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from peft import LoraConfig, get_peft_model
from process_ckpt import savee
from collections import OrderedDict as od
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
@@ -46,7 +53,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
@@ -65,7 +71,7 @@ def main():
def run(rank, n_gpus, hps):
global global_step,no_grad_names,save_root,lora_rank
global global_step, no_grad_names, save_root, lora_rank
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
@@ -74,7 +80,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@@ -122,21 +128,24 @@ def run(rank, n_gpus, hps):
persistent_workers=True,
prefetch_factor=4,
)
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
os.makedirs(save_root,exist_ok=True)
lora_rank=int(hps.train.lora_rank)
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
os.makedirs(save_root, exist_ok=True)
lora_rank = int(hps.train.lora_rank)
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
def get_model(hps):return SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
def get_model(hps):
return SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
def get_optim(net_g):
return torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
@@ -144,61 +153,66 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas,
eps=hps.train.eps,
)
def model2cuda(net_g,rank):
def model2cuda(net_g, rank):
if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
return net_g
try:# 如果能加载自动resume
try: # 如果能加载自动resume
net_g = get_model(hps)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
optim_g=get_optim(net_g)
net_g = model2cuda(net_g, rank)
optim_g = get_optim(net_g)
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(save_root, "G_*.pth"),
net_g,
optim_g,
)
epoch_str+=1
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
except: # 如果首次不能加载加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
net_g = get_model(hps)
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G,
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
),
)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
net_g = model2cuda(net_g, rank)
optim_g = get_optim(net_g)
no_grad_names=set()
no_grad_names = set()
for name, param in net_g.named_parameters():
if not param.requires_grad:
no_grad_names.add(name.replace("module.",""))
no_grad_names.add(name.replace("module.", ""))
# print(name, "not requires_grad")
# print(no_grad_names)
# os._exit(233333)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
for _ in range(epoch_str):
scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
print("start training from epoch %s"%epoch_str)
net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
@@ -230,9 +244,8 @@ def run(rank, n_gpus, hps):
scheduler_g.step()
print("training done")
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@@ -244,18 +257,32 @@ def train_and_evaluate(
global global_step
net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
rank, non_blocking=True
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@@ -265,8 +292,18 @@ def train_and_evaluate(
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
@@ -276,18 +313,17 @@ def train_and_evaluate(
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
lr = optim_g.param_groups[0]["lr"]
losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict)
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
@@ -297,9 +333,7 @@ def train_and_evaluate(
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(global_step)
),
os.path.join(save_root, "G_{}.pth".format(global_step)),
)
else:
utils.save_checkpoint(
@@ -307,21 +341,19 @@ def train_and_evaluate(
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(233333333333)
),
os.path.join(save_root, "G_{}.pth".format(233333333333)),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
sim_ckpt=od()
sim_ckpt = od()
for key in ckpt:
# if "cfm"not in key:
# print(key)
if key not in no_grad_names:
sim_ckpt[key]=ckpt[key].half().cpu()
sim_ckpt[key] = ckpt[key].half().cpu()
logger.info(
"saving ckpt %s_e%s:%s"
% (
@@ -329,10 +361,11 @@ def train_and_evaluate(
epoch,
savee(
sim_ckpt,
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
epoch,
global_step,
hps,lora_rank=lora_rank
hps,
lora_rank=lora_rank,
),
)
)