Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,14 +1,14 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
from typing import Tuple, Optional, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
@@ -19,6 +19,7 @@ from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
@@ -37,14 +38,15 @@ def unpack_one(t, ps, pattern):
# norm
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
return F.normalize(t, dim=-1, p=2)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -53,13 +55,9 @@ class RMSNorm(Module):
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -68,7 +66,7 @@ class FeedForward(Module):
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
nn.Dropout(dropout),
)
def forward(self, x):
@@ -76,18 +74,10 @@ class FeedForward(Module):
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -99,15 +89,12 @@ class Attention(Module):
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -116,9 +103,9 @@ class Attention(Module):
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -128,42 +115,22 @@ class LinearAttention(Module):
"""
# @beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
def forward(
self,
x
):
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -178,19 +145,19 @@ class LinearAttention(Module):
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
@@ -199,18 +166,20 @@ class Transformer(Module):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
attn = Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
flash=flash_attn,
)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
@@ -220,22 +189,16 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -250,13 +213,7 @@ class BandSplit(Module):
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -277,13 +234,7 @@ def MLP(
class MaskEstimator(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -292,10 +243,7 @@ class MaskEstimator(Module):
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
self.to_freqs.append(mlp)
@@ -314,53 +262,106 @@ class MaskEstimator(Module):
# main class
DEFAULT_FREQS_PER_BANDS = (
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
12, 12, 12, 12, 12, 12, 12, 12,
24, 24, 24, 24, 24, 24, 24, 24,
48, 48, 48, 48, 48, 48, 48, 48,
128, 129,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
12,
12,
12,
12,
12,
12,
12,
12,
24,
24,
24,
24,
24,
24,
24,
24,
48,
48,
48,
48,
48,
48,
48,
48,
128,
129,
)
class BSRoformer(Module):
# @beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.0,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
@@ -379,7 +380,7 @@ class BSRoformer(Module):
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn,
norm_output=False
norm_output=False,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -400,26 +401,23 @@ class BSRoformer(Module):
self.final_norm = RMSNorm(dim)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
)
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
freqs = torch.stft(
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True
).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
assert sum(freqs_per_bands) == freqs, (
f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
)
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
self.mask_estimators = nn.ModuleList([])
@@ -440,17 +438,9 @@ class BSRoformer(Module):
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -469,14 +459,16 @@ class BSRoformer(Module):
x_is_mps = True if device.type == "mps" else False
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
)
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
@@ -485,16 +477,21 @@ class BSRoformer(Module):
try:
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
except:
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
stft_repr = torch.stft(
raw_audio.cpu() if x_is_mps else raw_audio,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=True,
).to(device)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
x = rearrange(stft_repr, "b f t c -> b t (f c)")
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
@@ -505,16 +502,15 @@ class BSRoformer(Module):
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x, ft_ps = pack([x], "b * d")
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
(x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
@@ -523,24 +519,24 @@ class BSRoformer(Module):
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = rearrange(x, "b t f d -> b f t d")
x, ps = pack([x], "* t d")
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
(x,) = unpack(x, ps, "* t d")
x = rearrange(x, "b f t d -> b t f d")
x, ps = pack([x], "* f d")
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
(x,) = unpack(x, ps, "* f d")
if self.skip_connection:
store[i] = x
@@ -553,11 +549,11 @@ class BSRoformer(Module):
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -568,18 +564,26 @@ class BSRoformer(Module):
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
# same as torch.stft() fix for MacOS MPS above
try:
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
recon_audio = torch.istft(
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]
)
except:
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
recon_audio = torch.istft(
stft_repr.cpu() if x_is_mps else stft_repr,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=False,
length=raw_audio.shape[-1],
).to(device)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -590,13 +594,13 @@ class BSRoformer(Module):
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = rearrange(target, "... t -> ... 1 t")
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
@@ -607,8 +611,8 @@ class BSRoformer(Module):
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
@@ -619,4 +623,4 @@ class BSRoformer(Module):
if not return_loss_breakdown:
return total_loss
return total_loss, (loss, multi_stft_resolution_loss)
return total_loss, (loss, multi_stft_resolution_loss)