Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1,20 +1,26 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
|
||||
|
||||
# from utils import init_weights, get_padding
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple, List
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
||||
|
||||
@@ -30,24 +36,24 @@ class ConvNeXtBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layer_scale_init_value= None,
|
||||
adanorm_num_embeddings = None,
|
||||
layer_scale_init_value=None,
|
||||
adanorm_num_embeddings=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, dim*3) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.pwconv1 = nn.Linear(dim, dim * 3) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(dim*3, dim)
|
||||
self.pwconv2 = nn.Linear(dim * 3, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x, cond_embedding_id = None) :
|
||||
def forward(self, x, cond_embedding_id=None):
|
||||
residual = x
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||
@@ -72,11 +78,11 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
super(APNet_BWE_Model, self).__init__()
|
||||
self.h = h
|
||||
self.adanorm_num_embeddings = None
|
||||
layer_scale_init_value = 1 / h.ConvNeXt_layers
|
||||
layer_scale_init_value = 1 / h.ConvNeXt_layers
|
||||
|
||||
self.conv_pre_mag = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.conv_pre_mag = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.conv_pre_pha = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.conv_pre_pha = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
|
||||
self.convnext_mag = nn.ModuleList(
|
||||
@@ -104,9 +110,9 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.apply(self._init_weights)
|
||||
self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
|
||||
self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
|
||||
self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
|
||||
self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
@@ -114,7 +120,6 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, mag_nb, pha_nb):
|
||||
|
||||
x_mag = self.conv_pre_mag(mag_nb)
|
||||
x_pha = self.conv_pre_pha(pha_nb)
|
||||
x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2)
|
||||
@@ -134,11 +139,9 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
x_pha_i = self.linear_post_pha_i(x_pha)
|
||||
pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2)
|
||||
|
||||
com_wb = torch.stack((torch.exp(mag_wb)*torch.cos(pha_wb),
|
||||
torch.exp(mag_wb)*torch.sin(pha_wb)), dim=-1)
|
||||
|
||||
return mag_wb, pha_wb, com_wb
|
||||
com_wb = torch.stack((torch.exp(mag_wb) * torch.cos(pha_wb), torch.exp(mag_wb) * torch.sin(pha_wb)), dim=-1)
|
||||
|
||||
return mag_wb, pha_wb, com_wb
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
@@ -146,13 +149,15 @@ class DiscriminatorP(torch.nn.Module):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -160,13 +165,13 @@ class DiscriminatorP(torch.nn.Module):
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for i,l in enumerate(self.convs):
|
||||
for i, l in enumerate(self.convs):
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if i > 0:
|
||||
@@ -181,13 +186,15 @@ class DiscriminatorP(torch.nn.Module):
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
@@ -264,8 +271,8 @@ class DiscriminatorAR(nn.Module):
|
||||
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
x=x.squeeze(1)
|
||||
|
||||
x = x.squeeze(1)
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
@@ -358,8 +365,8 @@ class DiscriminatorPR(nn.Module):
|
||||
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
x=x.squeeze(1)
|
||||
|
||||
x = x.squeeze(1)
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
@@ -407,11 +414,11 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
@@ -420,35 +427,37 @@ def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
def phase_losses(phase_r, phase_g):
|
||||
|
||||
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
||||
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
|
||||
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
|
||||
|
||||
return ip_loss, gd_loss, iaf_loss
|
||||
|
||||
def anti_wrapping_function(x):
|
||||
|
||||
def anti_wrapping_function(x):
|
||||
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
||||
|
||||
|
||||
def stft_mag(audio, n_fft=2048, hop_length=512):
|
||||
hann_window = torch.hann_window(n_fft).to(audio.device)
|
||||
stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
|
||||
stft_mag = torch.abs(stft_spec)
|
||||
return(stft_mag)
|
||||
return stft_mag
|
||||
|
||||
|
||||
def cal_snr(pred, target):
|
||||
snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
|
||||
return snr
|
||||
|
||||
|
||||
def cal_lsd(pred, target):
|
||||
sp = torch.log10(stft_mag(pred).square().clamp(1e-8))
|
||||
st = torch.log10(stft_mag(target).square().clamp(1e-8))
|
||||
|
||||
Reference in New Issue
Block a user