Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -7,23 +7,22 @@ import torch.nn.functional as F
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
scale = None
):
def __init__(self, dropout=0.0, flash=False, scale=None):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), (
"in order to use flash attention, you must be using pytorch 2.0 or above"
)
def flash_attn(self, q, k, v):
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
@@ -34,7 +33,7 @@ class Attend(nn.Module):
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
return F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
def forward(self, q, k, v):
"""
@@ -54,7 +53,7 @@ class Attend(nn.Module):
# similarity
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# attention
@@ -63,6 +62,6 @@ class Attend(nn.Module):
# aggregate values
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out

View File

@@ -1,14 +1,14 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
from typing import Tuple, Optional, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
@@ -19,6 +19,7 @@ from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
@@ -37,14 +38,15 @@ def unpack_one(t, ps, pattern):
# norm
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
return F.normalize(t, dim=-1, p=2)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -53,13 +55,9 @@ class RMSNorm(Module):
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -68,7 +66,7 @@ class FeedForward(Module):
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
nn.Dropout(dropout),
)
def forward(self, x):
@@ -76,18 +74,10 @@ class FeedForward(Module):
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -99,15 +89,12 @@ class Attention(Module):
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -116,9 +103,9 @@ class Attention(Module):
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -128,42 +115,22 @@ class LinearAttention(Module):
"""
# @beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
def forward(
self,
x
):
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -178,19 +145,19 @@ class LinearAttention(Module):
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
@@ -199,18 +166,20 @@ class Transformer(Module):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
attn = Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
flash=flash_attn,
)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
@@ -220,22 +189,16 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -250,13 +213,7 @@ class BandSplit(Module):
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -277,13 +234,7 @@ def MLP(
class MaskEstimator(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -292,10 +243,7 @@ class MaskEstimator(Module):
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
self.to_freqs.append(mlp)
@@ -314,53 +262,106 @@ class MaskEstimator(Module):
# main class
DEFAULT_FREQS_PER_BANDS = (
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
12, 12, 12, 12, 12, 12, 12, 12,
24, 24, 24, 24, 24, 24, 24, 24,
48, 48, 48, 48, 48, 48, 48, 48,
128, 129,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
12,
12,
12,
12,
12,
12,
12,
12,
24,
24,
24,
24,
24,
24,
24,
24,
48,
48,
48,
48,
48,
48,
48,
48,
128,
129,
)
class BSRoformer(Module):
# @beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.0,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
@@ -379,7 +380,7 @@ class BSRoformer(Module):
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn,
norm_output=False
norm_output=False,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -400,26 +401,23 @@ class BSRoformer(Module):
self.final_norm = RMSNorm(dim)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
)
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
freqs = torch.stft(
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True
).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
assert sum(freqs_per_bands) == freqs, (
f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
)
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
self.mask_estimators = nn.ModuleList([])
@@ -440,17 +438,9 @@ class BSRoformer(Module):
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -469,14 +459,16 @@ class BSRoformer(Module):
x_is_mps = True if device.type == "mps" else False
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
)
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
@@ -485,16 +477,21 @@ class BSRoformer(Module):
try:
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
except:
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
stft_repr = torch.stft(
raw_audio.cpu() if x_is_mps else raw_audio,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=True,
).to(device)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
x = rearrange(stft_repr, "b f t c -> b t (f c)")
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
@@ -505,16 +502,15 @@ class BSRoformer(Module):
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x, ft_ps = pack([x], "b * d")
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
(x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
@@ -523,24 +519,24 @@ class BSRoformer(Module):
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = rearrange(x, "b t f d -> b f t d")
x, ps = pack([x], "* t d")
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
(x,) = unpack(x, ps, "* t d")
x = rearrange(x, "b f t d -> b t f d")
x, ps = pack([x], "* f d")
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
(x,) = unpack(x, ps, "* f d")
if self.skip_connection:
store[i] = x
@@ -553,11 +549,11 @@ class BSRoformer(Module):
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -568,18 +564,26 @@ class BSRoformer(Module):
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
# same as torch.stft() fix for MacOS MPS above
try:
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
recon_audio = torch.istft(
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]
)
except:
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
recon_audio = torch.istft(
stft_repr.cpu() if x_is_mps else stft_repr,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=False,
length=raw_audio.shape[-1],
).to(device)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -590,13 +594,13 @@ class BSRoformer(Module):
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = rearrange(target, "... t -> ... 1 t")
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
@@ -607,8 +611,8 @@ class BSRoformer(Module):
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
@@ -619,4 +623,4 @@ class BSRoformer(Module):
if not return_loss_breakdown:
return total_loss
return total_loss, (loss, multi_stft_resolution_loss)
return total_loss, (loss, multi_stft_resolution_loss)

View File

@@ -1,14 +1,14 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
from typing import Tuple, Optional, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
@@ -22,6 +22,7 @@ from librosa import filters
# helper functions
def exists(val):
return val is not None
@@ -38,9 +39,9 @@ def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def pad_at_dim(t, pad, dim=-1, value=0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
@@ -50,10 +51,11 @@ def l2norm(t):
# norm
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -62,13 +64,9 @@ class RMSNorm(Module):
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -77,7 +75,7 @@ class FeedForward(Module):
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
nn.Dropout(dropout),
)
def forward(self, x):
@@ -85,18 +83,10 @@ class FeedForward(Module):
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -108,15 +98,12 @@ class Attention(Module):
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -125,9 +112,9 @@ class Attention(Module):
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -137,42 +124,22 @@ class LinearAttention(Module):
"""
# @beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
def forward(
self,
x
):
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -187,19 +154,19 @@ class LinearAttention(Module):
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
@@ -208,18 +175,20 @@ class Transformer(Module):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
attn = Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
flash=flash_attn,
)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
@@ -229,22 +198,16 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -259,13 +222,7 @@ class BandSplit(Module):
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -286,13 +243,7 @@ def MLP(
class MaskEstimator(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -301,10 +252,7 @@ class MaskEstimator(Module):
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
self.to_freqs.append(mlp)
@@ -322,43 +270,43 @@ class MaskEstimator(Module):
# main class
class MelBandRoformer(Module):
class MelBandRoformer(Module):
# @beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
num_bands=60,
dim_head=64,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
flash_attn=True,
dim_freqs_in=1025,
sample_rate=44100, # needed for mel filter bank from librosa
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=1,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
num_bands=60,
dim_head=64,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
flash_attn=True,
dim_freqs_in=1025,
sample_rate=44100, # needed for mel filter bank from librosa
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=1,
multi_stft_resolution_loss_weight=1.0,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
@@ -376,7 +324,7 @@ class MelBandRoformer(Module):
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn
flash_attn=flash_attn,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -397,13 +345,12 @@ class MelBandRoformer(Module):
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
freqs = torch.stft(
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True
).shape[1]
# create mel filter bank
# with librosa.filters.mel as in section 2 of paper
@@ -414,43 +361,40 @@ class MelBandRoformer(Module):
# for some reason, it doesn't include the first freq? just force a value for now
mel_filter_bank[0][0] = 1.
mel_filter_bank[0][0] = 1.0
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
# so let's force a positive value
mel_filter_bank[-1, -1] = 1.
mel_filter_bank[-1, -1] = 1.0
# binary as in paper (then estimated masks are averaged for overlapping regions)
freqs_per_band = mel_filter_bank > 0
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
freq_indices = repeated_freq_indices[freqs_per_band]
if stereo:
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
freq_indices = repeat(freq_indices, "f -> f s", s=2)
freq_indices = freq_indices * 2 + torch.arange(2)
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
freq_indices = rearrange(freq_indices, "f s -> (f s)")
self.register_buffer('freq_indices', freq_indices, persistent=False)
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
self.register_buffer("freq_indices", freq_indices, persistent=False)
self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
# band split and mask estimator
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
self.mask_estimators = nn.ModuleList([])
@@ -471,19 +415,11 @@ class MelBandRoformer(Module):
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
self.match_input_audio_length = match_input_audio_length
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -499,28 +435,29 @@ class MelBandRoformer(Module):
device = raw_audio.device
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
batch, channels, raw_audio_length = raw_audio.shape
istft_length = raw_audio_length if self.match_input_audio_length else None
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
)
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
# index out all frequencies for all frequency ranges across bands ascending in one go
@@ -532,7 +469,7 @@ class MelBandRoformer(Module):
# fold the complex (real and imag) into the frequencies dimension
x = rearrange(x, 'b f t c -> b t (f c)')
x = rearrange(x, "b f t c -> b t (f c)")
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
@@ -543,16 +480,15 @@ class MelBandRoformer(Module):
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x, ft_ps = pack([x], "b * d")
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
(x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
@@ -561,24 +497,24 @@ class MelBandRoformer(Module):
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = rearrange(x, "b t f d -> b f t d")
x, ps = pack([x], "* t d")
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
(x,) = unpack(x, ps, "* t d")
x = rearrange(x, "b f t d -> b t f d")
x, ps = pack([x], "* f d")
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
(x,) = unpack(x, ps, "* f d")
if self.skip_connection:
store[i] = x
@@ -588,11 +524,11 @@ class MelBandRoformer(Module):
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -603,12 +539,12 @@ class MelBandRoformer(Module):
# need to average the estimated mask for the overlapped frequencies
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=num_stems, t=stft_repr.shape[-1])
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
masks_averaged = masks_summed / denom.clamp(min=1e-8)
@@ -618,15 +554,16 @@ class MelBandRoformer(Module):
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
length=istft_length)
recon_audio = torch.istft(
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=istft_length
)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -637,13 +574,13 @@ class MelBandRoformer(Module):
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = rearrange(target, "... t -> ... 1 t")
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
@@ -654,8 +591,8 @@ class MelBandRoformer(Module):
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)