Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,36 +1,41 @@
import warnings
warnings.filterwarnings("ignore")
import utils, os
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from module import commons
from module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler,
TextAudioSpeakerCollate,
TextAudioSpeakerLoader,
)
from module.models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from module.models import (
MultiPeriodDiscriminator,
SynthesizerTrn,
)
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
@@ -46,7 +51,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
@@ -74,7 +78,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@@ -128,19 +132,27 @@ def run(rank, n_gpus, hps):
# batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
net_g = (
SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
)
for name, param in net_g.named_parameters():
if not param.requires_grad:
print(name, "not requires_grad")
@@ -193,7 +205,7 @@ def run(rank, n_gpus, hps):
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事
@@ -201,11 +213,11 @@ def run(rank, n_gpus, hps):
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
net_g,
optim_g,
)
epoch_str+=1
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
@@ -213,37 +225,55 @@ def run(rank, n_gpus, hps):
# traceback.print_exc()
epoch_str = 1
global_step = 0
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G,
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
) if torch.cuda.is_available() else net_g.load_state_dict(
)
if torch.cuda.is_available()
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
),
) ##测试不加载优化器
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
if (
hps.train.pretrained_s2D != ""
and hps.train.pretrained_s2D != None
and os.path.exists(hps.train.pretrained_s2D)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
print("loaded pretrained %s" % hps.train.pretrained_s2D,
print(
"loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
) if torch.cuda.is_available() else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
)
if torch.cuda.is_available()
else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
),
)
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
optim_g,
gamma=hps.train.lr_decay,
last_epoch=-1,
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
optim_d,
gamma=hps.train.lr_decay,
last_epoch=-1,
)
for _ in range(epoch_str):
scheduler_g.step()
@@ -285,9 +315,7 @@ def run(rank, n_gpus, hps):
print("training done")
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@@ -311,17 +339,38 @@ def train_and_evaluate(
text_lengths,
) in enumerate(tqdm(train_loader)):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
y, y_lengths = (
y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@@ -350,9 +399,7 @@ def train_and_evaluate(
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
@@ -364,15 +411,14 @@ def train_and_evaluate(
hps.data.mel_fmax,
)
y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
y_d_hat_r,
y_d_hat_g,
)
loss_disc_all = loss_disc
optim_d.zero_grad()
@@ -405,7 +451,8 @@ def train_and_evaluate(
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
@@ -429,25 +476,37 @@ def train_and_evaluate(
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict=None
try:###Some people installed the wrong version of matplotlib.
image_dict = None
try: ###Some people installed the wrong version of matplotlib.
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
y_mel[0].data.cpu().numpy(),
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
y_hat_mel[0].data.cpu().numpy(),
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
mel[0].data.cpu().numpy(),
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy()
stats_ssl[0].data.cpu().numpy(),
),
}
except:pass
if image_dict:utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict,)
else:utils.summarize(writer=writer,global_step=global_step,scalars=scalar_dict,)
except:
pass
if image_dict:
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
else:
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
@@ -457,7 +516,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
),
)
utils.save_checkpoint(
@@ -466,7 +526,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(global_step),
),
)
else:
@@ -476,7 +537,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
),
)
utils.save_checkpoint(
@@ -485,7 +547,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(233333333333),
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
@@ -540,10 +603,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
ssl = ssl.to(device)
text, text_lengths = text.to(device), text_lengths.to(device)
for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer(
ssl, spec, spec_lengths, text, text_lengths, test=test
) if torch.cuda.is_available() else generator.infer(
ssl, spec, spec_lengths, text, text_lengths, test=test
y_hat, mask, *_ = (
generator.module.infer(
ssl,
spec,
spec_lengths,
text,
text_lengths,
test=test,
)
if torch.cuda.is_available()
else generator.infer(
ssl,
spec,
spec_lengths,
text,
text_lengths,
test=test,
)
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
@@ -568,19 +645,19 @@ def evaluate(hps, generator, eval_loader, writer_eval):
image_dict.update(
{
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
)
y_hat_mel[0].cpu().numpy(),
),
}
)
audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
{
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
},
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
},
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})