support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
This commit is contained in:
RVC-Boss
2025-06-04 15:18:55 +08:00
committed by GitHub
parent 3f46359652
commit 0621259549
2 changed files with 57 additions and 18 deletions

View File

@@ -586,11 +586,12 @@ class DiscriminatorS(torch.nn.Module):
return x, fmap
v2pro_set={"v2Pro","v2ProPlus"}
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
def __init__(self, use_spectral_norm=False,version=None):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11]
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
else:periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
@@ -786,7 +787,6 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
@@ -886,12 +886,23 @@ class SynthesizerTrn(nn.Module):
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
def forward(self, ssl, y, y_lengths, text, text_lengths):
self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
@@ -904,7 +915,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
@@ -941,8 +952,8 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
def get_ge(refer):
def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
@@ -951,16 +962,20 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge
if type(refer) == list:
ges = []
for _refer in refer:
ge = get_ge(_refer)
for idx,_refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer)
ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@@ -968,7 +983,7 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)