Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -23,9 +23,7 @@ class Snake(nn.Module):
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
@@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:

View File

@@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(
inputs, up_ftr, down_ftr, alpha, beta
)
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
return activation_results
@@ -61,17 +59,11 @@ class Activation1d(nn.Module):
if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta
else:
beta = (
self.act.beta.data
) # Snakebeta uses different params for alpha and beta
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data
if (
not self.act.alpha_logscale
): # Exp baked into cuda kernel, cancel it out with a log
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)
x = FusedAntiAliasActivation.apply(
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
return x

View File

@@ -58,17 +58,13 @@ def load():
srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu",
]
anti_alias_activation_cuda = _cpp_extention_load_helper(
"anti_alias_activation_cuda", sources, extra_cuda_flags
)
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")

View File

@@ -27,9 +27,7 @@ else:
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2

View File

@@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
@@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left : -self.pad_right]
return x
@@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,

View File

@@ -50,7 +50,7 @@ class AMPBlock1(torch.nn.Module):
activation: str = None,
):
super().__init__()
self.h = h
self.convs1 = nn.ModuleList(
@@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
)
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(
self.convs2
) # Total number of conv layers
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
@@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
@@ -169,7 +159,7 @@ class AMPBlock2(torch.nn.Module):
activation: str = None,
):
super().__init__()
self.h = h
self.convs = nn.ModuleList(
@@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
@@ -283,9 +265,7 @@ class BigVGAN(
self.num_upsamples = len(h.upsample_rates)
# Pre-conv
self.conv_pre = weight_norm(
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
if h.resblock == "1":
@@ -293,9 +273,7 @@ class BigVGAN(
elif h.resblock == "2":
resblock_class = AMPBlock2
else:
raise ValueError(
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
)
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
# Transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
@@ -320,22 +298,14 @@ class BigVGAN(
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock_class(h, ch, k, d, activation=h.activation)
)
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
# Post-conv
activation_post = (
activations.Snake(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snake"
else (
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snakebeta"
else None
)
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
)
if activation_post is None:
raise NotImplementedError(
@@ -346,9 +316,7 @@ class BigVGAN(
# Whether to use bias for the final conv_post. Default to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm(
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
# Weight initialization
for i in range(len(self.ups)):
@@ -451,13 +419,13 @@ class BigVGAN(
# instantiate BigVGAN using h
if use_cuda_kernel:
print(
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
)
print(
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
)
print(
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
)
model = cls(h, use_cuda_kernel=use_cuda_kernel)
@@ -485,7 +453,7 @@ class BigVGAN(
model.load_state_dict(checkpoint_dict["generator"])
except RuntimeError:
print(
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
)
model.remove_weight_norm()
model.load_state_dict(checkpoint_dict["generator"])

View File

@@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
from env import AttrDict
from utils import get_padding
import typing
from typing import Optional, List, Union, Dict, Tuple
from typing import List, Tuple
class DiscriminatorP(torch.nn.Module):
@@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
),
]
)
self.conv_post = norm_f(
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
)
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList(
[
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
for rs in self.mpd_reshapes
]
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
super().__init__()
self.resolution = resolution
assert (
len(self.resolution) == 3
), f"MRD layer requires list with len=3, got {self.resolution}"
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print(
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
)
norm_f = (
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
)
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
@@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
),
]
)
self.conv_post = norm_f(
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
)
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert (
len(self.resolutions) == 3
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
assert len(self.resolutions) == 3, (
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
)
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset
@@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList(
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
)
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__()
self.cfg = cfg
@@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations):
out_chs = min(
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
)
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
self.convs.append(
weight_norm(
nn.Conv2d(
@@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
in_chs,
out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding(
(self.kernel_size[0], self.kernel_size[0])
),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
)
@@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
if self.cqtd_normalize_volume:
print(
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
)
def get_2d_padding(
@@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
# Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
"cqtd_bins_per_octaves", [24, 36, 48]
)
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
self.discriminators = nn.ModuleList(
[
@@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []

View File

@@ -35,9 +35,7 @@ def inference(a, h):
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the ground truth audio and resample if necessary
wav, sr = librosa.load(
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
)
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
wav = torch.FloatTensor(wav).to(device)
# Compute mel spectrogram from the ground truth audio
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
@@ -48,9 +46,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
)
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)

View File

@@ -61,9 +61,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
)
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)

View File

@@ -6,13 +6,12 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import Optional, List, Union, Dict, Tuple
from typing import List, Tuple
from collections import namedtuple
import math
import functools
@@ -117,15 +116,13 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
window_type,
):
"""
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
B, C, T = wav.shape
if match_stride:
assert (
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
@@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
@@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
@@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
@@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
# Loss functions
def feature_loss(
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
@@ -226,7 +212,6 @@ def feature_loss(
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
@@ -243,7 +228,6 @@ def discriminator_loss(
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:

View File

@@ -86,9 +86,7 @@ def mel_spectrogram(
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
@@ -96,9 +94,7 @@ def mel_spectrogram(
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
y,
@@ -150,17 +146,13 @@ def get_dataset_filelist(a):
with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first training file: {training_files[0]}")
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first validation file: {validation_files[0]}")
@@ -171,9 +163,7 @@ def get_dataset_filelist(a):
for x in fi.read().split("\n")
if len(x) > 0
]
print(
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
)
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
list_unseen_validation_files.append(unseen_validation_files)
return training_files, validation_files, list_unseen_validation_files
@@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
print("[INFO] checking dataset integrity...")
for i in tqdm(range(len(self.audio_files))):
assert os.path.exists(
self.audio_files[i]
), f"{self.audio_files[i]} not found"
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
def __getitem__(
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
try:
filename = self.audio_files[index]
@@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
# Obtain randomized audio chunk
if source_sampling_rate != self.sampling_rate:
# Adjust segment size to crop if the source sr is different
target_segment_size = math.ceil(
self.segment_size
* (source_sampling_rate / self.sampling_rate)
)
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
else:
target_segment_size = self.segment_size
# Compute upper bound index for the random chunk
random_chunk_upper_bound = max(
0, audio.shape[0] - target_segment_size
)
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
# Crop or pad audio to obtain random chunk with target_segment_size
if audio.shape[0] >= target_segment_size:
@@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
else:
# For fine-tuning, assert that the waveform is in the defined sampling_rate
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
assert (
source_sampling_rate == self.sampling_rate
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
assert source_sampling_rate == self.sampling_rate, (
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
)
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
@@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[
:,
mel_start
* self.hop_size : (mel_start + frames_per_seg)
* self.hop_size,
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
]
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
mel = torch.nn.functional.pad(
mel, (0, frames_per_seg - mel.size(2)), "constant"
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
mel_loss = mel_spectrogram(
@@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
# Shape sanity checks
assert (
audio.shape[1] == mel.shape[2] * self.hop_size
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), (
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
@@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
if self.fine_tuning:
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
else:
print(
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
)
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
return self[random.randrange(len(self))]
def __len__(self):

View File

@@ -3,6 +3,7 @@
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations.Snake cuda vs. torch
fused_anti_alias_activation = activation1d.Activation1d(
activation=Snake(10), fused=True
).cuda()
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(
activation=Snake(10), fused=False
).cuda()
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()

View File

@@ -3,6 +3,7 @@
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations, Snake CUDA vs. Torch
fused_anti_alias_activation = activation1d.Activation1d(
activation=SnakeBeta(10), fused=True
).cuda()
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(
activation=SnakeBeta(10), fused=False
).cuda()
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
@@ -57,7 +54,6 @@ def test_anti_alias_activation():
)
if __name__ == "__main__":
from alias_free_activation.cuda import load

View File

@@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
def get_mel(x, h):
return mel_spectrogram(
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
)
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def load_checkpoint(filepath, device):
@@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test script to check CUDA kernel correctness."
)
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
parser.add_argument(
"--checkpoint_file",
type=str,
@@ -91,27 +87,25 @@ if __name__ == "__main__":
# define number of samples and length of mel frame to benchmark
num_sample = 10
num_mel_frame = 16384
# CUDA kernel correctness check
diff = 0.0
for i in tqdm(range(num_sample)):
# Random mel
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
with torch.inference_mode():
audio_original = generator_original(data)
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
# Both outputs should be (almost) the same
test_result = (audio_original - audio_cuda_kernel).abs()
diff += test_result.mean(dim=-1).item()
diff /= num_sample
if (
diff <= 2e-3
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
@@ -125,9 +119,9 @@ if __name__ == "__main__":
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
del data, audio_original, audio_cuda_kernel
# Variables for tracking total time and VRAM usage
toc_total_original = 0
toc_total_cuda_kernel = 0
@@ -145,10 +139,10 @@ if __name__ == "__main__":
audio_original = generator_original(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_original += toc
toc_total_original += toc
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_original
torch.cuda.empty_cache()
@@ -163,11 +157,11 @@ if __name__ == "__main__":
torch.cuda.synchronize()
toc = time() - tic
toc_total_cuda_kernel += toc
audio_length_total += audio_cuda_kernel.shape[-1]
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_cuda_kernel
torch.cuda.empty_cache()
@@ -175,8 +169,8 @@ if __name__ == "__main__":
audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
# Print results
print(

View File

@@ -77,24 +77,18 @@ def train(rank, a, h):
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
print(
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
# Variable name is kept as "mrd" for backward compatibility & minimal code change
mrd = MultiBandDiscriminator(h).to(device)
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
print(
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
else: # Fallback to original MRD in BigVGAN-v1
mrd = MultiResolutionDiscriminator(h).to(device)
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
if h.get("use_multiscale_melloss", False):
print(
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
)
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
sampling_rate=h.sampling_rate
) # NOTE: accepts waveform as input
@@ -114,9 +108,7 @@ def train(rank, a, h):
if os.path.isdir(a.checkpoint_path):
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
cp_g = scan_checkpoint(
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
)
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
cp_do = scan_checkpoint(
a.checkpoint_path,
prefix="do_",
@@ -143,9 +135,7 @@ def train(rank, a, h):
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()),
h.learning_rate,
@@ -156,12 +146,8 @@ def train(rank, a, h):
optim_g.load_state_dict(state_dict_do["optim_g"])
optim_d.load_state_dict(state_dict_do["optim_d"])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
# Define training and validation datasets
@@ -169,9 +155,7 @@ def train(rank, a, h):
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
Example: trained on LibriTTS, validate on VCTK
"""
training_filelist, validation_filelist, list_unseen_validation_filelist = (
get_dataset_filelist(a)
)
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(
training_filelist,
@@ -324,33 +308,26 @@ def train(rank, a, h):
h.fmax_for_loss,
)
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if (
not "nonspeech" in mode
): # Skips if the name of dataset (in mode string) contains "nonspeech"
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
# Resample to 16000 for pesq
y_16k = pesq_resampler(y)
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
y_g_hat_int_16k = (
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
)
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
# MRSTFT calculation
min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
# Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set
if steps >= 0:
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
if (
a.save_audio
): # Also save audio to disk if --save_audio is set to True
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y[0],
os.path.join(
@@ -373,9 +350,7 @@ def train(rank, a, h):
steps,
h.sampling_rate,
)
if (
a.save_audio
): # Also save audio to disk if --save_audio is set to True
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y_g_hat[0, 0],
os.path.join(
@@ -487,15 +462,11 @@ def train(rank, a, h):
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
y_df_hat_r, y_df_hat_g
)
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
y_ds_hat_r, y_ds_hat_g
)
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
@@ -505,17 +476,11 @@ def train(rank, a, h):
# Whether to freeze D for initial training steps
if steps >= a.freeze_step:
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
mpd.parameters(), clip_grad_norm
)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
mrd.parameters(), clip_grad_norm
)
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
optim_d.step()
else:
print(
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
)
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
grad_norm_mpd = 0.0
grad_norm_mrd = 0.0
@@ -523,9 +488,7 @@ def train(rank, a, h):
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
lambda_melloss = h.get(
"lambda_melloss", 45.0
) # Defaults to 45 in BigVGAN-v1 if not set
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
else: # Uses mel <y_mel, y_g_hat_mel> for loss
@@ -542,27 +505,19 @@ def train(rank, a, h):
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
if steps >= a.freeze_step:
loss_gen_all = (
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
print(
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
)
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
loss_gen_all = loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(
generator.parameters(), clip_grad_norm
)
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
mel_error = (
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to stdout
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
print(
f"Steps: {steps:d}, "
f"Gen Loss Total: {loss_gen_all:4.3f}, "
@@ -577,11 +532,7 @@ def train(rank, a, h):
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
save_checkpoint(
checkpoint_path,
{
"generator": (
generator.module if h.num_gpus > 1 else generator
).state_dict()
},
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
)
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
save_checkpoint(
@@ -598,9 +549,7 @@ def train(rank, a, h):
# Tensorboard summary logging
if steps % a.summary_interval == 0:
mel_error = (
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to tensorboard
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
@@ -612,12 +561,8 @@ def train(rank, a, h):
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
sw.add_scalar(
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
)
sw.add_scalar(
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
)
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
sw.add_scalar("training/epoch", epoch + 1, steps)
# Validation
@@ -660,9 +605,7 @@ def train(rank, a, h):
scheduler_d.step()
if rank == 0:
print(
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
)
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
def main():
@@ -674,12 +617,8 @@ def main():
parser.add_argument("--input_wavs_dir", default="LibriTTS")
parser.add_argument("--input_mels_dir", default="ft_dataset")
parser.add_argument(
"--input_training_file", default="tests/LibriTTS/train-full.txt"
)
parser.add_argument(
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
)
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
parser.add_argument(
"--list_input_unseen_wavs_dir",