Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -6,13 +6,12 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import Optional, List, Union, Dict, Tuple
from typing import List, Tuple
from collections import namedtuple
import math
import functools
@@ -117,15 +116,13 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
window_type,
):
"""
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
B, C, T = wav.shape
if match_stride:
assert (
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
@@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
@@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
@@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
@@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
# Loss functions
def feature_loss(
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
@@ -226,7 +212,6 @@ def feature_loss(
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
@@ -243,7 +228,6 @@ def discriminator_loss(
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs: