Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -7,23 +7,22 @@ import torch.nn.functional as F
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
scale = None
):
def __init__(self, dropout=0.0, flash=False, scale=None):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), (
"in order to use flash attention, you must be using pytorch 2.0 or above"
)
def flash_attn(self, q, k, v):
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
@@ -34,7 +33,7 @@ class Attend(nn.Module):
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
return F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
def forward(self, q, k, v):
"""
@@ -54,7 +53,7 @@ class Attend(nn.Module):
# similarity
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# attention
@@ -63,6 +62,6 @@ class Attend(nn.Module):
# aggregate values
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out

View File

@@ -1,14 +1,14 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
from typing import Tuple, Optional, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
@@ -19,6 +19,7 @@ from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
@@ -37,14 +38,15 @@ def unpack_one(t, ps, pattern):
# norm
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
return F.normalize(t, dim=-1, p=2)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -53,13 +55,9 @@ class RMSNorm(Module):
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -68,7 +66,7 @@ class FeedForward(Module):
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
nn.Dropout(dropout),
)
def forward(self, x):
@@ -76,18 +74,10 @@ class FeedForward(Module):
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -99,15 +89,12 @@ class Attention(Module):
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -116,9 +103,9 @@ class Attention(Module):
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -128,42 +115,22 @@ class LinearAttention(Module):
"""
# @beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
def forward(
self,
x
):
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -178,19 +145,19 @@ class LinearAttention(Module):
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
@@ -199,18 +166,20 @@ class Transformer(Module):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
attn = Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
flash=flash_attn,
)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
@@ -220,22 +189,16 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -250,13 +213,7 @@ class BandSplit(Module):
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -277,13 +234,7 @@ def MLP(
class MaskEstimator(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -292,10 +243,7 @@ class MaskEstimator(Module):
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
self.to_freqs.append(mlp)
@@ -314,53 +262,106 @@ class MaskEstimator(Module):
# main class
DEFAULT_FREQS_PER_BANDS = (
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
12, 12, 12, 12, 12, 12, 12, 12,
24, 24, 24, 24, 24, 24, 24, 24,
48, 48, 48, 48, 48, 48, 48, 48,
128, 129,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
12,
12,
12,
12,
12,
12,
12,
12,
24,
24,
24,
24,
24,
24,
24,
24,
48,
48,
48,
48,
48,
48,
48,
48,
128,
129,
)
class BSRoformer(Module):
# @beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.0,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
@@ -379,7 +380,7 @@ class BSRoformer(Module):
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn,
norm_output=False
norm_output=False,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -400,26 +401,23 @@ class BSRoformer(Module):
self.final_norm = RMSNorm(dim)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
)
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
freqs = torch.stft(
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True
).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
assert sum(freqs_per_bands) == freqs, (
f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
)
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
self.mask_estimators = nn.ModuleList([])
@@ -440,17 +438,9 @@ class BSRoformer(Module):
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -469,14 +459,16 @@ class BSRoformer(Module):
x_is_mps = True if device.type == "mps" else False
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
)
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
@@ -485,16 +477,21 @@ class BSRoformer(Module):
try:
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
except:
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
stft_repr = torch.stft(
raw_audio.cpu() if x_is_mps else raw_audio,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=True,
).to(device)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
x = rearrange(stft_repr, "b f t c -> b t (f c)")
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
@@ -505,16 +502,15 @@ class BSRoformer(Module):
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x, ft_ps = pack([x], "b * d")
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
(x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
@@ -523,24 +519,24 @@ class BSRoformer(Module):
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = rearrange(x, "b t f d -> b f t d")
x, ps = pack([x], "* t d")
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
(x,) = unpack(x, ps, "* t d")
x = rearrange(x, "b f t d -> b t f d")
x, ps = pack([x], "* f d")
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
(x,) = unpack(x, ps, "* f d")
if self.skip_connection:
store[i] = x
@@ -553,11 +549,11 @@ class BSRoformer(Module):
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -568,18 +564,26 @@ class BSRoformer(Module):
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
# same as torch.stft() fix for MacOS MPS above
try:
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
recon_audio = torch.istft(
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]
)
except:
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
recon_audio = torch.istft(
stft_repr.cpu() if x_is_mps else stft_repr,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=False,
length=raw_audio.shape[-1],
).to(device)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -590,13 +594,13 @@ class BSRoformer(Module):
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = rearrange(target, "... t -> ... 1 t")
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
@@ -607,8 +611,8 @@ class BSRoformer(Module):
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
@@ -619,4 +623,4 @@ class BSRoformer(Module):
if not return_loss_breakdown:
return total_loss
return total_loss, (loss, multi_stft_resolution_loss)
return total_loss, (loss, multi_stft_resolution_loss)

View File

@@ -1,14 +1,14 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
from typing import Tuple, Optional, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
@@ -22,6 +22,7 @@ from librosa import filters
# helper functions
def exists(val):
return val is not None
@@ -38,9 +39,9 @@ def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def pad_at_dim(t, pad, dim=-1, value=0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
@@ -50,10 +51,11 @@ def l2norm(t):
# norm
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -62,13 +64,9 @@ class RMSNorm(Module):
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -77,7 +75,7 @@ class FeedForward(Module):
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
nn.Dropout(dropout),
)
def forward(self, x):
@@ -85,18 +83,10 @@ class FeedForward(Module):
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -108,15 +98,12 @@ class Attention(Module):
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -125,9 +112,9 @@ class Attention(Module):
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -137,42 +124,22 @@ class LinearAttention(Module):
"""
# @beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
def forward(
self,
x
):
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -187,19 +154,19 @@ class LinearAttention(Module):
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
@@ -208,18 +175,20 @@ class Transformer(Module):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
attn = Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
flash=flash_attn,
)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
@@ -229,22 +198,16 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -259,13 +222,7 @@ class BandSplit(Module):
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -286,13 +243,7 @@ def MLP(
class MaskEstimator(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -301,10 +252,7 @@ class MaskEstimator(Module):
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
self.to_freqs.append(mlp)
@@ -322,43 +270,43 @@ class MaskEstimator(Module):
# main class
class MelBandRoformer(Module):
class MelBandRoformer(Module):
# @beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
num_bands=60,
dim_head=64,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
flash_attn=True,
dim_freqs_in=1025,
sample_rate=44100, # needed for mel filter bank from librosa
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=1,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
num_bands=60,
dim_head=64,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
flash_attn=True,
dim_freqs_in=1025,
sample_rate=44100, # needed for mel filter bank from librosa
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=1,
multi_stft_resolution_loss_weight=1.0,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
@@ -376,7 +324,7 @@ class MelBandRoformer(Module):
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn
flash_attn=flash_attn,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -397,13 +345,12 @@ class MelBandRoformer(Module):
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
freqs = torch.stft(
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True
).shape[1]
# create mel filter bank
# with librosa.filters.mel as in section 2 of paper
@@ -414,43 +361,40 @@ class MelBandRoformer(Module):
# for some reason, it doesn't include the first freq? just force a value for now
mel_filter_bank[0][0] = 1.
mel_filter_bank[0][0] = 1.0
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
# so let's force a positive value
mel_filter_bank[-1, -1] = 1.
mel_filter_bank[-1, -1] = 1.0
# binary as in paper (then estimated masks are averaged for overlapping regions)
freqs_per_band = mel_filter_bank > 0
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
freq_indices = repeated_freq_indices[freqs_per_band]
if stereo:
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
freq_indices = repeat(freq_indices, "f -> f s", s=2)
freq_indices = freq_indices * 2 + torch.arange(2)
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
freq_indices = rearrange(freq_indices, "f s -> (f s)")
self.register_buffer('freq_indices', freq_indices, persistent=False)
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
self.register_buffer("freq_indices", freq_indices, persistent=False)
self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
# band split and mask estimator
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
self.mask_estimators = nn.ModuleList([])
@@ -471,19 +415,11 @@ class MelBandRoformer(Module):
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
self.match_input_audio_length = match_input_audio_length
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -499,28 +435,29 @@ class MelBandRoformer(Module):
device = raw_audio.device
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
batch, channels, raw_audio_length = raw_audio.shape
istft_length = raw_audio_length if self.match_input_audio_length else None
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
)
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
# index out all frequencies for all frequency ranges across bands ascending in one go
@@ -532,7 +469,7 @@ class MelBandRoformer(Module):
# fold the complex (real and imag) into the frequencies dimension
x = rearrange(x, 'b f t c -> b t (f c)')
x = rearrange(x, "b f t c -> b t (f c)")
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
@@ -543,16 +480,15 @@ class MelBandRoformer(Module):
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x, ft_ps = pack([x], "b * d")
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
(x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
@@ -561,24 +497,24 @@ class MelBandRoformer(Module):
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = rearrange(x, "b t f d -> b f t d")
x, ps = pack([x], "* t d")
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
(x,) = unpack(x, ps, "* t d")
x = rearrange(x, "b f t d -> b t f d")
x, ps = pack([x], "* f d")
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
(x,) = unpack(x, ps, "* f d")
if self.skip_connection:
store[i] = x
@@ -588,11 +524,11 @@ class MelBandRoformer(Module):
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -603,12 +539,12 @@ class MelBandRoformer(Module):
# need to average the estimated mask for the overlapped frequencies
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=num_stems, t=stft_repr.shape[-1])
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
masks_averaged = masks_summed / denom.clamp(min=1e-8)
@@ -618,15 +554,16 @@ class MelBandRoformer(Module):
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
length=istft_length)
recon_audio = torch.istft(
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=istft_length
)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -637,13 +574,13 @@ class MelBandRoformer(Module):
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = rearrange(target, "... t -> ... 1 t")
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
@@ -654,8 +591,8 @@ class MelBandRoformer(Module):
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)

View File

@@ -1,28 +1,31 @@
# This code is modified from https://github.com/ZFTurbo/
import librosa
from tqdm import tqdm
import os
import torch
import warnings
import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
import yaml
import warnings
from tqdm import tqdm
warnings.filterwarnings("ignore")
class Roformer_Loader:
def get_config(self, config_path):
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, "r", encoding="utf-8") as f:
# use fullloader to load tag !!python/tuple, code can be improved
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def get_default_config(self):
default_config = None
if self.model_type == 'bs_roformer':
if self.model_type == "bs_roformer":
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
# Other BS_Roformer models may not be compatible
# fmt: off
default_config = {
"audio": {"chunk_size": 352800, "sample_rate": 44100},
"model": {
@@ -51,9 +54,10 @@ class Roformer_Loader:
"multi_stft_normalized": False,
},
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
"inference": {"batch_size": 2, "num_overlap": 2}
"inference": {"batch_size": 2, "num_overlap": 2},
}
elif self.model_type == 'mel_band_roformer':
# fmt: on
elif self.model_type == "mel_band_roformer":
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
# Other Mel_Band_Roformer models may not be compatible
default_config = {
@@ -82,29 +86,30 @@ class Roformer_Loader:
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
"multi_stft_hop_size": 147,
"multi_stft_normalized": False
"multi_stft_normalized": False,
},
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
"inference": {"batch_size": 2, "num_overlap": 2}
"inference": {"batch_size": 2, "num_overlap": 2},
}
return default_config
def get_model_from_config(self):
if self.model_type == 'bs_roformer':
if self.model_type == "bs_roformer":
from bs_roformer.bs_roformer import BSRoformer
model = BSRoformer(**dict(self.config["model"]))
elif self.model_type == 'mel_band_roformer':
elif self.model_type == "mel_band_roformer":
from bs_roformer.mel_band_roformer import MelBandRoformer
model = MelBandRoformer(**dict(self.config["model"]))
else:
print('Error: Unknown model: {}'.format(self.model_type))
print("Error: Unknown model: {}".format(self.model_type))
model = None
return model
def demix_track(self, model, mix, device):
C = self.config["audio"]["chunk_size"] # chunk_size
C = self.config["audio"]["chunk_size"] # chunk_size
N = self.config["inference"]["num_overlap"]
fade_size = C // 10
step = int(C // N)
@@ -116,7 +121,7 @@ class Roformer_Loader:
# Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0):
mix = nn.functional.pad(mix, (border, border), mode='reflect')
mix = nn.functional.pad(mix, (border, border), mode="reflect")
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
window_size = C
@@ -125,17 +130,17 @@ class Roformer_Loader:
window_start = torch.ones(window_size)
window_middle = torch.ones(window_size)
window_finish = torch.ones(window_size)
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein
with torch.amp.autocast('cuda'):
with torch.amp.autocast("cuda"):
with torch.inference_mode():
if self.config["training"]["target_instrument"] is None:
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
else:
req_shape = (1, ) + tuple(mix.shape)
req_shape = (1,) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
@@ -143,15 +148,15 @@ class Roformer_Loader:
batch_data = []
batch_locations = []
while i < mix.shape[1]:
part = mix[:, i:i + C].to(device)
part = mix[:, i : i + C].to(device)
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
part = nn.functional.pad(input=part, pad=(0, C - length), mode="reflect")
else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode="constant", value=0)
if self.is_half:
part=part.half()
part = part.half()
batch_data.append(part)
batch_locations.append((i, length))
i += step
@@ -170,8 +175,8 @@ class Roformer_Loader:
for j in range(len(batch_locations)):
start, l = batch_locations[j]
result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
counter[..., start:start+l] += window[..., :l]
result[..., start : start + l] += x[j][..., :l].cpu() * window[..., :l]
counter[..., start : start + l] += window[..., :l]
batch_data = []
batch_locations = []
@@ -191,7 +196,6 @@ class Roformer_Loader:
else:
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
def run_folder(self, input, vocal_root, others_root, format):
self.model.eval()
path = input
@@ -200,20 +204,20 @@ class Roformer_Loader:
file_base_name = os.path.splitext(os.path.basename(path))[0]
sample_rate = 44100
if 'sample_rate' in self.config["audio"]:
sample_rate = self.config["audio"]['sample_rate']
if "sample_rate" in self.config["audio"]:
sample_rate = self.config["audio"]["sample_rate"]
try:
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
except Exception as e:
print('Can read track: {}'.format(path))
print('Error message: {}'.format(str(e)))
print("Can read track: {}".format(path))
print("Error message: {}".format(str(e)))
return
# in case if model only supports mono tracks
isstereo = self.config["model"].get("stereo", True)
if not isstereo and len(mix.shape) != 1:
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
mix_orig = mix.copy()
@@ -226,7 +230,7 @@ class Roformer_Loader:
# other instruments are caculated by subtracting target instrument from mixture
target_instrument = self.config["training"]["target_instrument"]
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
other = mix_orig - res[target_instrument] # caculate other instruments
other = mix_orig - res[target_instrument] # caculate other instruments
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
@@ -237,11 +241,10 @@ class Roformer_Loader:
vocal_inst = self.config["training"]["instruments"][0]
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
for other in self.config["training"]["instruments"][1:]: # save other instruments
for other in self.config["training"]["instruments"][1:]: # save other instruments
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
self.save_audio(path_other, res[other].T, sr, format)
def save_audio(self, path, data, sr, format):
# input path should be endwith '.wav'
if format in ["wav", "flac"]:
@@ -250,10 +253,11 @@ class Roformer_Loader:
sf.write(path, data, sr)
else:
sf.write(path, data, sr)
os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
try: os.remove(path)
except: pass
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format))
try:
os.remove(path)
except:
pass
def __init__(self, model_path, config_path, device, is_half):
self.device = device
@@ -270,7 +274,9 @@ class Roformer_Loader:
if not os.path.exists(config_path):
if self.model_type is None:
# if model_type is still None, raise an error
raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again.")
raise ValueError(
"Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again."
)
self.config = self.get_default_config()
else:
# if there is a configuration file
@@ -289,12 +295,10 @@ class Roformer_Loader:
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
if(is_half==False):
if is_half == False:
self.model = model.to(device)
else:
self.model = model.half().to(device)
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
self.run_folder(input, vocal_root, others_root, format)

View File

@@ -13,9 +13,7 @@ cpu = torch.device("cpu")
class ConvTDFNetTrim:
def __init__(
self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024
):
def __init__(self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024):
super(ConvTDFNetTrim, self).__init__()
self.dim_f = dim_f
@@ -24,17 +22,13 @@ class ConvTDFNetTrim:
self.hop = hop
self.n_bins = self.n_fft // 2 + 1
self.chunk_size = hop * (self.dim_t - 1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(
device
)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
self.target_name = target_name
self.blender = "blender" in model_name
self.dim_c = 4
out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
self.freq_pad = torch.zeros(
[1, out_c, self.n_bins - self.dim_f, self.dim_t]
).to(device)
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
self.n = L // 2
@@ -50,28 +44,18 @@ class ConvTDFNetTrim:
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
[-1, self.dim_c, self.n_bins, self.dim_t]
)
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
return x[:, :, : self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = (
self.freq_pad.repeat([x.shape[0], 1, 1, 1])
if freq_pad is None
else freq_pad
)
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
c = 4 * 2 if self.target_name == "*" else 2
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
[-1, 2, self.n_bins, self.dim_t]
)
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
x = x.permute([0, 2, 3, 1])
x = x.contiguous()
x = torch.view_as_complex(x)
x = torch.istft(
x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1, c, self.chunk_size])
@@ -93,9 +77,7 @@ class Predictor:
logger.info(ort.get_available_providers())
self.args = args
self.model_ = get_models(
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
)
self.model_ = get_models(device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft)
self.model = ort.InferenceSession(
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
providers=[
@@ -152,9 +134,7 @@ class Predictor:
trim = model.n_fft // 2
gen_size = model.chunk_size - 2 * trim
pad = gen_size - n_sample % gen_size
mix_p = np.concatenate(
(np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
)
mix_p = np.concatenate((np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1)
mix_waves = []
i = 0
while i < n_sample + pad:
@@ -172,15 +152,8 @@ class Predictor:
)
tar_waves = model.istft(torch.tensor(spec_pred))
else:
tar_waves = model.istft(
torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
)
tar_signal = (
tar_waves[:, :, trim:-trim]
.transpose(0, 1)
.reshape(2, -1)
.numpy()[:, :-pad]
)
tar_waves = model.istft(torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0]))
tar_signal = tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
start = 0 if mix == 0 else margin_size
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
@@ -207,9 +180,7 @@ class Predictor:
sources = self.demix(mix.T)
opt = sources[0].T
if format in ["wav", "flac"]:
sf.write(
"%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
)
sf.write("%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate)
sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
else:
path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
@@ -219,18 +190,14 @@ class Predictor:
opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
)
os.system("ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal))
if os.path.exists(opt_path_vocal):
try:
os.remove(path_vocal)
except:
pass
if os.path.exists(path_other):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
)
os.system("ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other))
if os.path.exists(opt_path_other):
try:
os.remove(path_other)
@@ -240,7 +207,7 @@ class Predictor:
class MDXNetDereverb:
def __init__(self, chunks):
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy"%os.path.dirname(os.path.abspath(__file__))
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy" % os.path.dirname(os.path.abspath(__file__))
self.shifts = 10 # 'Predict with randomised equivariant stabilisation'
self.mixing = "min_mag" # ['default','min_mag','max_mag']
self.chunks = chunks

View File

@@ -1,6 +1,8 @@
import os,sys
import os
parent_directory = os.path.dirname(os.path.abspath(__file__))
import logging,pdb
import logging
logger = logging.getLogger(__name__)
import librosa
@@ -27,7 +29,7 @@ class AudioPre:
"agg": agg,
"high_end_process": "mirroring",
}
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v2.json"%parent_directory)
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v2.json" % parent_directory)
model = Nets.CascadedASPPNet(mp.param["bins"] * 2)
cpk = torch.load(model_path, map_location="cpu")
model.load_state_dict(cpk)
@@ -40,9 +42,7 @@ class AudioPre:
self.mp = mp
self.model = model
def _path_audio_(
self, music_file, ins_root=None, vocal_root=None, format="flac", is_hp3=False
):
def _path_audio_(self, music_file, ins_root=None, vocal_root=None, format="flac", is_hp3=False):
if ins_root is None and vocal_root is None:
return "No save root."
name = os.path.basename(music_file)
@@ -61,19 +61,19 @@ class AudioPre:
_,
) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug应该上ffmpeg读取但是太麻烦了弃坑
music_file,
sr = bp["sr"],
mono = False,
dtype = np.float32,
res_type = bp["res_type"],
sr=bp["sr"],
mono=False,
dtype=np.float32,
res_type=bp["res_type"],
)
if X_wave[d].ndim == 1:
X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]])
else: # lower bands
X_wave[d] = librosa.core.resample(
X_wave[d + 1],
orig_sr = self.mp.param["band"][d + 1]["sr"],
target_sr = bp["sr"],
res_type = bp["res_type"],
orig_sr=self.mp.param["band"][d + 1]["sr"],
target_sr=bp["sr"],
res_type=bp["res_type"],
)
# Stft of wave source
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(
@@ -89,9 +89,7 @@ class AudioPre:
input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"]
)
input_high_end = X_spec_s[d][
:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :
]
input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :]
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp)
aggresive_set = float(self.data["agg"] / 100)
@@ -100,9 +98,7 @@ class AudioPre:
"split_bin": self.mp.param["band"][1]["crop_stop"],
}
with torch.no_grad():
pred, X_mag, X_phase = inference(
X_spec_m, self.device, self.model, aggressiveness, self.data
)
pred, X_mag, X_phase = inference(X_spec_m, self.device, self.model, aggressiveness, self.data)
# Postprocess
if self.data["postprocess"]:
pred_inv = np.clip(X_mag - pred, 0, np.inf)
@@ -111,13 +107,11 @@ class AudioPre:
v_spec_m = X_spec_m - y_spec_m
if is_hp3 == True:
ins_root,vocal_root = vocal_root,ins_root
ins_root, vocal_root = vocal_root, ins_root
if ins_root is not None:
if self.data["high_end_process"].startswith("mirroring"):
input_high_end_ = spec_utils.mirroring(
self.data["high_end_process"], y_spec_m, input_high_end, self.mp
)
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], y_spec_m, input_high_end, self.mp)
wav_instrument = spec_utils.cmb_spectrogram_to_wave(
y_spec_m, self.mp, input_high_end_h, input_high_end_
)
@@ -138,9 +132,7 @@ class AudioPre:
self.mp.param["sr"],
) #
else:
path = os.path.join(
ins_root, head + "{}_{}.wav".format(name, self.data["agg"])
)
path = os.path.join(ins_root, head + "{}_{}.wav".format(name, self.data["agg"]))
sf.write(
path,
(np.array(wav_instrument) * 32768).astype("int16"),
@@ -160,12 +152,8 @@ class AudioPre:
else:
head = "vocal_"
if self.data["high_end_process"].startswith("mirroring"):
input_high_end_ = spec_utils.mirroring(
self.data["high_end_process"], v_spec_m, input_high_end, self.mp
)
wav_vocals = spec_utils.cmb_spectrogram_to_wave(
v_spec_m, self.mp, input_high_end_h, input_high_end_
)
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], v_spec_m, input_high_end, self.mp)
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp, input_high_end_h, input_high_end_)
else:
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
logger.info("%s vocals done" % name)
@@ -179,9 +167,7 @@ class AudioPre:
self.mp.param["sr"],
)
else:
path = os.path.join(
vocal_root, head + "{}_{}.wav".format(name, self.data["agg"])
)
path = os.path.join(vocal_root, head + "{}_{}.wav".format(name, self.data["agg"]))
sf.write(
path,
(np.array(wav_vocals) * 32768).astype("int16"),
@@ -210,7 +196,7 @@ class AudioPreDeEcho:
"agg": agg,
"high_end_process": "mirroring",
}
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v3.json"%parent_directory)
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v3.json" % parent_directory)
nout = 64 if "DeReverb" in model_path else 48
model = CascadedNet(mp.param["bins"] * 2, nout)
cpk = torch.load(model_path, map_location="cpu")
@@ -245,19 +231,19 @@ class AudioPreDeEcho:
_,
) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug应该上ffmpeg读取但是太麻烦了弃坑
music_file,
sr = bp["sr"],
mono = False,
dtype = np.float32,
res_type = bp["res_type"],
sr=bp["sr"],
mono=False,
dtype=np.float32,
res_type=bp["res_type"],
)
if X_wave[d].ndim == 1:
X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]])
else: # lower bands
X_wave[d] = librosa.core.resample(
X_wave[d + 1],
orig_sr = self.mp.param["band"][d + 1]["sr"],
target_sr = bp["sr"],
res_type = bp["res_type"],
orig_sr=self.mp.param["band"][d + 1]["sr"],
target_sr=bp["sr"],
res_type=bp["res_type"],
)
# Stft of wave source
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(
@@ -273,9 +259,7 @@ class AudioPreDeEcho:
input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"]
)
input_high_end = X_spec_s[d][
:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :
]
input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :]
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp)
aggresive_set = float(self.data["agg"] / 100)
@@ -284,9 +268,7 @@ class AudioPreDeEcho:
"split_bin": self.mp.param["band"][1]["crop_stop"],
}
with torch.no_grad():
pred, X_mag, X_phase = inference(
X_spec_m, self.device, self.model, aggressiveness, self.data
)
pred, X_mag, X_phase = inference(X_spec_m, self.device, self.model, aggressiveness, self.data)
# Postprocess
if self.data["postprocess"]:
pred_inv = np.clip(X_mag - pred, 0, np.inf)
@@ -296,9 +278,7 @@ class AudioPreDeEcho:
if ins_root is not None:
if self.data["high_end_process"].startswith("mirroring"):
input_high_end_ = spec_utils.mirroring(
self.data["high_end_process"], y_spec_m, input_high_end, self.mp
)
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], y_spec_m, input_high_end, self.mp)
wav_instrument = spec_utils.cmb_spectrogram_to_wave(
y_spec_m, self.mp, input_high_end_h, input_high_end_
)
@@ -315,9 +295,7 @@ class AudioPreDeEcho:
self.mp.param["sr"],
) #
else:
path = os.path.join(
ins_root, "vocal_{}_{}.wav".format(name, self.data["agg"])
)
path = os.path.join(ins_root, "vocal_{}_{}.wav".format(name, self.data["agg"]))
sf.write(
path,
(np.array(wav_instrument) * 32768).astype("int16"),
@@ -333,12 +311,8 @@ class AudioPreDeEcho:
pass
if vocal_root is not None:
if self.data["high_end_process"].startswith("mirroring"):
input_high_end_ = spec_utils.mirroring(
self.data["high_end_process"], v_spec_m, input_high_end, self.mp
)
wav_vocals = spec_utils.cmb_spectrogram_to_wave(
v_spec_m, self.mp, input_high_end_h, input_high_end_
)
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], v_spec_m, input_high_end, self.mp)
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp, input_high_end_h, input_high_end_)
else:
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
logger.info("%s vocals done" % name)
@@ -352,9 +326,7 @@ class AudioPreDeEcho:
self.mp.param["sr"],
)
else:
path = os.path.join(
vocal_root, "instrument_{}_{}.wav".format(name, self.data["agg"])
)
path = os.path.join(vocal_root, "instrument_{}_{}.wav".format(name, self.data["agg"]))
sf.write(
path,
(np.array(wav_vocals) * 32768).astype("int16"),

View File

@@ -1,13 +1,14 @@
import os
import traceback,gradio as gr
import traceback
import gradio as gr
import logging
from tools.i18n.i18n import I18nAuto
from tools.my_utils import clean_path
i18n = I18nAuto()
logger = logging.getLogger(__name__)
import librosa,ffmpeg
import soundfile as sf
import ffmpeg
import torch
import sys
from mdxnet import MDXNetDereverb
@@ -16,8 +17,10 @@ from bsroformer import Roformer_Loader
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
analytics.version_check = lambda: None
except:
...
weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = []
@@ -25,21 +28,24 @@ for name in os.listdir(weight_uvr5_root):
if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name:
uvr5_names.append(name.replace(".pth", "").replace(".ckpt", ""))
device=sys.argv[1]
is_half=eval(sys.argv[2])
webui_port_uvr5=int(sys.argv[3])
is_share=eval(sys.argv[4])
device = sys.argv[1]
is_half = eval(sys.argv[2])
webui_port_uvr5 = int(sys.argv[3])
is_share = eval(sys.argv[4])
def html_left(text, label='p'):
def html_left(text, label="p"):
return f"""<div style="text-align: left; margin: 0; padding: 0;">
<{label} style="margin: 0; padding: 0;">{text}</{label}>
</div>"""
def html_center(text, label='p'):
def html_center(text, label="p"):
return f"""<div style="text-align: center; margin: 100; padding: 50;">
<{label} style="margin: 0; padding: 0;">{text}</{label}>
</div>"""
def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format0):
infos = []
try:
@@ -52,13 +58,15 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
elif "roformer" in model_name.lower():
func = Roformer_Loader
pre_fun = func(
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
config_path = os.path.join(weight_uvr5_root, model_name + ".yaml"),
device = device,
is_half=is_half
model_path=os.path.join(weight_uvr5_root, model_name + ".ckpt"),
config_path=os.path.join(weight_uvr5_root, model_name + ".yaml"),
device=device,
is_half=is_half,
)
if not os.path.exists(os.path.join(weight_uvr5_root, model_name + ".yaml")):
infos.append("Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning.")
infos.append(
"Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning."
)
yield "\n".join(infos)
else:
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
@@ -74,19 +82,15 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
paths = [path.name for path in paths]
for path in paths:
inp_path = os.path.join(inp_root, path)
if(os.path.isfile(inp_path)==False):continue
if os.path.isfile(inp_path) == False:
continue
need_reformat = 1
done = 0
try:
info = ffmpeg.probe(inp_path, cmd="ffprobe")
if (
info["streams"][0]["channels"] == 2
and info["streams"][0]["sample_rate"] == "44100"
):
if info["streams"][0]["channels"] == 2 and info["streams"][0]["sample_rate"] == "44100":
need_reformat = 0
pre_fun._path_audio_(
inp_path, save_root_ins, save_root_vocal, format0,is_hp3
)
pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3)
done = 1
except:
need_reformat = 1
@@ -96,21 +100,15 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
os.path.join(os.environ["TEMP"]),
os.path.basename(inp_path),
)
os.system(
f'ffmpeg -i "{inp_path}" -vn -acodec pcm_s16le -ac 2 -ar 44100 "{tmp_path}" -y'
)
os.system(f'ffmpeg -i "{inp_path}" -vn -acodec pcm_s16le -ac 2 -ar 44100 "{tmp_path}" -y')
inp_path = tmp_path
try:
if done == 0:
pre_fun._path_audio_(
inp_path, save_root_ins, save_root_vocal, format0,is_hp3
)
pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3)
infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos)
except:
infos.append(
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
)
infos.append("%s->%s" % (os.path.basename(inp_path), traceback.format_exc()))
yield "\n".join(infos)
except:
infos.append(traceback.format_exc())
@@ -130,80 +128,98 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
torch.cuda.empty_cache()
yield "\n".join(infos)
with gr.Blocks(title="UVR5 WebUI") as app:
gr.Markdown(
value=
i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
+ "<br>"
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
)
with gr.Group():
gr.Markdown(html_center(i18n("伴奏人声分离&去混响&去回声"),'h2'))
gr.Markdown(html_center(i18n("伴奏人声分离&去混响&去回声"), "h2"))
with gr.Group():
gr.Markdown(
value=html_left(i18n("人声伴奏分离批量处理, 使用UVR5模型。") + "<br>" + \
i18n("合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。")+ "<br>" + \
i18n("模型分为三类:") + "<br>" + \
i18n("1、保留人声不带和声的音频选这个对主人声保留比HP5更好。内置HP2和HP3两个模型HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点") + "<br>" + \
i18n("2、仅保留主人声带和声的音频选这个对主人声可能有削弱。内置HP5一个模型") + "<br>" + \
i18n("3、去混响、去延迟模型by FoxJoy") + "<br>" + \
i18n("(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;") + "<br>&emsp;" + \
i18n("(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底DeReverb额外去除混响可去除单声道混响但是对高频重的板式混响去不干净。") + "<br>" + \
i18n("去混响/去延迟,附:") + "<br>" + \
i18n("1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍") + "<br>" + \
i18n("2、MDX-Net-Dereverb模型挺慢的") + "<br>" + \
i18n("3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。"),'h4')
)
with gr.Row():
with gr.Column():
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
dir_wav_input = gr.Textbox(
label=i18n("输入待处理音频文件夹路径"),
placeholder="C:\\Users\\Desktop\\todo-songs",
)
wav_inputs = gr.File(
file_count="multiple", label=i18n("也可批量输入音频文件, 二选一, 优先读文件夹")
)
with gr.Column():
agg = gr.Slider(
minimum=0,
maximum=20,
step=1,
label=i18n("人声提取激进程度"),
value=10,
interactive=True,
visible=False, # 先不开放调整
)
opt_vocal_root = gr.Textbox(
label=i18n("指定输出主人声文件夹"), value="output/uvr5_opt"
)
opt_ins_root = gr.Textbox(
label=i18n("指定输出非主人声文件夹"), value="output/uvr5_opt"
)
format0 = gr.Radio(
label=i18n("导出文件格式"),
choices=["wav", "flac", "mp3", "m4a"],
value="flac",
interactive=True,
)
with gr.Column():
with gr.Row():
but2 = gr.Button(i18n("转换"), variant="primary")
with gr.Row():
vc_output4 = gr.Textbox(label=i18n("输出信息"),lines=3)
but2.click(
uvr,
[
model_choose,
dir_wav_input,
opt_vocal_root,
wav_inputs,
opt_ins_root,
agg,
format0,
],
[vc_output4],
api_name="uvr_convert",
gr.Markdown(
value=html_left(
i18n("人声伴奏分离批量处理, 使用UVR5模型。")
+ "<br>"
+ i18n(
"合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。"
)
app.queue().launch(#concurrency_count=511, max_size=1022
+ "<br>"
+ i18n("模型分为三类:")
+ "<br>"
+ i18n(
"1、保留人声不带和声的音频选这个对主人声保留比HP5更好。内置HP2和HP3两个模型HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点"
)
+ "<br>"
+ i18n("2、仅保留主人声带和声的音频选这个对主人声可能有削弱。内置HP5一个模型")
+ "<br>"
+ i18n("3、去混响、去延迟模型by FoxJoy")
+ "<br>"
+ i18n("(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;")
+ "<br>&emsp;"
+ i18n(
"(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底DeReverb额外去除混响可去除单声道混响但是对高频重的板式混响去不干净。"
)
+ "<br>"
+ i18n("去混响/去延迟,附:")
+ "<br>"
+ i18n("1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍")
+ "<br>"
+ i18n("2、MDX-Net-Dereverb模型挺慢的")
+ "<br>"
+ i18n("3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。"),
"h4",
)
)
with gr.Row():
with gr.Column():
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
dir_wav_input = gr.Textbox(
label=i18n("输入待处理音频文件夹路径"),
placeholder="C:\\Users\\Desktop\\todo-songs",
)
wav_inputs = gr.File(
file_count="multiple", label=i18n("也可批量输入音频文件, 二选一, 优先读文件夹")
)
with gr.Column():
agg = gr.Slider(
minimum=0,
maximum=20,
step=1,
label=i18n("人声提取激进程度"),
value=10,
interactive=True,
visible=False, # 先不开放调整
)
opt_vocal_root = gr.Textbox(label=i18n("指定输出主人声文件夹"), value="output/uvr5_opt")
opt_ins_root = gr.Textbox(label=i18n("指定输出非主人声文件夹"), value="output/uvr5_opt")
format0 = gr.Radio(
label=i18n("导出文件格式"),
choices=["wav", "flac", "mp3", "m4a"],
value="flac",
interactive=True,
)
with gr.Column():
with gr.Row():
but2 = gr.Button(i18n("转换"), variant="primary")
with gr.Row():
vc_output4 = gr.Textbox(label=i18n("输出信息"), lines=3)
but2.click(
uvr,
[
model_choose,
dir_wav_input,
opt_vocal_root,
wav_inputs,
opt_ins_root,
agg,
format0,
],
[vc_output4],
api_name="uvr_convert",
)
app.queue().launch( # concurrency_count=511, max_size=1022
server_name="0.0.0.0",
inbrowser=True,
share=is_share,