export_torch_script.py support v2Pro & v2ProPlus
This commit is contained in:
@@ -762,6 +762,7 @@ class CodePredictor(nn.Module):
|
||||
|
||||
return pred_codes.transpose(0, 1)
|
||||
|
||||
v2pro_set={"v2Pro","v2ProPlus"}
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
@@ -867,20 +868,33 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.text_embedding.requires_grad_(False)
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
self.is_v2pro=self.version in v2pro_set
|
||||
if self.is_v2pro:
|
||||
self.sv_emb = nn.Linear(20480, gin_channels)
|
||||
self.ge_to512 = nn.Linear(gin_channels, 512)
|
||||
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
||||
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb)
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
if self.is_v2pro:
|
||||
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
|
||||
else:
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user