Fix onnx_export to support v2 (#1604)

This commit is contained in:
zzz
2024-09-13 11:27:22 +08:00
committed by GitHub
parent 570da092c9
commit 0c000191b3
2 changed files with 58 additions and 31 deletions

View File

@@ -13,7 +13,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from text import symbols
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
@@ -182,6 +184,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version="v2",
):
super().__init__()
self.out_channels = out_channels
@@ -192,6 +195,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@@ -207,6 +211,11 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@@ -817,6 +826,7 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v2",
**kwargs
):
super().__init__()
@@ -837,6 +847,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
@@ -847,6 +858,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version=version,
)
self.dec = Generator(
inter_channels,
@@ -871,9 +883,11 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
# self.version=os.environ.get("version","v1")
if self.version == "v1":
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
ssl_dim = 768
self.ssl_dim = ssl_dim
@@ -894,7 +908,10 @@ class SynthesizerTrn(nn.Module):
def forward(self, codes, text, refer):
refer_mask = torch.ones_like(refer[:1,:1,:])
ge = self.ref_enc(refer * refer_mask, refer_mask)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":