Version Check (#1390)

* version check

* fix webui and symbols

* fix v1 language map
This commit is contained in:
KamioRinn
2024-08-05 17:24:42 +08:00
committed by GitHub
parent 0c25e57959
commit 4e34814c70
8 changed files with 157 additions and 78 deletions

View File

@@ -15,7 +15,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from text import symbols
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib
@@ -185,6 +187,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version = "v2",
):
super().__init__()
self.out_channels = out_channels
@@ -195,6 +198,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@@ -210,6 +214,11 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@@ -827,6 +836,7 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version = "v2",
**kwargs
):
super().__init__()
@@ -847,6 +857,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
@@ -857,6 +868,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version = version,
)
self.dec = Generator(
inter_channels,
@@ -881,7 +893,7 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.version=os.environ.get("version","v1")
# self.version=os.environ.get("version","v1")
if(self.version=="v1"):
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else: