优化 export_torch_script.py (#1720)

* export_torch_script 从命令行获取参数

* export 支持语速设置
This commit is contained in:
zzz
2024-10-26 16:14:39 +08:00
committed by GitHub
parent 5d126f98b2
commit 98cc47699c
2 changed files with 89 additions and 59 deletions

View File

@@ -231,7 +231,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge):
def forward(self, y, text, ge, speed=1):
y_mask = torch.ones_like(y[:1,:1,:])
y = self.ssl_proj(y * y_mask) * y_mask
@@ -244,6 +244,9 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
@@ -887,7 +890,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer):
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
@@ -900,10 +903,10 @@ class SynthesizerTrn(nn.Module):
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge
quantized, text, ge, speed
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)