Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
from env import AttrDict
from utils import get_padding
import typing
from typing import Optional, List, Union, Dict, Tuple
from typing import List, Tuple
class DiscriminatorP(torch.nn.Module):
@@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
),
]
)
self.conv_post = norm_f(
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
)
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList(
[
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
for rs in self.mpd_reshapes
]
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
super().__init__()
self.resolution = resolution
assert (
len(self.resolution) == 3
), f"MRD layer requires list with len=3, got {self.resolution}"
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print(
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
)
norm_f = (
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
)
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
@@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
),
]
)
self.conv_post = norm_f(
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
)
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert (
len(self.resolutions) == 3
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
assert len(self.resolutions) == 3, (
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
)
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset
@@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList(
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
)
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__()
self.cfg = cfg
@@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations):
out_chs = min(
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
)
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
self.convs.append(
weight_norm(
nn.Conv2d(
@@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
in_chs,
out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding(
(self.kernel_size[0], self.kernel_size[0])
),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
)
@@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
if self.cqtd_normalize_volume:
print(
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
)
def get_2d_padding(
@@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
# Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
"cqtd_bins_per_octaves", [24, 36, 48]
)
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
self.discriminators = nn.ModuleList(
[
@@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []