Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -6,13 +6,12 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy import signal
|
||||
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import functools
|
||||
@@ -117,15 +116,13 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
window_type,
|
||||
):
|
||||
"""
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert (
|
||||
hop_length == window_length // 4
|
||||
), "For match_stride, hop must equal n_fft // 4"
|
||||
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
@@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(
|
||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
||||
)
|
||||
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
@@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(
|
||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
||||
):
|
||||
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
@@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(
|
||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(
|
||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
@@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(
|
||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
|
||||
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
@@ -226,7 +212,6 @@ def feature_loss(
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
@@ -243,7 +228,6 @@ def discriminator_loss(
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
|
||||
Reference in New Issue
Block a user