Fix onnx_export to support v2 (#1604)
This commit is contained in:
@@ -13,7 +13,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
from text import symbols
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
@@ -182,6 +184,7 @@ class TextEncoder(nn.Module):
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
latent_channels=192,
|
||||
version="v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
@@ -192,6 +195,7 @@ class TextEncoder(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.latent_channels = latent_channels
|
||||
self.version = version
|
||||
|
||||
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
||||
|
||||
@@ -207,6 +211,11 @@ class TextEncoder(nn.Module):
|
||||
self.encoder_text = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
|
||||
if self.version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||
|
||||
self.mrte = MRTE()
|
||||
@@ -817,6 +826,7 @@ class SynthesizerTrn(nn.Module):
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v2",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -837,6 +847,7 @@ class SynthesizerTrn(nn.Module):
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(
|
||||
@@ -847,6 +858,7 @@ class SynthesizerTrn(nn.Module):
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
version=version,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
@@ -871,9 +883,11 @@ class SynthesizerTrn(nn.Module):
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
spec_channels, style_vector_dim=gin_channels
|
||||
)
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if self.version == "v1":
|
||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||
else:
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||
|
||||
ssl_dim = 768
|
||||
self.ssl_dim = ssl_dim
|
||||
@@ -894,7 +908,10 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
def forward(self, codes, text, refer):
|
||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
if (self.version == "v1"):
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
|
||||
Reference in New Issue
Block a user