Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -23,9 +23,7 @@ class Snake(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
@@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
|
||||
@@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||
activation_results = anti_alias_activation_cuda.forward(
|
||||
inputs, up_ftr, down_ftr, alpha, beta
|
||||
)
|
||||
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
|
||||
|
||||
return activation_results
|
||||
|
||||
@@ -61,17 +59,11 @@ class Activation1d(nn.Module):
|
||||
if self.act.__class__.__name__ == "Snake":
|
||||
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||
else:
|
||||
beta = (
|
||||
self.act.beta.data
|
||||
) # Snakebeta uses different params for alpha and beta
|
||||
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
|
||||
alpha = self.act.alpha.data
|
||||
if (
|
||||
not self.act.alpha_logscale
|
||||
): # Exp baked into cuda kernel, cancel it out with a log
|
||||
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
|
||||
alpha = torch.log(alpha)
|
||||
beta = torch.log(beta)
|
||||
|
||||
x = FusedAntiAliasActivation.apply(
|
||||
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
||||
)
|
||||
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
|
||||
return x
|
||||
|
||||
@@ -58,17 +58,13 @@ def load():
|
||||
srcpath / "anti_alias_activation.cpp",
|
||||
srcpath / "anti_alias_activation_cuda.cu",
|
||||
]
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
||||
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
|
||||
|
||||
return anti_alias_activation_cuda
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
|
||||
@@ -27,9 +27,7 @@ else:
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff, half_width, kernel_size
|
||||
): # return filter [1,1,kernel_size]
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
|
||||
|
||||
@@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
@@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
return x
|
||||
@@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
|
||||
@@ -50,7 +50,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
@@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(
|
||||
self.convs2
|
||||
) # Total number of conv layers
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
@@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@@ -169,7 +159,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
@@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@@ -283,9 +265,7 @@ class BigVGAN(
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# Pre-conv
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
if h.resblock == "1":
|
||||
@@ -293,9 +273,7 @@ class BigVGAN(
|
||||
elif h.resblock == "2":
|
||||
resblock_class = AMPBlock2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
||||
)
|
||||
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||
|
||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
@@ -320,22 +298,14 @@ class BigVGAN(
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(
|
||||
resblock_class(h, ch, k, d, activation=h.activation)
|
||||
)
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# Post-conv
|
||||
activation_post = (
|
||||
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snake"
|
||||
else (
|
||||
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snakebeta"
|
||||
else None
|
||||
)
|
||||
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
|
||||
)
|
||||
if activation_post is None:
|
||||
raise NotImplementedError(
|
||||
@@ -346,9 +316,7 @@ class BigVGAN(
|
||||
|
||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
||||
)
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||
|
||||
# Weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
@@ -451,13 +419,13 @@ class BigVGAN(
|
||||
# instantiate BigVGAN using h
|
||||
if use_cuda_kernel:
|
||||
print(
|
||||
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
)
|
||||
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||
|
||||
@@ -485,7 +453,7 @@ class BigVGAN(
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
except RuntimeError:
|
||||
print(
|
||||
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
)
|
||||
model.remove_weight_norm()
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
|
||||
@@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
|
||||
from env import AttrDict
|
||||
from utils import get_padding
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
@@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
self.mpd_reshapes = h.mpd_reshapes
|
||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
|
||||
for rs in self.mpd_reshapes
|
||||
]
|
||||
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.resolution = resolution
|
||||
assert (
|
||||
len(self.resolution) == 3
|
||||
), f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
self.lrelu_slope = 0.1
|
||||
|
||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||
print(
|
||||
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
|
||||
)
|
||||
norm_f = (
|
||||
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
)
|
||||
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
self.d_mult = cfg.discriminator_channel_mult
|
||||
if hasattr(cfg, "mrd_channel_mult"):
|
||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||
@@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert (
|
||||
len(self.resolutions) == 3
|
||||
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||
assert len(self.resolutions) == 3, (
|
||||
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||
),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
# Remove DC offset
|
||||
@@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
|
||||
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorCQT(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
@@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
|
||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
||||
)
|
||||
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
@@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding(
|
||||
(self.kernel_size[0], self.kernel_size[0])
|
||||
),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||
if self.cqtd_normalize_volume:
|
||||
print(
|
||||
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
)
|
||||
|
||||
def get_2d_padding(
|
||||
@@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
# Multi-scale params to loop over
|
||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
|
||||
"cqtd_bins_per_octaves", [24, 36, 48]
|
||||
)
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
@@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
|
||||
@@ -35,9 +35,7 @@ def inference(a, h):
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the ground truth audio and resample if necessary
|
||||
wav, sr = librosa.load(
|
||||
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
|
||||
)
|
||||
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||
wav = torch.FloatTensor(wav).to(device)
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||
@@ -48,9 +46,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
@@ -6,13 +6,12 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy import signal
|
||||
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import functools
|
||||
@@ -117,15 +116,13 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
window_type,
|
||||
):
|
||||
"""
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert (
|
||||
hop_length == window_length // 4
|
||||
), "For match_stride, hop must equal n_fft // 4"
|
||||
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
@@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(
|
||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
||||
)
|
||||
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
@@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(
|
||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
||||
):
|
||||
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
@@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(
|
||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(
|
||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
@@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(
|
||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
|
||||
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
@@ -226,7 +212,6 @@ def feature_loss(
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
@@ -243,7 +228,6 @@ def discriminator_loss(
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
|
||||
@@ -86,9 +86,7 @@ def mel_spectrogram(
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
@@ -96,9 +94,7 @@ def mel_spectrogram(
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
@@ -150,17 +146,13 @@ def get_dataset_filelist(a):
|
||||
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first training file: {training_files[0]}")
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first validation file: {validation_files[0]}")
|
||||
|
||||
@@ -171,9 +163,7 @@ def get_dataset_filelist(a):
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(
|
||||
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
||||
)
|
||||
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||
list_unseen_validation_files.append(unseen_validation_files)
|
||||
|
||||
return training_files, validation_files, list_unseen_validation_files
|
||||
@@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
print("[INFO] checking dataset integrity...")
|
||||
for i in tqdm(range(len(self.audio_files))):
|
||||
assert os.path.exists(
|
||||
self.audio_files[i]
|
||||
), f"{self.audio_files[i]} not found"
|
||||
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
try:
|
||||
filename = self.audio_files[index]
|
||||
|
||||
@@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
# Obtain randomized audio chunk
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
# Adjust segment size to crop if the source sr is different
|
||||
target_segment_size = math.ceil(
|
||||
self.segment_size
|
||||
* (source_sampling_rate / self.sampling_rate)
|
||||
)
|
||||
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||
else:
|
||||
target_segment_size = self.segment_size
|
||||
|
||||
# Compute upper bound index for the random chunk
|
||||
random_chunk_upper_bound = max(
|
||||
0, audio.shape[0] - target_segment_size
|
||||
)
|
||||
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||
|
||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||
if audio.shape[0] >= target_segment_size:
|
||||
@@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||
assert (
|
||||
source_sampling_rate == self.sampling_rate
|
||||
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
assert source_sampling_rate == self.sampling_rate, (
|
||||
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
)
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
@@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||
]
|
||||
|
||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||
mel_loss = mel_spectrogram(
|
||||
@@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Shape sanity checks
|
||||
assert (
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size
|
||||
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), (
|
||||
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
@@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
if self.fine_tuning:
|
||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
|
||||
)
|
||||
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||
return self[random.randrange(len(self))]
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations.Snake cuda vs. torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations, Snake CUDA vs. Torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
@@ -57,7 +54,6 @@ def test_anti_alias_activation():
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
|
||||
@@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
|
||||
|
||||
|
||||
def get_mel(x, h):
|
||||
return mel_spectrogram(
|
||||
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
|
||||
)
|
||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
@@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test script to check CUDA kernel correctness."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
type=str,
|
||||
@@ -91,27 +87,25 @@ if __name__ == "__main__":
|
||||
# define number of samples and length of mel frame to benchmark
|
||||
num_sample = 10
|
||||
num_mel_frame = 16384
|
||||
|
||||
|
||||
# CUDA kernel correctness check
|
||||
diff = 0.0
|
||||
for i in tqdm(range(num_sample)):
|
||||
# Random mel
|
||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
audio_original = generator_original(data)
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
||||
|
||||
# Both outputs should be (almost) the same
|
||||
test_result = (audio_original - audio_cuda_kernel).abs()
|
||||
diff += test_result.mean(dim=-1).item()
|
||||
|
||||
|
||||
diff /= num_sample
|
||||
if (
|
||||
diff <= 2e-3
|
||||
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
print(
|
||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||
f"\n > mean_difference={diff}"
|
||||
@@ -125,9 +119,9 @@ if __name__ == "__main__":
|
||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
|
||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
||||
)
|
||||
|
||||
|
||||
del data, audio_original, audio_cuda_kernel
|
||||
|
||||
|
||||
# Variables for tracking total time and VRAM usage
|
||||
toc_total_original = 0
|
||||
toc_total_cuda_kernel = 0
|
||||
@@ -145,10 +139,10 @@ if __name__ == "__main__":
|
||||
audio_original = generator_original(data)
|
||||
torch.cuda.synchronize()
|
||||
toc = time() - tic
|
||||
toc_total_original += toc
|
||||
toc_total_original += toc
|
||||
|
||||
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||
|
||||
|
||||
del data, audio_original
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -163,11 +157,11 @@ if __name__ == "__main__":
|
||||
torch.cuda.synchronize()
|
||||
toc = time() - tic
|
||||
toc_total_cuda_kernel += toc
|
||||
|
||||
|
||||
audio_length_total += audio_cuda_kernel.shape[-1]
|
||||
|
||||
|
||||
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||
|
||||
|
||||
del data, audio_cuda_kernel
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -175,8 +169,8 @@ if __name__ == "__main__":
|
||||
audio_second = audio_length_total / h.sampling_rate
|
||||
khz_original = audio_length_total / toc_total_original / 1000
|
||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
||||
|
||||
# Print results
|
||||
print(
|
||||
|
||||
@@ -77,24 +77,18 @@ def train(rank, a, h):
|
||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||
print(
|
||||
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||
mrd = MultiBandDiscriminator(h).to(device)
|
||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||
print(
|
||||
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||
else: # Fallback to original MRD in BigVGAN-v1
|
||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||
|
||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||
if h.get("use_multiscale_melloss", False):
|
||||
print(
|
||||
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
|
||||
)
|
||||
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||
sampling_rate=h.sampling_rate
|
||||
) # NOTE: accepts waveform as input
|
||||
@@ -114,9 +108,7 @@ def train(rank, a, h):
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||
cp_g = scan_checkpoint(
|
||||
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
|
||||
)
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||
cp_do = scan_checkpoint(
|
||||
a.checkpoint_path,
|
||||
prefix="do_",
|
||||
@@ -143,9 +135,7 @@ def train(rank, a, h):
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
|
||||
)
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(
|
||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||
h.learning_rate,
|
||||
@@ -156,12 +146,8 @@ def train(rank, a, h):
|
||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
# Define training and validation datasets
|
||||
|
||||
@@ -169,9 +155,7 @@ def train(rank, a, h):
|
||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||
Example: trained on LibriTTS, validate on VCTK
|
||||
"""
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = (
|
||||
get_dataset_filelist(a)
|
||||
)
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(
|
||||
training_filelist,
|
||||
@@ -324,33 +308,26 @@ def train(rank, a, h):
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
|
||||
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||
|
||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||
if (
|
||||
not "nonspeech" in mode
|
||||
): # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
|
||||
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
# Resample to 16000 for pesq
|
||||
y_16k = pesq_resampler(y)
|
||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
y_g_hat_int_16k = (
|
||||
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
)
|
||||
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||
|
||||
# MRSTFT calculation
|
||||
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
|
||||
|
||||
# Log audio and figures to Tensorboard
|
||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||
if steps >= 0:
|
||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y[0],
|
||||
os.path.join(
|
||||
@@ -373,9 +350,7 @@ def train(rank, a, h):
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y_g_hat[0, 0],
|
||||
os.path.join(
|
||||
@@ -487,15 +462,11 @@ def train(rank, a, h):
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
||||
y_df_hat_r, y_df_hat_g
|
||||
)
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MRD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
||||
y_ds_hat_r, y_ds_hat_g
|
||||
)
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
@@ -505,17 +476,11 @@ def train(rank, a, h):
|
||||
# Whether to freeze D for initial training steps
|
||||
if steps >= a.freeze_step:
|
||||
loss_disc_all.backward()
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
|
||||
mpd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
|
||||
mrd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||
optim_d.step()
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||
grad_norm_mpd = 0.0
|
||||
grad_norm_mrd = 0.0
|
||||
|
||||
@@ -523,9 +488,7 @@ def train(rank, a, h):
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
lambda_melloss = h.get(
|
||||
"lambda_melloss", 45.0
|
||||
) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||
@@ -542,27 +505,19 @@ def train(rank, a, h):
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
|
||||
if steps >= a.freeze_step:
|
||||
loss_gen_all = (
|
||||
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(
|
||||
generator.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to stdout
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||
print(
|
||||
f"Steps: {steps:d}, "
|
||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||
@@ -577,11 +532,7 @@ def train(rank, a, h):
|
||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"generator": (
|
||||
generator.module if h.num_gpus > 1 else generator
|
||||
).state_dict()
|
||||
},
|
||||
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||
)
|
||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||
save_checkpoint(
|
||||
@@ -598,9 +549,7 @@ def train(rank, a, h):
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to tensorboard
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||
@@ -612,12 +561,8 @@ def train(rank, a, h):
|
||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||
|
||||
# Validation
|
||||
@@ -660,9 +605,7 @@ def train(rank, a, h):
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
|
||||
)
|
||||
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -674,12 +617,8 @@ def main():
|
||||
|
||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||
parser.add_argument(
|
||||
"--input_training_file", default="tests/LibriTTS/train-full.txt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
|
||||
)
|
||||
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_wavs_dir",
|
||||
|
||||
Reference in New Issue
Block a user