Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,9 +1,7 @@
import warnings
warnings.filterwarnings("ignore")
import copy
import math
import os
import pdb
import torch
from torch import nn
@@ -13,16 +11,18 @@ from module import commons
from module import modules
from module import attentions
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib,random
import contextlib
import random
class StochasticDurationPredictor(nn.Module):
@@ -48,29 +48,21 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@@ -91,10 +83,7 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@@ -102,13 +91,8 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
@@ -117,18 +101,12 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
@@ -137,9 +115,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
@@ -149,13 +125,9 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
@@ -190,7 +162,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version = "v2",
version="v2",
):
super().__init__()
self.out_channels = out_channels
@@ -237,26 +209,22 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(
commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
if test == 1:
text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
@@ -360,9 +328,7 @@ class PosteriorEncoder(nn.Module):
def forward(self, x, x_lengths, g=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
@@ -372,14 +338,9 @@ class PosteriorEncoder(nn.Module):
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -394,7 +355,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None):
if(g!=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
@@ -402,6 +363,7 @@ class Encoder(nn.Module):
stats = self.proj(x) * x_mask
return stats, x_mask
class WNEncoder(nn.Module):
def __init__(
self,
@@ -434,9 +396,7 @@ class WNEncoder(nn.Module):
self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask
@@ -459,9 +419,7 @@ class Generator(torch.nn.Module):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
@@ -481,9 +439,7 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@@ -636,9 +592,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
@@ -738,10 +692,7 @@ class Quantizer(torch.nn.Module):
super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList(
[
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
)
self.n_code_groups = n_code_groups
self.embed_dim = embed_dim
@@ -759,9 +710,7 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@@ -801,13 +750,9 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q
@@ -820,9 +765,7 @@ class CodePredictor(nn.Module):
x = x + g
x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
target = codes[1:].transpose(0, 1)
if not infer:
logits = logits.reshape(-1, self.dims)
@@ -870,8 +813,8 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version = "v2",
**kwargs
version="v2",
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
@@ -902,7 +845,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version = version,
version=version,
)
self.dec = Generator(
inter_channels,
@@ -923,12 +866,10 @@ class SynthesizerTrn(nn.Module):
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
# self.version=os.environ.get("version","v1")
if(self.version=="v1"):
if self.version == "v1":
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
@@ -945,13 +886,11 @@ class SynthesizerTrn(nn.Module):
self.freeze_quantizer = freeze_quantizer
def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
if(self.version=="v1"):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
@@ -959,24 +898,16 @@ class SynthesizerTrn(nn.Module):
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge)
return (
o,
@@ -989,24 +920,18 @@ class SynthesizerTrn(nn.Module):
)
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
if(self.version=="v1"):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge, test=test
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@@ -1015,39 +940,34 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
def get_ge(refer):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
if (self.version == "v1"):
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
if(type(refer)==list):
ges=[]
if type(refer) == list:
ges = []
for _refer in refer:
ge=get_ge(_refer)
ge = get_ge(_refer)
ges.append(ge)
ge=torch.stack(ges,0).mean(0)
ge = torch.stack(ges, 0).mean(0)
else:
ge=get_ge(refer)
ge = get_ge(refer)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge,speed
)
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@@ -1059,11 +979,10 @@ class SynthesizerTrn(nn.Module):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(
self,
in_channels,dit
):
def __init__(self, in_channels, dit):
super().__init__()
self.sigma_min = 1e-6
@@ -1077,41 +996,54 @@ class CFM(torch.nn.Module):
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
"""Forward diffusion"""
B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0
mu=mu.transpose(2,1)
mu = mu.transpose(2, 1)
t = 0
d = 1 / n_timesteps
for j in range(n_timesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
if inference_cfg_rate>1e-5:
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
v_pred = self.estimator(
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
).transpose(2, 1)
if inference_cfg_rate > 1e-5:
neg = self.estimator(
x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=True,
drop_text=True,
).transpose(2, 1)
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
x = x + d * v_pred
t = t + d
x[:, :, :prompt_len] = 0
return x
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
b, _, t = x1.shape
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
x0 = torch.randn_like(x1,device=mu.device)
x0 = torch.randn_like(x1, device=mu.device)
vt = x1 - x0
xt = x0 + t[:, None, None] * vt
dt = torch.zeros_like(t,device=mu.device)
dt = torch.zeros_like(t, device=mu.device)
prompt = torch.zeros_like(x1)
for i in range(b):
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
xt[i, :, :prompt_lens[i]] = 0
gailv=0.3# if ttime()>1736250488 else 0.1
prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
xt[i, :, : prompt_lens[i]] = 0
gailv = 0.3 # if ttime()>1736250488 else 0.1
if random.random() < gailv:
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
d = 1/torch.pow(2, base)
d = 1 / torch.pow(2, base)
d_input = d.clone()
d_input[d_input < 1e-2] = 0
# with torch.no_grad():
@@ -1119,52 +1051,55 @@ class CFM(torch.nn.Module):
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
x_mid = xt + d[:, None, None] * v_pred_1
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
vt = (v_pred_1 + v_pred_2) / 2
vt = vt.detach()
dt = 2*d
dt = 2 * d
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
loss = 0
for i in range(b):
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
loss /= b
return loss
def set_no_grad(net_g):
for name, param in net_g.named_parameters():
param.requires_grad=False
param.requires_grad = False
class SynthesizerTrnV3(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@@ -1185,132 +1120,133 @@ class SynthesizerTrnV3(nn.Module):
self.gin_channels = gin_channels
self.version = version
self.model_dim=512
self.model_dim = 512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if self.freeze_quantizer==True:
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
inter_channels2 = 512
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
self.cfm = CFM(
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
if self.freeze_quantizer == True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
def forward(
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
): # ssl_lengths no need now
with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()#
self.ssl_proj.eval() #
self.quantizer.eval()
self.enc_p.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
fea, y_mask_ = self.wns1(
fea, mel_lengths, ge
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B = ssl.shape[0]
prompt_len_max = mel_lengths * 2 / 3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
minn = min(mel.shape[-1], fea.shape[-1])
mel = mel[:, :, :minn]
fea = fea[:, :, :minn]
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
return cfm_loss
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None,speed=1):
def decode_encp(self, codes, text, refer, ge=None, speed=1):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
if ge == None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
if speed==1:
sizee=int(codes.size(2)*2.5*1.5)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
if speed == 1:
sizee = int(codes.size(2) * 2.5 * 1.5)
else:
sizee=int(codes.size(2)*2.5*1.5/speed)+1
sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge
return fea, ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)
class SynthesizerTrnV3b(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@@ -1330,47 +1266,52 @@ class SynthesizerTrnV3b(nn.Module):
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.model_dim=512
self.model_dim = 512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
gin_channels=gin_channels)
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
inter_channels2 = 512
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
self.cfm = CFM(
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
# ge=None
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
@@ -1379,51 +1320,59 @@ class SynthesizerTrnV3b(nn.Module):
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
learned_mel = self.linear_mel(fea)
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
B = ssl.shape[0]
prompt_len_max = mel_lengths * 2 / 3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
minn = min(mel.shape[-1], fea.shape[-1])
mel = mel[:, :, :minn]
fea = fea[:, :, :minn]
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
return (
commit_loss,
cfm_loss,
F.mse_loss(learned_mel, mel),
o,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None):
def decode_encp(self, codes, text, refer, ge=None):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
if ge == None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge
return fea, ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)