support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
This commit is contained in:
@@ -36,7 +36,7 @@ from module.models import (
|
||||
MultiPeriodDiscriminator,
|
||||
SynthesizerTrn,
|
||||
)
|
||||
from process_ckpt import savee
|
||||
from process_ckpt import savee,my_save2
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = False
|
||||
@@ -87,38 +87,19 @@ def run(rank, n_gpus, hps):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data,version=hps.model.version)
|
||||
train_sampler = DistributedBucketSampler(
|
||||
train_dataset,
|
||||
hps.train.batch_size,
|
||||
[
|
||||
32,
|
||||
300,
|
||||
400,
|
||||
500,
|
||||
600,
|
||||
700,
|
||||
800,
|
||||
900,
|
||||
1000,
|
||||
1100,
|
||||
1200,
|
||||
1300,
|
||||
1400,
|
||||
1500,
|
||||
1600,
|
||||
1700,
|
||||
1800,
|
||||
1900,
|
||||
],
|
||||
[32,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,],
|
||||
num_replicas=n_gpus,
|
||||
rank=rank,
|
||||
shuffle=True,
|
||||
)
|
||||
collate_fn = TextAudioSpeakerCollate()
|
||||
collate_fn = TextAudioSpeakerCollate(version=hps.model.version)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=6,
|
||||
num_workers=5,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
@@ -149,9 +130,9 @@ def run(rank, n_gpus, hps):
|
||||
)
|
||||
|
||||
net_d = (
|
||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).to(device)
|
||||
)
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
@@ -235,12 +216,12 @@ def run(rank, n_gpus, hps):
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||||
strict=False,
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||||
strict=False,
|
||||
),
|
||||
) ##测试不加载优化器
|
||||
@@ -254,11 +235,11 @@ def run(rank, n_gpus, hps):
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],strict=False
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_d.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
|
||||
),
|
||||
)
|
||||
|
||||
@@ -328,50 +309,20 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
|
||||
net_g.train()
|
||||
net_d.train()
|
||||
for batch_idx, (
|
||||
ssl,
|
||||
ssl_lengths,
|
||||
spec,
|
||||
spec_lengths,
|
||||
y,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, data in enumerate(tqdm(train_loader)):
|
||||
if hps.model.version in {"v2Pro","v2ProPlus"}:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths,sv_emb=data
|
||||
else:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths=data
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
y, y_lengths = (
|
||||
y.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
y_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
spec, spec_lengths = (spec.cuda(rank,non_blocking=True,),spec_lengths.cuda(rank,non_blocking=True,),)
|
||||
y, y_lengths = (y.cuda(rank,non_blocking=True,),y_lengths.cuda(rank,non_blocking=True,),)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
text, text_lengths = (text.cuda(rank,non_blocking=True,),text_lengths.cuda(rank,non_blocking=True,),)
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
sv_emb = sv_emb.cuda(rank, non_blocking=True)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||
@@ -379,17 +330,13 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
sv_emb = sv_emb.to(device)
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
(
|
||||
y_hat,
|
||||
kl_ssl,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
stats_ssl,
|
||||
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl) = net_g(ssl, spec, spec_lengths, text, text_lengths,sv_emb)
|
||||
else:
|
||||
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl,) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
@@ -561,13 +508,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
% (
|
||||
hps.name,
|
||||
epoch,
|
||||
savee(
|
||||
ckpt,
|
||||
hps.name + "_e%s_s%s" % (epoch, global_step),
|
||||
epoch,
|
||||
global_step,
|
||||
hps,
|
||||
),
|
||||
savee(ckpt,hps.name + "_e%s_s%s" % (epoch, global_step),epoch,global_step,hps,model_version=None if hps.model.version not in {"v2Pro","v2ProPlus"}else hps.model.version),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user