Support for mel_band_roformer (#2078)
* support for mel_band_roformer * Remove unnecessary audio channel judgments * remove context manager and fix path * Update webui.py * Update README.md
This commit is contained in:
@@ -6,6 +6,7 @@ from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
@@ -356,13 +357,18 @@ class BSRoformer(Module):
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
@@ -402,7 +408,7 @@ class BSRoformer(Module):
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
|
||||
|
||||
assert len(freqs_per_bands) > 1
|
||||
assert sum(
|
||||
@@ -421,7 +427,8 @@ class BSRoformer(Module):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
@@ -458,12 +465,14 @@ class BSRoformer(Module):
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
# defining whether model is loaded on MPS (MacOS GPU accelerator)
|
||||
x_is_mps = True if device.type == "mps" else False
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
|
||||
channels = raw_audio.shape[1]
|
||||
assert (not self.stereo and channels == 1) or (
|
||||
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
|
||||
# to stft
|
||||
|
||||
@@ -471,53 +480,79 @@ class BSRoformer(Module):
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
# RuntimeError: FFT operations are only supported on MacOS 14+
|
||||
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
|
||||
try:
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
except:
|
||||
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
|
||||
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
stft_repr = rearrange(stft_repr,
|
||||
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
|
||||
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
|
||||
# print("460:", x.dtype)#fp32
|
||||
x = self.band_split(x)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
# print("487:",x.dtype)#fp16
|
||||
for transformer_block in self.layers:
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
# print("494:", x.dtype)#fp16
|
||||
x = linear_transformer(x)
|
||||
# print("496:", x.dtype)#fp16
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
# print("501:", x.dtype)#fp16
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
|
||||
x = time_transformer(x)
|
||||
# print("505:", x.dtype)#fp16
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
|
||||
x = freq_transformer(x)
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
|
||||
# print("515:", x.dtype)######fp16
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
# print("519:", x.dtype)#fp32
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
@@ -535,7 +570,11 @@ class BSRoformer(Module):
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
|
||||
# same as torch.stft() fix for MacOS MPS above
|
||||
try:
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
|
||||
except:
|
||||
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user