export_torch_script.py support v2Pro & v2ProPlus

This commit is contained in:
csh
2025-06-12 21:53:14 +08:00
parent ed89a02337
commit 5c91e66d2e
2 changed files with 275 additions and 29 deletions

View File

@@ -762,6 +762,7 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
v2pro_set={"v2Pro","v2ProPlus"}
class SynthesizerTrn(nn.Module):
"""
@@ -867,20 +868,33 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
refer_mask = torch.ones_like(refer[:1, :1, :])
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb)
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
if self.is_v2pro:
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
else:
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale