Support for mel_band_roformer (#2078)

* support for mel_band_roformer

* Remove unnecessary audio channel judgments

* remove context manager and fix path

* Update webui.py

* Update README.md
This commit is contained in:
Sucial
2025-02-23 20:28:53 +08:00
committed by GitHub
parent fbb9f21e53
commit e061e9d38e
10 changed files with 941 additions and 176 deletions

View File

@@ -6,6 +6,7 @@ from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
# from beartype.typing import Tuple, Optional, List, Callable
@@ -356,13 +357,18 @@ class BSRoformer(Module):
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
self.stereo = stereo
self.audio_channels = 2 if stereo else 1
self.num_stems = num_stems
self.use_torch_checkpoint = use_torch_checkpoint
self.skip_connection = skip_connection
self.layers = ModuleList([])
@@ -402,7 +408,7 @@ class BSRoformer(Module):
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
@@ -421,7 +427,8 @@ class BSRoformer(Module):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth
depth=mask_estimator_depth,
mlp_expansion_factor=mlp_expansion_factor,
)
self.mask_estimators.append(mask_estimator)
@@ -458,12 +465,14 @@ class BSRoformer(Module):
device = raw_audio.device
# defining whether model is loaded on MPS (MacOS GPU accelerator)
x_is_mps = True if device.type == "mps" else False
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
# to stft
@@ -471,53 +480,79 @@ class BSRoformer(Module):
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
# RuntimeError: FFT operations are only supported on MacOS 14+
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
try:
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
except:
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = rearrange(stft_repr,
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
# print("460:", x.dtype)#fp32
x = self.band_split(x)
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
else:
x = self.band_split(x)
# axial / hierarchical attention
# print("487:",x.dtype)#fp16
for transformer_block in self.layers:
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
# print("494:", x.dtype)#fp16
x = linear_transformer(x)
# print("496:", x.dtype)#fp16
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
else:
time_transformer, freq_transformer = transformer_block
# print("501:", x.dtype)#fp16
if self.skip_connection:
# Sum all previous
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = time_transformer(x)
# print("505:", x.dtype)#fp16
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
x = freq_transformer(x)
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
# print("515:", x.dtype)######fp16
if self.skip_connection:
store[i] = x
x = self.final_norm(x)
num_stems = len(self.mask_estimators)
# print("519:", x.dtype)#fp32
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
if self.use_torch_checkpoint:
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
# modulate frequency representation
@@ -535,7 +570,11 @@ class BSRoformer(Module):
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
# same as torch.stft() fix for MacOS MPS above
try:
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
except:
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)