support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
This commit is contained in:
RVC-Boss
2025-06-04 15:16:47 +08:00
committed by GitHub
parent b7c0c5ca87
commit 921ac6c41a
9 changed files with 284 additions and 585 deletions

View File

@@ -36,7 +36,7 @@ from module.models import (
MultiPeriodDiscriminator,
SynthesizerTrn,
)
from process_ckpt import savee
from process_ckpt import savee,my_save2
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
@@ -87,38 +87,19 @@ def run(rank, n_gpus, hps):
if torch.cuda.is_available():
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_dataset = TextAudioSpeakerLoader(hps.data,version=hps.model.version)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
[32,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
collate_fn = TextAudioSpeakerCollate(version=hps.model.version)
train_loader = DataLoader(
train_dataset,
num_workers=6,
num_workers=5,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
@@ -149,9 +130,9 @@ def run(rank, n_gpus, hps):
)
net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).cuda(rank)
if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
else MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).to(device)
)
for name, param in net_g.named_parameters():
if not param.requires_grad:
@@ -235,12 +216,12 @@ def run(rank, n_gpus, hps):
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
)
if torch.cuda.is_available()
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
),
) ##测试不加载优化器
@@ -254,11 +235,11 @@ def run(rank, n_gpus, hps):
print(
"loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],strict=False
)
if torch.cuda.is_available()
else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
),
)
@@ -328,50 +309,20 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train()
net_d.train()
for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in enumerate(tqdm(train_loader)):
for batch_idx, data in enumerate(tqdm(train_loader)):
if hps.model.version in {"v2Pro","v2ProPlus"}:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths,sv_emb=data
else:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths=data
if torch.cuda.is_available():
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
y, y_lengths = (
y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
)
spec, spec_lengths = (spec.cuda(rank,non_blocking=True,),spec_lengths.cuda(rank,non_blocking=True,),)
y, y_lengths = (y.cuda(rank,non_blocking=True,),y_lengths.cuda(rank,non_blocking=True,),)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
text, text_lengths = (text.cuda(rank,non_blocking=True,),text_lengths.cuda(rank,non_blocking=True,),)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device)
@@ -379,17 +330,13 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device), text_lengths.to(device)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.to(device)
with autocast(enabled=hps.train.fp16_run):
(
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl) = net_g(ssl, spec, spec_lengths, text, text_lengths,sv_emb)
else:
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl,) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
@@ -561,13 +508,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
% (
hps.name,
epoch,
savee(
ckpt,
hps.name + "_e%s_s%s" % (epoch, global_step),
epoch,
global_step,
hps,
),
savee(ckpt,hps.name + "_e%s_s%s" % (epoch, global_step),epoch,global_step,hps,model_version=None if hps.model.version not in {"v2Pro","v2ProPlus"}else hps.model.version),
)
)