Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -5,24 +5,31 @@ import torchaudio
|
||||
import torch.utils.data
|
||||
import torchaudio.functional as aF
|
||||
|
||||
def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True):
|
||||
|
||||
def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True):
|
||||
hann_window = torch.hann_window(win_size).to(audio.device)
|
||||
stft_spec = torch.stft(audio, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
|
||||
center=center, pad_mode='reflect', normalized=False, return_complex=True)
|
||||
log_amp = torch.log(torch.abs(stft_spec)+1e-4)
|
||||
stft_spec = torch.stft(
|
||||
audio,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
return_complex=True,
|
||||
)
|
||||
log_amp = torch.log(torch.abs(stft_spec) + 1e-4)
|
||||
pha = torch.angle(stft_spec)
|
||||
|
||||
com = torch.stack((torch.exp(log_amp)*torch.cos(pha),
|
||||
torch.exp(log_amp)*torch.sin(pha)), dim=-1)
|
||||
com = torch.stack((torch.exp(log_amp) * torch.cos(pha), torch.exp(log_amp) * torch.sin(pha)), dim=-1)
|
||||
|
||||
return log_amp, pha, com
|
||||
|
||||
|
||||
def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True):
|
||||
|
||||
amp = torch.exp(log_amp)
|
||||
com = torch.complex(amp*torch.cos(pha), amp*torch.sin(pha))
|
||||
com = torch.complex(amp * torch.cos(pha), amp * torch.sin(pha))
|
||||
hann_window = torch.hann_window(win_size).to(com.device)
|
||||
audio = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
|
||||
|
||||
@@ -30,18 +37,28 @@ def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True):
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
|
||||
training_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0]
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
|
||||
|
||||
with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
|
||||
validation_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0]
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
|
||||
|
||||
return training_indexes, validation_indexes
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, training_indexes, wavs_dir, segment_size, hr_sampling_rate, lr_sampling_rate,
|
||||
split=True, shuffle=True, n_cache_reuse=1, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
training_indexes,
|
||||
wavs_dir,
|
||||
segment_size,
|
||||
hr_sampling_rate,
|
||||
lr_sampling_rate,
|
||||
split=True,
|
||||
shuffle=True,
|
||||
n_cache_reuse=1,
|
||||
device=None,
|
||||
):
|
||||
self.audio_indexes = training_indexes
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
@@ -59,7 +76,7 @@ class Dataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, index):
|
||||
filename = self.audio_indexes[index]
|
||||
if self._cache_ref_count == 0:
|
||||
audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + '.wav'))
|
||||
audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + ".wav"))
|
||||
self.cached_wav = audio
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
@@ -79,14 +96,13 @@ class Dataset(torch.utils.data.Dataset):
|
||||
if audio_hr.size(1) >= self.segment_size:
|
||||
max_audio_start = audio_hr.size(1) - self.segment_size
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio_hr = audio_hr[:, audio_start: audio_start+self.segment_size]
|
||||
audio_lr = audio_lr[:, audio_start: audio_start+self.segment_size]
|
||||
audio_hr = audio_hr[:, audio_start : audio_start + self.segment_size]
|
||||
audio_lr = audio_lr[:, audio_start : audio_start + self.segment_size]
|
||||
else:
|
||||
audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), 'constant')
|
||||
audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), 'constant')
|
||||
audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), "constant")
|
||||
audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), "constant")
|
||||
|
||||
return (audio_hr.squeeze(), audio_lr.squeeze())
|
||||
|
||||
def __len__(self):
|
||||
|
||||
return len(self.audio_indexes)
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
|
||||
|
||||
# from utils import init_weights, get_padding
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple, List
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
||||
|
||||
@@ -30,24 +36,24 @@ class ConvNeXtBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layer_scale_init_value= None,
|
||||
adanorm_num_embeddings = None,
|
||||
layer_scale_init_value=None,
|
||||
adanorm_num_embeddings=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, dim*3) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.pwconv1 = nn.Linear(dim, dim * 3) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(dim*3, dim)
|
||||
self.pwconv2 = nn.Linear(dim * 3, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x, cond_embedding_id = None) :
|
||||
def forward(self, x, cond_embedding_id=None):
|
||||
residual = x
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||
@@ -72,11 +78,11 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
super(APNet_BWE_Model, self).__init__()
|
||||
self.h = h
|
||||
self.adanorm_num_embeddings = None
|
||||
layer_scale_init_value = 1 / h.ConvNeXt_layers
|
||||
layer_scale_init_value = 1 / h.ConvNeXt_layers
|
||||
|
||||
self.conv_pre_mag = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.conv_pre_mag = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.conv_pre_pha = nn.Conv1d(h.n_fft//2+1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.conv_pre_pha = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
|
||||
self.convnext_mag = nn.ModuleList(
|
||||
@@ -104,9 +110,9 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.apply(self._init_weights)
|
||||
self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
|
||||
self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
|
||||
self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft//2+1)
|
||||
self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
@@ -114,7 +120,6 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, mag_nb, pha_nb):
|
||||
|
||||
x_mag = self.conv_pre_mag(mag_nb)
|
||||
x_pha = self.conv_pre_pha(pha_nb)
|
||||
x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2)
|
||||
@@ -134,11 +139,9 @@ class APNet_BWE_Model(torch.nn.Module):
|
||||
x_pha_i = self.linear_post_pha_i(x_pha)
|
||||
pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2)
|
||||
|
||||
com_wb = torch.stack((torch.exp(mag_wb)*torch.cos(pha_wb),
|
||||
torch.exp(mag_wb)*torch.sin(pha_wb)), dim=-1)
|
||||
|
||||
return mag_wb, pha_wb, com_wb
|
||||
com_wb = torch.stack((torch.exp(mag_wb) * torch.cos(pha_wb), torch.exp(mag_wb) * torch.sin(pha_wb)), dim=-1)
|
||||
|
||||
return mag_wb, pha_wb, com_wb
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
@@ -146,13 +149,15 @@ class DiscriminatorP(torch.nn.Module):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -160,13 +165,13 @@ class DiscriminatorP(torch.nn.Module):
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for i,l in enumerate(self.convs):
|
||||
for i, l in enumerate(self.convs):
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if i > 0:
|
||||
@@ -181,13 +186,15 @@ class DiscriminatorP(torch.nn.Module):
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
@@ -264,8 +271,8 @@ class DiscriminatorAR(nn.Module):
|
||||
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
x=x.squeeze(1)
|
||||
|
||||
x = x.squeeze(1)
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
@@ -358,8 +365,8 @@ class DiscriminatorPR(nn.Module):
|
||||
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
x=x.squeeze(1)
|
||||
|
||||
x = x.squeeze(1)
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
@@ -407,11 +414,11 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
@@ -420,35 +427,37 @@ def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
def phase_losses(phase_r, phase_g):
|
||||
|
||||
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
||||
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
|
||||
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
|
||||
|
||||
return ip_loss, gd_loss, iaf_loss
|
||||
|
||||
def anti_wrapping_function(x):
|
||||
|
||||
def anti_wrapping_function(x):
|
||||
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
||||
|
||||
|
||||
def stft_mag(audio, n_fft=2048, hop_length=512):
|
||||
hann_window = torch.hann_window(n_fft).to(audio.device)
|
||||
stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
|
||||
stft_mag = torch.abs(stft_spec)
|
||||
return(stft_mag)
|
||||
return stft_mag
|
||||
|
||||
|
||||
def cal_snr(pred, target):
|
||||
snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
|
||||
return snr
|
||||
|
||||
|
||||
def cal_lsd(pred, target):
|
||||
sp = torch.log10(stft_mag(pred).square().clamp(1e-8))
|
||||
st = torch.log10(stft_mag(target).square().clamp(1e-8))
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
import os
|
||||
|
||||
|
||||
def check_fw_local_models():
|
||||
'''
|
||||
"""
|
||||
启动时检查本地是否有 Faster Whisper 模型.
|
||||
'''
|
||||
"""
|
||||
model_size_list = [
|
||||
"tiny", "tiny.en",
|
||||
"base", "base.en",
|
||||
"small", "small.en",
|
||||
"medium", "medium.en",
|
||||
"large", "large-v1",
|
||||
"large-v2", "large-v3"]
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"large",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
]
|
||||
for i, size in enumerate(model_size_list):
|
||||
if os.path.exists(f'tools/asr/models/faster-whisper-{size}'):
|
||||
model_size_list[i] = size + '-local'
|
||||
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
|
||||
model_size_list[i] = size + "-local"
|
||||
return model_size_list
|
||||
|
||||
|
||||
asr_dict = {
|
||||
"达摩 ASR (中文)": {
|
||||
'lang': ['zh','yue'],
|
||||
'size': ['large'],
|
||||
'path': 'funasr_asr.py',
|
||||
'precision': ['float32']
|
||||
},
|
||||
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
||||
"Faster Whisper (多语种)": {
|
||||
'lang': ['auto', 'zh', 'en', 'ja', 'ko', 'yue'],
|
||||
'size': check_fw_local_models(),
|
||||
'path': 'fasterwhisper_asr.py',
|
||||
'precision': ['float32', 'float16', 'int8']
|
||||
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
|
||||
"size": check_fw_local_models(),
|
||||
"path": "fasterwhisper_asr.py",
|
||||
"precision": ["float32", "float16", "int8"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
import torch
|
||||
@@ -11,6 +11,7 @@ from tqdm import tqdm
|
||||
|
||||
from tools.asr.config import check_fw_local_models
|
||||
|
||||
# fmt: off
|
||||
language_code_list = [
|
||||
"af", "am", "ar", "as", "az",
|
||||
"ba", "be", "bg", "bn", "bo",
|
||||
@@ -32,82 +33,97 @@ language_code_list = [
|
||||
"te", "tg", "th", "tk", "tl",
|
||||
"tr", "tt", "uk", "ur", "uz",
|
||||
"vi", "yi", "yo", "zh", "yue",
|
||||
"auto"]
|
||||
"auto"]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def execute_asr(input_folder, output_folder, model_size, language, precision):
|
||||
if '-local' in model_size:
|
||||
if "-local" in model_size:
|
||||
model_size = model_size[:-6]
|
||||
model_path = f'tools/asr/models/faster-whisper-{model_size}'
|
||||
model_path = f"tools/asr/models/faster-whisper-{model_size}"
|
||||
else:
|
||||
model_path = model_size
|
||||
if language == 'auto':
|
||||
language = None #不设置语种由模型自动输出概率最高的语种
|
||||
print("loading faster whisper model:",model_size,model_path)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
if language == "auto":
|
||||
language = None # 不设置语种由模型自动输出概率最高的语种
|
||||
print("loading faster whisper model:", model_size, model_path)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
try:
|
||||
model = WhisperModel(model_path, device=device, compute_type=precision)
|
||||
except:
|
||||
return print(traceback.format_exc())
|
||||
|
||||
|
||||
input_file_names = os.listdir(input_folder)
|
||||
input_file_names.sort()
|
||||
|
||||
output = []
|
||||
output_file_name = os.path.basename(input_folder)
|
||||
|
||||
|
||||
for file_name in tqdm(input_file_names):
|
||||
try:
|
||||
file_path = os.path.join(input_folder, file_name)
|
||||
segments, info = model.transcribe(
|
||||
audio = file_path,
|
||||
beam_size = 5,
|
||||
vad_filter = True,
|
||||
vad_parameters = dict(min_silence_duration_ms=700),
|
||||
language = language)
|
||||
text = ''
|
||||
audio=file_path,
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=700),
|
||||
language=language,
|
||||
)
|
||||
text = ""
|
||||
|
||||
if info.language == "zh":
|
||||
print("检测为中文文本, 转 FunASR 处理")
|
||||
if("only_asr" not in globals()):
|
||||
from tools.asr.funasr_asr import only_asr #如果用英文就不需要导入下载模型
|
||||
if "only_asr" not in globals():
|
||||
from tools.asr.funasr_asr import only_asr # 如果用英文就不需要导入下载模型
|
||||
text = only_asr(file_path, language=info.language.lower())
|
||||
|
||||
if text == '':
|
||||
if text == "":
|
||||
for segment in segments:
|
||||
text += segment.text
|
||||
output.append(f"{file_path}|{output_file_name}|{info.language.upper()}|{text}")
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
output_folder = output_folder or "output/asr_opt"
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
|
||||
output_file_path = os.path.abspath(f"{output_folder}/{output_file_name}.list")
|
||||
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output))
|
||||
print(f"ASR 任务完成->标注文件路径: {output_file_path}\n")
|
||||
return output_file_path
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input_folder", type=str, required=True,
|
||||
help="Path to the folder containing WAV files.")
|
||||
parser.add_argument("-o", "--output_folder", type=str, required=True,
|
||||
help="Output folder to store transcriptions.")
|
||||
parser.add_argument("-s", "--model_size", type=str, default='large-v3',
|
||||
choices=check_fw_local_models(),
|
||||
help="Model Size of Faster Whisper")
|
||||
parser.add_argument("-l", "--language", type=str, default='ja',
|
||||
choices=language_code_list,
|
||||
help="Language of the audio files.")
|
||||
parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32','int8'],
|
||||
help="fp16, int8 or fp32")
|
||||
parser.add_argument(
|
||||
"-i", "--input_folder", type=str, required=True, help="Path to the folder containing WAV files."
|
||||
)
|
||||
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output folder to store transcriptions.")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="large-v3",
|
||||
choices=check_fw_local_models(),
|
||||
help="Model Size of Faster Whisper",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--language", type=str, default="ja", choices=language_code_list, help="Language of the audio files."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int8"],
|
||||
help="fp16, int8 or fp32",
|
||||
)
|
||||
|
||||
cmd = parser.parse_args()
|
||||
output_file_path = execute_asr(
|
||||
input_folder = cmd.input_folder,
|
||||
output_folder = cmd.output_folder,
|
||||
model_size = cmd.model_size,
|
||||
language = cmd.language,
|
||||
precision = cmd.precision,
|
||||
input_folder=cmd.input_folder,
|
||||
output_folder=cmd.output_folder,
|
||||
model_size=cmd.model_size,
|
||||
language=cmd.language,
|
||||
precision=cmd.precision,
|
||||
)
|
||||
|
||||
@@ -9,31 +9,41 @@ import traceback
|
||||
from funasr import AutoModel
|
||||
from tqdm import tqdm
|
||||
|
||||
funasr_models = {} # 存储模型避免重复加载
|
||||
funasr_models = {} # 存储模型避免重复加载
|
||||
|
||||
|
||||
def only_asr(input_file, language):
|
||||
try:
|
||||
model = create_model(language)
|
||||
text = model.generate(input=input_file)[0]["text"]
|
||||
except:
|
||||
text = ''
|
||||
text = ""
|
||||
print(traceback.format_exc())
|
||||
return text
|
||||
|
||||
|
||||
def create_model(language="zh"):
|
||||
path_vad = 'tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch'
|
||||
path_punc = 'tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch'
|
||||
path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
vad_model_revision = punc_model_revision = "v2.0.4"
|
||||
|
||||
if language == "zh":
|
||||
path_asr = 'tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
|
||||
path_asr = path_asr if os.path.exists(path_asr) else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
path_asr = (
|
||||
path_asr
|
||||
if os.path.exists(path_asr)
|
||||
else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
)
|
||||
model_revision = "v2.0.4"
|
||||
elif language == "yue":
|
||||
path_asr = 'tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online'
|
||||
path_asr = path_asr if os.path.exists(path_asr) else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||
path_asr = (
|
||||
path_asr
|
||||
if os.path.exists(path_asr)
|
||||
else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||
)
|
||||
model_revision = "master"
|
||||
path_vad = path_punc = None
|
||||
vad_model_revision = punc_model_revision = None
|
||||
@@ -45,25 +55,26 @@ def create_model(language="zh"):
|
||||
return funasr_models[language]
|
||||
else:
|
||||
model = AutoModel(
|
||||
model = path_asr,
|
||||
model_revision = model_revision,
|
||||
vad_model = path_vad,
|
||||
vad_model_revision = vad_model_revision,
|
||||
punc_model = path_punc,
|
||||
punc_model_revision = punc_model_revision,
|
||||
model=path_asr,
|
||||
model_revision=model_revision,
|
||||
vad_model=path_vad,
|
||||
vad_model_revision=vad_model_revision,
|
||||
punc_model=path_punc,
|
||||
punc_model_revision=punc_model_revision,
|
||||
)
|
||||
print(f"FunASR 模型加载完成: {language.upper()}")
|
||||
|
||||
funasr_models[language] = model
|
||||
return model
|
||||
|
||||
|
||||
def execute_asr(input_folder, output_folder, model_size, language):
|
||||
input_file_names = os.listdir(input_folder)
|
||||
input_file_names.sort()
|
||||
|
||||
|
||||
output = []
|
||||
output_file_name = os.path.basename(input_folder)
|
||||
|
||||
|
||||
model = create_model(language)
|
||||
|
||||
for file_name in tqdm(input_file_names):
|
||||
@@ -77,29 +88,31 @@ def execute_asr(input_folder, output_folder, model_size, language):
|
||||
|
||||
output_folder = output_folder or "output/asr_opt"
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
|
||||
output_file_path = os.path.abspath(f"{output_folder}/{output_file_name}.list")
|
||||
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output))
|
||||
print(f"ASR 任务完成->标注文件路径: {output_file_path}\n")
|
||||
return output_file_path
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input_folder", type=str, required=True,
|
||||
help="Path to the folder containing WAV files.")
|
||||
parser.add_argument("-o", "--output_folder", type=str, required=True,
|
||||
help="Output folder to store transcriptions.")
|
||||
parser.add_argument("-s", "--model_size", type=str, default='large',
|
||||
help="Model Size of FunASR is Large")
|
||||
parser.add_argument("-l", "--language", type=str, default='zh', choices=['zh','yue','auto'],
|
||||
help="Language of the audio files.")
|
||||
parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
|
||||
help="fp16 or fp32")#还没接入
|
||||
parser.add_argument(
|
||||
"-i", "--input_folder", type=str, required=True, help="Path to the folder containing WAV files."
|
||||
)
|
||||
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output folder to store transcriptions.")
|
||||
parser.add_argument("-s", "--model_size", type=str, default="large", help="Model Size of FunASR is Large")
|
||||
parser.add_argument(
|
||||
"-l", "--language", type=str, default="zh", choices=["zh", "yue", "auto"], help="Language of the audio files."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--precision", type=str, default="float16", choices=["float16", "float32"], help="fp16 or fp32"
|
||||
) # 还没接入
|
||||
cmd = parser.parse_args()
|
||||
execute_asr(
|
||||
input_folder = cmd.input_folder,
|
||||
output_folder = cmd.output_folder,
|
||||
model_size = cmd.model_size,
|
||||
language = cmd.language,
|
||||
input_folder=cmd.input_folder,
|
||||
output_folder=cmd.output_folder,
|
||||
model_size=cmd.model_size,
|
||||
language=cmd.language,
|
||||
)
|
||||
|
||||
@@ -1,50 +1,44 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import sys,os
|
||||
import traceback
|
||||
AP_BWE_main_dir_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'AP_BWE_main')
|
||||
import sys
|
||||
import os
|
||||
|
||||
AP_BWE_main_dir_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "AP_BWE_main")
|
||||
sys.path.append(AP_BWE_main_dir_path)
|
||||
import glob
|
||||
import argparse
|
||||
import json
|
||||
from re import S
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import time
|
||||
import torchaudio.functional as aF
|
||||
# from attrdict import AttrDict####will be bug in py3.10
|
||||
|
||||
from datasets1.dataset import amp_pha_stft, amp_pha_istft
|
||||
from models.model import APNet_BWE_Model
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
from rich.progress import track
|
||||
|
||||
class AP_BWE():
|
||||
def __init__(self,device,DictToAttrRecursive,checkpoint_file=None):
|
||||
if checkpoint_file==None:
|
||||
checkpoint_file="%s/24kto48k/g_24kto48k.zip"%(AP_BWE_main_dir_path)
|
||||
if os.path.exists(checkpoint_file)==False:
|
||||
|
||||
class AP_BWE:
|
||||
def __init__(self, device, DictToAttrRecursive, checkpoint_file=None):
|
||||
if checkpoint_file == None:
|
||||
checkpoint_file = "%s/24kto48k/g_24kto48k.zip" % (AP_BWE_main_dir_path)
|
||||
if os.path.exists(checkpoint_file) == False:
|
||||
raise FileNotFoundError
|
||||
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
|
||||
with open(config_file) as f:data = f.read()
|
||||
config_file = os.path.join(os.path.split(checkpoint_file)[0], "config.json")
|
||||
with open(config_file) as f:
|
||||
data = f.read()
|
||||
json_config = json.loads(data)
|
||||
# h = AttrDict(json_config)
|
||||
h = DictToAttrRecursive(json_config)
|
||||
model = APNet_BWE_Model(h).to(device)
|
||||
state_dict = torch.load(checkpoint_file,map_location="cpu",weights_only=False)
|
||||
model.load_state_dict(state_dict['generator'])
|
||||
state_dict = torch.load(checkpoint_file, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(state_dict["generator"])
|
||||
model.eval()
|
||||
self.device=device
|
||||
self.model=model
|
||||
self.h=h
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.h = h
|
||||
|
||||
def to(self, *arg, **kwargs):
|
||||
self.model.to(*arg, **kwargs)
|
||||
self.device = self.model.conv_pre_mag.weight.device
|
||||
return self
|
||||
|
||||
def __call__(self, audio,orig_sampling_rate):
|
||||
def __call__(self, audio, orig_sampling_rate):
|
||||
with torch.no_grad():
|
||||
# audio, orig_sampling_rate = torchaudio.load(inp_path)
|
||||
# audio = audio.to(self.device)
|
||||
@@ -53,4 +47,4 @@ class AP_BWE():
|
||||
amp_wb_g, pha_wb_g, com_wb_g = self.model(amp_nb, pha_nb)
|
||||
audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, self.h.n_fft, self.h.hop_size, self.h.win_size)
|
||||
# sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16')
|
||||
return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate
|
||||
return audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate
|
||||
|
||||
@@ -1,33 +1,38 @@
|
||||
import os,argparse
|
||||
import os
|
||||
import argparse
|
||||
import traceback
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from tqdm import tqdm
|
||||
|
||||
path_denoise = 'tools/denoise-model/speech_frcrn_ans_cirm_16k'
|
||||
path_denoise = path_denoise if os.path.exists(path_denoise) else "damo/speech_frcrn_ans_cirm_16k"
|
||||
ans = pipeline(Tasks.acoustic_noise_suppression,model=path_denoise)
|
||||
def execute_denoise(input_folder,output_folder):
|
||||
os.makedirs(output_folder,exist_ok=True)
|
||||
path_denoise = "tools/denoise-model/speech_frcrn_ans_cirm_16k"
|
||||
path_denoise = path_denoise if os.path.exists(path_denoise) else "damo/speech_frcrn_ans_cirm_16k"
|
||||
ans = pipeline(Tasks.acoustic_noise_suppression, model=path_denoise)
|
||||
|
||||
|
||||
def execute_denoise(input_folder, output_folder):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
# print(input_folder)
|
||||
# print(list(os.listdir(input_folder).sort()))
|
||||
for name in tqdm(os.listdir(input_folder)):
|
||||
try:
|
||||
ans("%s/%s"%(input_folder,name),output_path='%s/%s'%(output_folder,name))
|
||||
ans("%s/%s" % (input_folder, name), output_path="%s/%s" % (output_folder, name))
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input_folder", type=str, required=True,
|
||||
help="Path to the folder containing WAV files.")
|
||||
parser.add_argument("-o", "--output_folder", type=str, required=True,
|
||||
help="Output folder to store transcriptions.")
|
||||
parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
|
||||
help="fp16 or fp32")#还没接入
|
||||
parser.add_argument(
|
||||
"-i", "--input_folder", type=str, required=True, help="Path to the folder containing WAV files."
|
||||
)
|
||||
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output folder to store transcriptions.")
|
||||
parser.add_argument(
|
||||
"-p", "--precision", type=str, default="float16", choices=["float16", "float32"], help="fp16 or fp32"
|
||||
) # 还没接入
|
||||
cmd = parser.parse_args()
|
||||
execute_denoise(
|
||||
input_folder = cmd.input_folder,
|
||||
output_folder = cmd.output_folder,
|
||||
)
|
||||
input_folder=cmd.input_folder,
|
||||
output_folder=cmd.output_folder,
|
||||
)
|
||||
|
||||
@@ -2,23 +2,27 @@ import json
|
||||
import locale
|
||||
import os
|
||||
|
||||
I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale')
|
||||
I18N_JSON_DIR: os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), "locale")
|
||||
|
||||
|
||||
def load_language_list(language):
|
||||
with open(os.path.join(I18N_JSON_DIR, f"{language}.json"), "r", encoding="utf-8") as f:
|
||||
language_list = json.load(f)
|
||||
return language_list
|
||||
|
||||
|
||||
def scan_language_list():
|
||||
language_list = []
|
||||
for name in os.listdir(I18N_JSON_DIR):
|
||||
if name.endswith(".json"):language_list.append(name.split('.')[0])
|
||||
if name.endswith(".json"):
|
||||
language_list.append(name.split(".")[0])
|
||||
return language_list
|
||||
|
||||
|
||||
class I18nAuto:
|
||||
def __init__(self, language=None):
|
||||
if language in ["Auto", None]:
|
||||
language = locale.getdefaultlocale()[0]
|
||||
language = locale.getdefaultlocale()[0]
|
||||
# getlocale can't identify the system's language ((None, None))
|
||||
if not os.path.exists(os.path.join(I18N_JSON_DIR, f"{language}.json")):
|
||||
language = "en_US"
|
||||
@@ -31,6 +35,7 @@ class I18nAuto:
|
||||
def __repr__(self):
|
||||
return "Use Language: " + self.language
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
i18n = I18nAuto(language='en_US')
|
||||
print(i18n)
|
||||
i18n = I18nAuto(language="en_US")
|
||||
print(i18n)
|
||||
|
||||
@@ -4,21 +4,18 @@ import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale')
|
||||
DEFAULT_LANGUAGE: str = "zh_CN" # 默认语言
|
||||
TITLE_LEN : int = 60 # 标题显示长度
|
||||
KEY_LEN : int = 30 # 键名显示长度
|
||||
SHOW_KEYS : bool = False # 是否显示键信息
|
||||
SORT_KEYS : bool = False # 是否按全局键名写入文件
|
||||
I18N_JSON_DIR: os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), "locale")
|
||||
DEFAULT_LANGUAGE: str = "zh_CN" # 默认语言
|
||||
TITLE_LEN: int = 60 # 标题显示长度
|
||||
KEY_LEN: int = 30 # 键名显示长度
|
||||
SHOW_KEYS: bool = False # 是否显示键信息
|
||||
SORT_KEYS: bool = False # 是否按全局键名写入文件
|
||||
|
||||
|
||||
def extract_i18n_strings(node):
|
||||
i18n_strings = []
|
||||
|
||||
if (
|
||||
isinstance(node, ast.Call)
|
||||
and isinstance(node.func, ast.Name)
|
||||
and node.func.id == "i18n"
|
||||
):
|
||||
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "i18n":
|
||||
for arg in node.args:
|
||||
if isinstance(arg, ast.Str):
|
||||
i18n_strings.append(arg.s)
|
||||
@@ -28,6 +25,7 @@ def extract_i18n_strings(node):
|
||||
|
||||
return i18n_strings
|
||||
|
||||
|
||||
def scan_i18n_strings():
|
||||
"""
|
||||
scan the directory for all .py files (recursively)
|
||||
@@ -43,7 +41,7 @@ def scan_i18n_strings():
|
||||
if "I18nAuto" in code:
|
||||
tree = ast.parse(code)
|
||||
i18n_strings = extract_i18n_strings(tree)
|
||||
print(f"{filename.ljust(KEY_LEN*3//2)}: {len(i18n_strings)}")
|
||||
print(f"{filename.ljust(KEY_LEN * 3 // 2)}: {len(i18n_strings)}")
|
||||
if SHOW_KEYS:
|
||||
print("\n".join([s for s in i18n_strings]))
|
||||
strings.extend(i18n_strings)
|
||||
@@ -51,9 +49,10 @@ def scan_i18n_strings():
|
||||
print(f"\033[31m[Failed] Error occur at {filename}: {e}\033[0m")
|
||||
|
||||
code_keys = set(strings)
|
||||
print(f"{'Total Unique'.ljust(KEY_LEN*3//2)}: {len(code_keys)}")
|
||||
print(f"{'Total Unique'.ljust(KEY_LEN * 3 // 2)}: {len(code_keys)}")
|
||||
return code_keys
|
||||
|
||||
|
||||
def update_i18n_json(json_file, standard_keys):
|
||||
standard_keys = sorted(standard_keys)
|
||||
print(f" Process {json_file} ".center(TITLE_LEN, "="))
|
||||
@@ -89,8 +88,10 @@ def update_i18n_json(json_file, standard_keys):
|
||||
sorted(
|
||||
json_data.items(),
|
||||
key=lambda x: (
|
||||
list(standard_keys).index(x[0]) if x[0] in standard_keys and not x[1].startswith('#!') else len(json_data),
|
||||
)
|
||||
list(standard_keys).index(x[0])
|
||||
if x[0] in standard_keys and not x[1].startswith("#!")
|
||||
else len(json_data),
|
||||
),
|
||||
)
|
||||
)
|
||||
# 打印处理后的 JSON 条目数
|
||||
@@ -111,21 +112,26 @@ def update_i18n_json(json_file, standard_keys):
|
||||
# 打印是否有重复的值
|
||||
for value, keys in duplicate_items.items():
|
||||
if len(keys) > 1:
|
||||
print("\n".join([f"\033[31m{'[Failed] Duplicate Value'.ljust(KEY_LEN)}: {key} -> {value}\033[0m" for key in keys]))
|
||||
print(
|
||||
"\n".join(
|
||||
[f"\033[31m{'[Failed] Duplicate Value'.ljust(KEY_LEN)}: {key} -> {value}\033[0m" for key in keys]
|
||||
)
|
||||
)
|
||||
|
||||
if num_miss_translation > 0:
|
||||
print(f"\033[31m{'[Failed] Missing Translation'.ljust(KEY_LEN)}: {num_miss_translation}\033[0m")
|
||||
else:
|
||||
print(f"\033[32m[Passed] All Keys Translated\033[0m")
|
||||
print("\033[32m[Passed] All Keys Translated\033[0m")
|
||||
# 将处理后的结果写入 JSON 文件
|
||||
with open(json_file, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4, sort_keys=SORT_KEYS)
|
||||
f.write("\n")
|
||||
print(f" Updated {json_file} ".center(TITLE_LEN, "=") + '\n')
|
||||
print(f" Updated {json_file} ".center(TITLE_LEN, "=") + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
code_keys = scan_i18n_strings()
|
||||
for json_file in os.listdir(I18N_JSON_DIR):
|
||||
if json_file.endswith(r".json"):
|
||||
json_file = os.path.join(I18N_JSON_DIR, json_file)
|
||||
update_i18n_json(json_file, code_keys)
|
||||
update_i18n_json(json_file, code_keys)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import platform,os,traceback
|
||||
import os
|
||||
import traceback
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
import pandas as pd
|
||||
i18n = I18nAuto(language=os.environ.get('language','Auto'))
|
||||
|
||||
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
||||
|
||||
|
||||
def load_audio(file, sr):
|
||||
try:
|
||||
@@ -13,45 +16,49 @@ def load_audio(file, sr):
|
||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||
file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车
|
||||
if os.path.exists(file) == False:
|
||||
raise RuntimeError(
|
||||
"You input a wrong audio path that does not exists, please fix it!"
|
||||
)
|
||||
raise RuntimeError("You input a wrong audio path that does not exists, please fix it!")
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(i18n("音频加载失败"))
|
||||
|
||||
return np.frombuffer(out, np.float32).flatten()
|
||||
|
||||
|
||||
def clean_path(path_str:str):
|
||||
if path_str.endswith(('\\','/')):
|
||||
def clean_path(path_str: str):
|
||||
if path_str.endswith(("\\", "/")):
|
||||
return clean_path(path_str[0:-1])
|
||||
path_str = path_str.replace('/', os.sep).replace('\\', os.sep)
|
||||
return path_str.strip(" \'\n\"\u202a")#path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
||||
path_str = path_str.replace("/", os.sep).replace("\\", os.sep)
|
||||
return path_str.strip(
|
||||
" '\n\"\u202a"
|
||||
) # path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
||||
|
||||
|
||||
def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing=False):
|
||||
files_status=[]
|
||||
def check_for_existance(file_list: list = None, is_train=False, is_dataset_processing=False):
|
||||
files_status = []
|
||||
if is_train == True and file_list:
|
||||
file_list.append(os.path.join(file_list[0],'2-name2text.txt'))
|
||||
file_list.append(os.path.join(file_list[0],'3-bert'))
|
||||
file_list.append(os.path.join(file_list[0],'4-cnhubert'))
|
||||
file_list.append(os.path.join(file_list[0],'5-wav32k'))
|
||||
file_list.append(os.path.join(file_list[0],'6-name2semantic.tsv'))
|
||||
file_list.append(os.path.join(file_list[0], "2-name2text.txt"))
|
||||
file_list.append(os.path.join(file_list[0], "3-bert"))
|
||||
file_list.append(os.path.join(file_list[0], "4-cnhubert"))
|
||||
file_list.append(os.path.join(file_list[0], "5-wav32k"))
|
||||
file_list.append(os.path.join(file_list[0], "6-name2semantic.tsv"))
|
||||
for file in file_list:
|
||||
if os.path.exists(file):files_status.append(True)
|
||||
else:files_status.append(False)
|
||||
if sum(files_status)!=len(files_status):
|
||||
if os.path.exists(file):
|
||||
files_status.append(True)
|
||||
else:
|
||||
files_status.append(False)
|
||||
if sum(files_status) != len(files_status):
|
||||
if is_train:
|
||||
for file,status in zip(file_list,files_status):
|
||||
if status:pass
|
||||
else:gr.Warning(file)
|
||||
gr.Warning(i18n('以下文件或文件夹不存在'))
|
||||
for file, status in zip(file_list, files_status):
|
||||
if status:
|
||||
pass
|
||||
else:
|
||||
gr.Warning(file)
|
||||
gr.Warning(i18n("以下文件或文件夹不存在"))
|
||||
return False
|
||||
elif is_dataset_processing:
|
||||
if files_status[0]:
|
||||
@@ -60,56 +67,63 @@ def check_for_existance(file_list:list=None,is_train=False,is_dataset_processing
|
||||
gr.Warning(file_list[0])
|
||||
elif not files_status[1] and file_list[1]:
|
||||
gr.Warning(file_list[1])
|
||||
gr.Warning(i18n('以下文件或文件夹不存在'))
|
||||
gr.Warning(i18n("以下文件或文件夹不存在"))
|
||||
return False
|
||||
else:
|
||||
if file_list[0]:
|
||||
gr.Warning(file_list[0])
|
||||
gr.Warning(i18n('以下文件或文件夹不存在'))
|
||||
gr.Warning(i18n("以下文件或文件夹不存在"))
|
||||
else:
|
||||
gr.Warning(i18n('路径不能为空'))
|
||||
gr.Warning(i18n("路径不能为空"))
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_details(path_list=None,is_train=False,is_dataset_processing=False):
|
||||
|
||||
def check_details(path_list=None, is_train=False, is_dataset_processing=False):
|
||||
if is_dataset_processing:
|
||||
list_path, audio_path = path_list
|
||||
if (not list_path.endswith('.list')):
|
||||
gr.Warning(i18n('请填入正确的List路径'))
|
||||
if not list_path.endswith(".list"):
|
||||
gr.Warning(i18n("请填入正确的List路径"))
|
||||
return
|
||||
if audio_path:
|
||||
if not os.path.isdir(audio_path):
|
||||
gr.Warning(i18n('请填入正确的音频文件夹路径'))
|
||||
gr.Warning(i18n("请填入正确的音频文件夹路径"))
|
||||
return
|
||||
with open(list_path,"r",encoding="utf8")as f:
|
||||
line=f.readline().strip("\n").split("\n")
|
||||
with open(list_path, "r", encoding="utf8") as f:
|
||||
line = f.readline().strip("\n").split("\n")
|
||||
wav_name, _, __, ___ = line[0].split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
if (audio_path != "" and audio_path != None):
|
||||
wav_name = clean_path(wav_name)
|
||||
if audio_path != "" and audio_path != None:
|
||||
wav_name = os.path.basename(wav_name)
|
||||
wav_path = "%s/%s"%(audio_path, wav_name)
|
||||
wav_path = "%s/%s" % (audio_path, wav_name)
|
||||
else:
|
||||
wav_path=wav_name
|
||||
wav_path = wav_name
|
||||
if os.path.exists(wav_path):
|
||||
...
|
||||
else:
|
||||
gr.Warning(i18n('路径错误'))
|
||||
gr.Warning(i18n("路径错误"))
|
||||
return
|
||||
if is_train:
|
||||
path_list.append(os.path.join(path_list[0],'2-name2text.txt'))
|
||||
path_list.append(os.path.join(path_list[0],'4-cnhubert'))
|
||||
path_list.append(os.path.join(path_list[0],'5-wav32k'))
|
||||
path_list.append(os.path.join(path_list[0],'6-name2semantic.tsv'))
|
||||
path_list.append(os.path.join(path_list[0], "2-name2text.txt"))
|
||||
path_list.append(os.path.join(path_list[0], "4-cnhubert"))
|
||||
path_list.append(os.path.join(path_list[0], "5-wav32k"))
|
||||
path_list.append(os.path.join(path_list[0], "6-name2semantic.tsv"))
|
||||
phone_path, hubert_path, wav_path, semantic_path = path_list[1:]
|
||||
with open(phone_path,'r',encoding='utf-8') as f:
|
||||
if f.read(1):...
|
||||
else:gr.Warning(i18n('缺少音素数据集'))
|
||||
if os.listdir(hubert_path):...
|
||||
else:gr.Warning(i18n('缺少Hubert数据集'))
|
||||
if os.listdir(wav_path):...
|
||||
else:gr.Warning(i18n('缺少音频数据集'))
|
||||
df = pd.read_csv(
|
||||
semantic_path, delimiter="\t", encoding="utf-8"
|
||||
)
|
||||
if len(df) >= 1:...
|
||||
else:gr.Warning(i18n('缺少语义数据集'))
|
||||
with open(phone_path, "r", encoding="utf-8") as f:
|
||||
if f.read(1):
|
||||
...
|
||||
else:
|
||||
gr.Warning(i18n("缺少音素数据集"))
|
||||
if os.listdir(hubert_path):
|
||||
...
|
||||
else:
|
||||
gr.Warning(i18n("缺少Hubert数据集"))
|
||||
if os.listdir(wav_path):
|
||||
...
|
||||
else:
|
||||
gr.Warning(i18n("缺少音频数据集"))
|
||||
df = pd.read_csv(semantic_path, delimiter="\t", encoding="utf-8")
|
||||
if len(df) >= 1:
|
||||
...
|
||||
else:
|
||||
gr.Warning(i18n("缺少语义数据集"))
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
import os,sys,numpy as np
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import traceback
|
||||
from scipy.io import wavfile
|
||||
|
||||
# parent_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
# sys.path.append(parent_directory)
|
||||
from tools.my_utils import load_audio
|
||||
from slicer2 import Slicer
|
||||
|
||||
def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,i_part,all_part):
|
||||
os.makedirs(opt_root,exist_ok=True)
|
||||
|
||||
def slice(inp, opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, i_part, all_part):
|
||||
os.makedirs(opt_root, exist_ok=True)
|
||||
if os.path.isfile(inp):
|
||||
input=[inp]
|
||||
input = [inp]
|
||||
elif os.path.isdir(inp):
|
||||
input=[os.path.join(inp, name) for name in sorted(list(os.listdir(inp)))]
|
||||
input = [os.path.join(inp, name) for name in sorted(list(os.listdir(inp)))]
|
||||
else:
|
||||
return "输入路径存在但既不是文件也不是文件夹"
|
||||
slicer = Slicer(
|
||||
sr=32000, # 长音频采样率
|
||||
threshold= int(threshold), # 音量小于这个值视作静音的备选切割点
|
||||
min_length= int(min_length), # 每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值
|
||||
min_interval= int(min_interval), # 最短切割间隔
|
||||
hop_size= int(hop_size), # 怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)
|
||||
max_sil_kept= int(max_sil_kept), # 切完后静音最多留多长
|
||||
threshold=int(threshold), # 音量小于这个值视作静音的备选切割点
|
||||
min_length=int(min_length), # 每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值
|
||||
min_interval=int(min_interval), # 最短切割间隔
|
||||
hop_size=int(hop_size), # 怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)
|
||||
max_sil_kept=int(max_sil_kept), # 切完后静音最多留多长
|
||||
)
|
||||
_max=float(_max)
|
||||
alpha=float(alpha)
|
||||
for inp_path in input[int(i_part)::int(all_part)]:
|
||||
_max = float(_max)
|
||||
alpha = float(alpha)
|
||||
for inp_path in input[int(i_part) :: int(all_part)]:
|
||||
# print(inp_path)
|
||||
try:
|
||||
name = os.path.basename(inp_path)
|
||||
@@ -32,7 +36,8 @@ def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_
|
||||
# print(audio.shape)
|
||||
for chunk, start, end in slicer.slice(audio): # start和end是帧数
|
||||
tmp_max = np.abs(chunk).max()
|
||||
if(tmp_max>1):chunk/=tmp_max
|
||||
if tmp_max > 1:
|
||||
chunk /= tmp_max
|
||||
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
||||
wavfile.write(
|
||||
"%s/%s_%010d_%010d.wav" % (opt_root, name, start, end),
|
||||
@@ -41,8 +46,8 @@ def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_
|
||||
(chunk * 32767).astype(np.int16),
|
||||
)
|
||||
except:
|
||||
print(inp_path,"->fail->",traceback.format_exc())
|
||||
print(inp_path, "->fail->", traceback.format_exc())
|
||||
return "执行完毕,请检查输出文件"
|
||||
|
||||
print(slice(*sys.argv[1:]))
|
||||
|
||||
print(slice(*sys.argv[1:]))
|
||||
|
||||
@@ -46,13 +46,9 @@ class Slicer:
|
||||
max_sil_kept: int = 5000,
|
||||
):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
||||
)
|
||||
raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
|
||||
if not max_sil_kept >= hop_size:
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
||||
)
|
||||
raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
|
||||
min_interval = sr * min_interval / 1000
|
||||
self.threshold = 10 ** (threshold / 20.0)
|
||||
self.hop_size = round(sr * hop_size / 1000)
|
||||
@@ -63,13 +59,9 @@ class Slicer:
|
||||
|
||||
def _apply_slice(self, waveform, begin, end):
|
||||
if len(waveform.shape) > 1:
|
||||
return waveform[
|
||||
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
||||
]
|
||||
return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
|
||||
else:
|
||||
return waveform[
|
||||
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
||||
]
|
||||
return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
|
||||
|
||||
# @timeit
|
||||
def slice(self, waveform):
|
||||
@@ -79,9 +71,7 @@ class Slicer:
|
||||
samples = waveform
|
||||
if samples.shape[0] <= self.min_length:
|
||||
return [waveform]
|
||||
rms_list = get_rms(
|
||||
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
||||
).squeeze(0)
|
||||
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
sil_tags = []
|
||||
silence_start = None
|
||||
clip_start = 0
|
||||
@@ -97,10 +87,7 @@ class Slicer:
|
||||
continue
|
||||
# Clear recorded silence start if interval is not enough or clip is too short
|
||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||
need_slice_middle = (
|
||||
i - silence_start >= self.min_interval
|
||||
and i - clip_start >= self.min_length
|
||||
)
|
||||
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
||||
if not is_leading_silence and not need_slice_middle:
|
||||
silence_start = None
|
||||
continue
|
||||
@@ -113,21 +100,10 @@ class Slicer:
|
||||
sil_tags.append((pos, pos))
|
||||
clip_start = pos
|
||||
elif i - silence_start <= self.max_sil_kept * 2:
|
||||
pos = rms_list[
|
||||
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
||||
pos += i - self.max_sil_kept
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
clip_start = pos_r
|
||||
@@ -135,17 +111,8 @@ class Slicer:
|
||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||
clip_start = max(pos_r, pos)
|
||||
else:
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
else:
|
||||
@@ -154,28 +121,33 @@ class Slicer:
|
||||
silence_start = None
|
||||
# Deal with trailing silence.
|
||||
total_frames = rms_list.shape[0]
|
||||
if (
|
||||
silence_start is not None
|
||||
and total_frames - silence_start >= self.min_interval
|
||||
):
|
||||
if silence_start is not None and total_frames - silence_start >= self.min_interval:
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
# Apply and return slices.
|
||||
####音频+起始时间+终止时间
|
||||
if len(sil_tags) == 0:
|
||||
return [[waveform,0,int(total_frames*self.hop_size)]]
|
||||
return [[waveform, 0, int(total_frames * self.hop_size)]]
|
||||
else:
|
||||
chunks = []
|
||||
if sil_tags[0][0] > 0:
|
||||
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
|
||||
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
|
||||
for i in range(len(sil_tags) - 1):
|
||||
chunks.append(
|
||||
[self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
|
||||
[
|
||||
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
|
||||
int(sil_tags[i][1] * self.hop_size),
|
||||
int(sil_tags[i + 1][0] * self.hop_size),
|
||||
]
|
||||
)
|
||||
if sil_tags[-1][1] < total_frames:
|
||||
chunks.append(
|
||||
[self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
|
||||
[
|
||||
self._apply_slice(waveform, sil_tags[-1][1], total_frames),
|
||||
int(sil_tags[-1][1] * self.hop_size),
|
||||
int(total_frames * self.hop_size),
|
||||
]
|
||||
)
|
||||
return chunks
|
||||
|
||||
@@ -189,9 +161,7 @@ def main():
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("audio", type=str, help="The audio to be sliced")
|
||||
parser.add_argument(
|
||||
"--out", type=str, help="Output directory of the sliced audio clips"
|
||||
)
|
||||
parser.add_argument("--out", type=str, help="Output directory of the sliced audio clips")
|
||||
parser.add_argument(
|
||||
"--db_thresh",
|
||||
type=float,
|
||||
@@ -249,8 +219,7 @@ def main():
|
||||
soundfile.write(
|
||||
os.path.join(
|
||||
out,
|
||||
f"%s_%d.wav"
|
||||
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
|
||||
"%s_%d.wav" % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
|
||||
),
|
||||
chunk,
|
||||
sr,
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import argparse,os
|
||||
import argparse
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
analytics.version_check = lambda:None
|
||||
except:...
|
||||
|
||||
analytics.version_check = lambda: None
|
||||
except:
|
||||
...
|
||||
|
||||
import librosa
|
||||
import gradio as gr
|
||||
@@ -33,15 +35,10 @@ def reload_data(index, batch):
|
||||
g_index = index
|
||||
global g_batch
|
||||
g_batch = batch
|
||||
datas = g_data_json[index:index+batch]
|
||||
datas = g_data_json[index : index + batch]
|
||||
output = []
|
||||
for d in datas:
|
||||
output.append(
|
||||
{
|
||||
g_json_key_text: d[g_json_key_text],
|
||||
g_json_key_path: d[g_json_key_path]
|
||||
}
|
||||
)
|
||||
output.append({g_json_key_text: d[g_json_key_text], g_json_key_path: d[g_json_key_path]})
|
||||
return output
|
||||
|
||||
|
||||
@@ -50,17 +47,13 @@ def b_change_index(index, batch):
|
||||
g_index, g_batch = index, batch
|
||||
datas = reload_data(index, batch)
|
||||
output = []
|
||||
for i , _ in enumerate(datas):
|
||||
for i, _ in enumerate(datas):
|
||||
output.append(
|
||||
# gr.Textbox(
|
||||
# label=f"Text {i+index}",
|
||||
# value=_[g_json_key_text]#text
|
||||
# )
|
||||
{
|
||||
"__type__":"update",
|
||||
"label":f"Text {i+index}",
|
||||
"value":_[g_json_key_text]
|
||||
}
|
||||
{"__type__": "update", "label": f"Text {i + index}", "value": _[g_json_key_text]}
|
||||
)
|
||||
for _ in range(g_batch - len(datas)):
|
||||
output.append(
|
||||
@@ -68,11 +61,7 @@ def b_change_index(index, batch):
|
||||
# label=f"Text",
|
||||
# value=""
|
||||
# )
|
||||
{
|
||||
"__type__": "update",
|
||||
"label": f"Text",
|
||||
"value": ""
|
||||
}
|
||||
{"__type__": "update", "label": "Text", "value": ""}
|
||||
)
|
||||
for _ in datas:
|
||||
output.append(_[g_json_key_path])
|
||||
@@ -86,7 +75,7 @@ def b_change_index(index, batch):
|
||||
def b_next_index(index, batch):
|
||||
b_save_file()
|
||||
if (index + batch) <= g_max_json_index:
|
||||
return index + batch , *b_change_index(index + batch, batch)
|
||||
return index + batch, *b_change_index(index + batch, batch)
|
||||
else:
|
||||
return index, *b_change_index(index, batch)
|
||||
|
||||
@@ -94,7 +83,7 @@ def b_next_index(index, batch):
|
||||
def b_previous_index(index, batch):
|
||||
b_save_file()
|
||||
if (index - batch) >= 0:
|
||||
return index - batch , *b_change_index(index - batch, batch)
|
||||
return index - batch, *b_change_index(index - batch, batch)
|
||||
else:
|
||||
return 0, *b_change_index(0, batch)
|
||||
|
||||
@@ -104,8 +93,8 @@ def b_submit_change(*text_list):
|
||||
change = False
|
||||
for i, new_text in enumerate(text_list):
|
||||
if g_index + i <= g_max_json_index:
|
||||
new_text = new_text.strip()+' '
|
||||
if (g_data_json[g_index + i][g_json_key_text] != new_text):
|
||||
new_text = new_text.strip() + " "
|
||||
if g_data_json[g_index + i][g_json_key_text] != new_text:
|
||||
g_data_json[g_index + i][g_json_key_text] = new_text
|
||||
change = True
|
||||
if change:
|
||||
@@ -119,18 +108,22 @@ def b_delete_audio(*checkbox_list):
|
||||
change = False
|
||||
for i, checkbox in reversed(list(enumerate(checkbox_list))):
|
||||
if g_index + i < len(g_data_json):
|
||||
if (checkbox == True):
|
||||
if checkbox == True:
|
||||
g_data_json.pop(g_index + i)
|
||||
change = True
|
||||
|
||||
g_max_json_index = len(g_data_json)-1
|
||||
|
||||
g_max_json_index = len(g_data_json) - 1
|
||||
if g_index > g_max_json_index:
|
||||
g_index = g_max_json_index
|
||||
g_index = g_index if g_index >= 0 else 0
|
||||
if change:
|
||||
b_save_file()
|
||||
# return gr.Slider(value=g_index, maximum=(g_max_json_index if g_max_json_index>=0 else 0)), *b_change_index(g_index, g_batch)
|
||||
return {"value":g_index,"__type__":"update","maximum":(g_max_json_index if g_max_json_index>=0 else 0)},*b_change_index(g_index, g_batch)
|
||||
return {
|
||||
"value": g_index,
|
||||
"__type__": "update",
|
||||
"maximum": (g_max_json_index if g_max_json_index >= 0 else 0),
|
||||
}, *b_change_index(g_index, g_batch)
|
||||
|
||||
|
||||
def b_invert_selection(*checkbox_list):
|
||||
@@ -143,18 +136,18 @@ def get_next_path(filename):
|
||||
base_name = os.path.splitext(os.path.basename(filename))[0]
|
||||
for i in range(100):
|
||||
new_path = os.path.join(base_dir, f"{base_name}_{str(i).zfill(2)}.wav")
|
||||
if not os.path.exists(new_path) :
|
||||
if not os.path.exists(new_path):
|
||||
return new_path
|
||||
return os.path.join(base_dir, f'{str(uuid.uuid4())}.wav')
|
||||
return os.path.join(base_dir, f"{str(uuid.uuid4())}.wav")
|
||||
|
||||
|
||||
def b_audio_split(audio_breakpoint, *checkbox_list):
|
||||
global g_data_json , g_max_json_index
|
||||
global g_data_json, g_max_json_index
|
||||
checked_index = []
|
||||
for i, checkbox in enumerate(checkbox_list):
|
||||
if (checkbox == True and g_index+i < len(g_data_json)):
|
||||
if checkbox == True and g_index + i < len(g_data_json):
|
||||
checked_index.append(g_index + i)
|
||||
if len(checked_index) == 1 :
|
||||
if len(checked_index) == 1:
|
||||
index = checked_index[0]
|
||||
audio_json = copy.deepcopy(g_data_json[index])
|
||||
path = audio_json[g_json_key_path]
|
||||
@@ -162,7 +155,7 @@ def b_audio_split(audio_breakpoint, *checkbox_list):
|
||||
audio_maxframe = len(data)
|
||||
break_frame = int(audio_breakpoint * sample_rate)
|
||||
|
||||
if (break_frame >= 1 and break_frame < audio_maxframe):
|
||||
if break_frame >= 1 and break_frame < audio_maxframe:
|
||||
audio_first = data[0:break_frame]
|
||||
audio_second = data[break_frame:]
|
||||
nextpath = get_next_path(path)
|
||||
@@ -174,19 +167,20 @@ def b_audio_split(audio_breakpoint, *checkbox_list):
|
||||
|
||||
g_max_json_index = len(g_data_json) - 1
|
||||
# return gr.Slider(value=g_index, maximum=g_max_json_index), *b_change_index(g_index, g_batch)
|
||||
return {"value":g_index,"maximum":g_max_json_index,"__type__":"update"}, *b_change_index(g_index, g_batch)
|
||||
return {"value": g_index, "maximum": g_max_json_index, "__type__": "update"}, *b_change_index(g_index, g_batch)
|
||||
|
||||
|
||||
def b_merge_audio(interval_r, *checkbox_list):
|
||||
global g_data_json , g_max_json_index
|
||||
global g_data_json, g_max_json_index
|
||||
b_save_file()
|
||||
checked_index = []
|
||||
audios_path = []
|
||||
audios_text = []
|
||||
for i, checkbox in enumerate(checkbox_list):
|
||||
if (checkbox == True and g_index+i < len(g_data_json)):
|
||||
if checkbox == True and g_index + i < len(g_data_json):
|
||||
checked_index.append(g_index + i)
|
||||
|
||||
if (len(checked_index)>1):
|
||||
|
||||
if len(checked_index) > 1:
|
||||
for i in checked_index:
|
||||
audios_path.append(g_data_json[i][g_json_key_path])
|
||||
audios_text.append(g_data_json[i][g_json_key_text])
|
||||
@@ -202,7 +196,7 @@ def b_merge_audio(interval_r, *checkbox_list):
|
||||
for i, path in enumerate(audios_path):
|
||||
data, sample_rate = librosa.load(path, sr=l_sample_rate, mono=True)
|
||||
l_sample_rate = sample_rate
|
||||
if (i > 0):
|
||||
if i > 0:
|
||||
silence = np.zeros(int(l_sample_rate * interval_r))
|
||||
audio_list.append(silence)
|
||||
|
||||
@@ -213,32 +207,32 @@ def b_merge_audio(interval_r, *checkbox_list):
|
||||
soundfile.write(base_path, audio_concat, l_sample_rate)
|
||||
|
||||
b_save_file()
|
||||
|
||||
|
||||
g_max_json_index = len(g_data_json) - 1
|
||||
|
||||
|
||||
# return gr.Slider(value=g_index, maximum=g_max_json_index), *b_change_index(g_index, g_batch)
|
||||
return {"value":g_index,"maximum":g_max_json_index,"__type__":"update"}, *b_change_index(g_index, g_batch)
|
||||
return {"value": g_index, "maximum": g_max_json_index, "__type__": "update"}, *b_change_index(g_index, g_batch)
|
||||
|
||||
|
||||
def b_save_json():
|
||||
with open(g_load_file,'w', encoding="utf-8") as file:
|
||||
with open(g_load_file, "w", encoding="utf-8") as file:
|
||||
for data in g_data_json:
|
||||
file.write(f'{json.dumps(data, ensure_ascii = False)}\n')
|
||||
file.write(f"{json.dumps(data, ensure_ascii=False)}\n")
|
||||
|
||||
|
||||
def b_save_list():
|
||||
with open(g_load_file,'w', encoding="utf-8") as file:
|
||||
with open(g_load_file, "w", encoding="utf-8") as file:
|
||||
for data in g_data_json:
|
||||
wav_path = data["wav_path"]
|
||||
speaker_name = data["speaker_name"]
|
||||
language = data["language"]
|
||||
text = data["text"]
|
||||
file.write(f"{wav_path}|{speaker_name}|{language}|{text}".strip()+'\n')
|
||||
file.write(f"{wav_path}|{speaker_name}|{language}|{text}".strip() + "\n")
|
||||
|
||||
|
||||
def b_load_json():
|
||||
global g_data_json, g_max_json_index
|
||||
with open(g_load_file, 'r', encoding="utf-8") as file:
|
||||
with open(g_load_file, "r", encoding="utf-8") as file:
|
||||
g_data_json = file.readlines()
|
||||
g_data_json = [json.loads(line) for line in g_data_json]
|
||||
g_max_json_index = len(g_data_json) - 1
|
||||
@@ -246,19 +240,14 @@ def b_load_json():
|
||||
|
||||
def b_load_list():
|
||||
global g_data_json, g_max_json_index
|
||||
with open(g_load_file, 'r', encoding="utf-8") as source:
|
||||
with open(g_load_file, "r", encoding="utf-8") as source:
|
||||
data_list = source.readlines()
|
||||
for _ in data_list:
|
||||
data = _.split('|')
|
||||
if (len(data) == 4):
|
||||
data = _.split("|")
|
||||
if len(data) == 4:
|
||||
wav_path, speaker_name, language, text = data
|
||||
g_data_json.append(
|
||||
{
|
||||
'wav_path':wav_path,
|
||||
'speaker_name':speaker_name,
|
||||
'language':language,
|
||||
'text':text.strip()
|
||||
}
|
||||
{"wav_path": wav_path, "speaker_name": speaker_name, "language": language, "text": text.strip()}
|
||||
)
|
||||
else:
|
||||
print("error line:", data)
|
||||
@@ -283,17 +272,17 @@ def set_global(load_json, load_list, json_key_text, json_key_path, batch):
|
||||
global g_json_key_text, g_json_key_path, g_load_file, g_load_format, g_batch
|
||||
|
||||
g_batch = int(batch)
|
||||
|
||||
if (load_json != "None"):
|
||||
|
||||
if load_json != "None":
|
||||
g_load_format = "json"
|
||||
g_load_file = load_json
|
||||
elif (load_list != "None"):
|
||||
elif load_list != "None":
|
||||
g_load_format = "list"
|
||||
g_load_file = load_list
|
||||
else:
|
||||
g_load_format = "list"
|
||||
g_load_file = "demo.list"
|
||||
|
||||
|
||||
g_json_key_text = json_key_text
|
||||
g_json_key_path = json_key_path
|
||||
|
||||
@@ -301,21 +290,20 @@ def set_global(load_json, load_list, json_key_text, json_key_path, batch):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
parser.add_argument('--load_json', default="None", help='source file, like demo.json')
|
||||
parser.add_argument('--is_share', default="False", help='whether webui is_share=True')
|
||||
parser.add_argument('--load_list', default="None", help='source file, like demo.list')
|
||||
parser.add_argument('--webui_port_subfix', default=9871, help='source file, like demo.list')
|
||||
parser.add_argument('--json_key_text', default="text", help='the text key name in json, Default: text')
|
||||
parser.add_argument('--json_key_path', default="wav_path", help='the path key name in json, Default: wav_path')
|
||||
parser.add_argument('--g_batch', default=10, help='max number g_batch wav to display, Default: 10')
|
||||
parser = argparse.ArgumentParser(description="Process some integers.")
|
||||
parser.add_argument("--load_json", default="None", help="source file, like demo.json")
|
||||
parser.add_argument("--is_share", default="False", help="whether webui is_share=True")
|
||||
parser.add_argument("--load_list", default="None", help="source file, like demo.list")
|
||||
parser.add_argument("--webui_port_subfix", default=9871, help="source file, like demo.list")
|
||||
parser.add_argument("--json_key_text", default="text", help="the text key name in json, Default: text")
|
||||
parser.add_argument("--json_key_path", default="wav_path", help="the path key name in json, Default: wav_path")
|
||||
parser.add_argument("--g_batch", default=10, help="max number g_batch wav to display, Default: 10")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
set_global(args.load_json, args.load_list, args.json_key_text, args.json_key_path, args.g_batch)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
btn_change_index = gr.Button("Change Index")
|
||||
btn_submit_change = gr.Button("Submit Text")
|
||||
@@ -323,79 +311,50 @@ if __name__ == "__main__":
|
||||
btn_delete_audio = gr.Button("Delete Audio")
|
||||
btn_previous_index = gr.Button("Previous Index")
|
||||
btn_next_index = gr.Button("Next Index")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
index_slider = gr.Slider(
|
||||
minimum=0, maximum=g_max_json_index, value=g_index, step=1, label="Index", scale=3
|
||||
)
|
||||
index_slider = gr.Slider(minimum=0, maximum=g_max_json_index, value=g_index, step=1, label="Index", scale=3)
|
||||
splitpoint_slider = gr.Slider(
|
||||
minimum=0, maximum=120.0, value=0, step=0.1, label="Audio Split Point(s)", scale=3
|
||||
minimum=0, maximum=120.0, value=0, step=0.1, label="Audio Split Point(s)", scale=3
|
||||
)
|
||||
btn_audio_split = gr.Button("Split Audio", scale=1)
|
||||
btn_save_json = gr.Button("Save File", visible=True, scale=1)
|
||||
btn_invert_selection = gr.Button("Invert Selection", scale=1)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
for _ in range(0,g_batch):
|
||||
for _ in range(0, g_batch):
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
label = "Text",
|
||||
visible = True,
|
||||
scale=5
|
||||
)
|
||||
audio_output = gr.Audio(
|
||||
label="Output Audio",
|
||||
visible = True,
|
||||
scale=5
|
||||
)
|
||||
audio_check = gr.Checkbox(
|
||||
label="Yes",
|
||||
show_label = True,
|
||||
info = "Choose Audio",
|
||||
scale=1
|
||||
)
|
||||
text = gr.Textbox(label="Text", visible=True, scale=5)
|
||||
audio_output = gr.Audio(label="Output Audio", visible=True, scale=5)
|
||||
audio_check = gr.Checkbox(label="Yes", show_label=True, info="Choose Audio", scale=1)
|
||||
g_text_list.append(text)
|
||||
g_audio_list.append(audio_output)
|
||||
g_checkbox_list.append(audio_check)
|
||||
|
||||
|
||||
|
||||
with gr.Row():
|
||||
batchsize_slider = gr.Slider(
|
||||
minimum=1, maximum=g_batch, value=g_batch, step=1, label="Batch Size", scale=3, interactive=False
|
||||
)
|
||||
interval_slider = gr.Slider(
|
||||
minimum=0, maximum=2, value=0, step=0.01, label="Interval", scale=3
|
||||
minimum=1, maximum=g_batch, value=g_batch, step=1, label="Batch Size", scale=3, interactive=False
|
||||
)
|
||||
interval_slider = gr.Slider(minimum=0, maximum=2, value=0, step=0.01, label="Interval", scale=3)
|
||||
btn_theme_dark = gr.Button("Light Theme", link="?__theme=light", scale=1)
|
||||
btn_theme_light = gr.Button("Dark Theme", link="?__theme=dark", scale=1)
|
||||
|
||||
|
||||
btn_change_index.click(
|
||||
b_change_index,
|
||||
inputs=[
|
||||
index_slider,
|
||||
batchsize_slider,
|
||||
],
|
||||
outputs=[
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[*g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
|
||||
btn_submit_change.click(
|
||||
b_submit_change,
|
||||
inputs=[
|
||||
*g_text_list,
|
||||
],
|
||||
outputs=[
|
||||
index_slider,
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[index_slider, *g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
btn_previous_index.click(
|
||||
@@ -404,82 +363,39 @@ if __name__ == "__main__":
|
||||
index_slider,
|
||||
batchsize_slider,
|
||||
],
|
||||
outputs=[
|
||||
index_slider,
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[index_slider, *g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
|
||||
btn_next_index.click(
|
||||
b_next_index,
|
||||
inputs=[
|
||||
index_slider,
|
||||
batchsize_slider,
|
||||
],
|
||||
outputs=[
|
||||
index_slider,
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[index_slider, *g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
btn_delete_audio.click(
|
||||
b_delete_audio,
|
||||
inputs=[
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[
|
||||
index_slider,
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
]
|
||||
inputs=[*g_checkbox_list],
|
||||
outputs=[index_slider, *g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
btn_merge_audio.click(
|
||||
b_merge_audio,
|
||||
inputs=[
|
||||
interval_slider,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[
|
||||
index_slider,
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
]
|
||||
inputs=[interval_slider, *g_checkbox_list],
|
||||
outputs=[index_slider, *g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
btn_audio_split.click(
|
||||
b_audio_split,
|
||||
inputs=[
|
||||
splitpoint_slider,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[
|
||||
index_slider,
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
]
|
||||
inputs=[splitpoint_slider, *g_checkbox_list],
|
||||
outputs=[index_slider, *g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
btn_invert_selection.click(
|
||||
b_invert_selection,
|
||||
inputs=[
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[
|
||||
*g_checkbox_list
|
||||
]
|
||||
)
|
||||
btn_invert_selection.click(b_invert_selection, inputs=[*g_checkbox_list], outputs=[*g_checkbox_list])
|
||||
|
||||
btn_save_json.click(
|
||||
b_save_file
|
||||
)
|
||||
btn_save_json.click(b_save_file)
|
||||
|
||||
demo.load(
|
||||
b_change_index,
|
||||
@@ -487,17 +403,13 @@ if __name__ == "__main__":
|
||||
index_slider,
|
||||
batchsize_slider,
|
||||
],
|
||||
outputs=[
|
||||
*g_text_list,
|
||||
*g_audio_list,
|
||||
*g_checkbox_list
|
||||
],
|
||||
outputs=[*g_text_list, *g_audio_list, *g_checkbox_list],
|
||||
)
|
||||
|
||||
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
quiet=True,
|
||||
share=eval(args.is_share),
|
||||
server_port=int(args.webui_port_subfix)
|
||||
)
|
||||
server_port=int(args.webui_port_subfix),
|
||||
)
|
||||
|
||||
@@ -7,23 +7,22 @@ import torch.nn.functional as F
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
class Attend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dropout = 0.,
|
||||
flash = False,
|
||||
scale = None
|
||||
):
|
||||
def __init__(self, dropout=0.0, flash=False, scale=None):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.flash = flash
|
||||
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), (
|
||||
"in order to use flash attention, you must be using pytorch 2.0 or above"
|
||||
)
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
||||
@@ -34,7 +33,7 @@ class Attend(nn.Module):
|
||||
|
||||
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
|
||||
return F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
"""
|
||||
@@ -54,7 +53,7 @@ class Attend(nn.Module):
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
@@ -63,6 +62,6 @@ class Attend(nn.Module):
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum, Tensor
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from typing import Tuple, Optional, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
@@ -19,6 +19,7 @@ from einops.layers.torch import Rearrange
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@@ -37,14 +38,15 @@ def unpack_one(t, ps, pattern):
|
||||
|
||||
# norm
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
return F.normalize(t, dim=-1, p=2)
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -53,13 +55,9 @@ class RMSNorm(Module):
|
||||
|
||||
# attention
|
||||
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mult=4,
|
||||
dropout=0.
|
||||
):
|
||||
def __init__(self, dim, mult=4, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
@@ -68,7 +66,7 @@ class FeedForward(Module):
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_inner, dim),
|
||||
nn.Dropout(dropout)
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -76,18 +74,10 @@ class FeedForward(Module):
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
rotary_embed=None,
|
||||
flash=True
|
||||
):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale = dim_head**-0.5
|
||||
dim_inner = heads * dim_head
|
||||
|
||||
self.rotary_embed = rotary_embed
|
||||
@@ -99,15 +89,12 @@ class Attention(Module):
|
||||
|
||||
self.to_gates = nn.Linear(dim, heads)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim_inner, dim, bias=False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
|
||||
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
||||
|
||||
if exists(self.rotary_embed):
|
||||
q = self.rotary_embed.rotate_queries_or_keys(q)
|
||||
@@ -116,9 +103,9 @@ class Attention(Module):
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
gates = self.to_gates(x)
|
||||
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
|
||||
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -128,42 +115,22 @@ class LinearAttention(Module):
|
||||
"""
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=32,
|
||||
heads=8,
|
||||
scale=8,
|
||||
flash=False,
|
||||
dropout=0.
|
||||
):
|
||||
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False),
|
||||
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
|
||||
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = Attend(
|
||||
scale=scale,
|
||||
dropout=dropout,
|
||||
flash=flash
|
||||
)
|
||||
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange('b h d n -> b n (h d)'),
|
||||
nn.Linear(dim_inner, dim, bias=False)
|
||||
)
|
||||
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x)
|
||||
@@ -178,19 +145,19 @@ class LinearAttention(Module):
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.,
|
||||
ff_dropout=0.,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.0,
|
||||
ff_dropout=0.0,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
@@ -199,18 +166,20 @@ class Transformer(Module):
|
||||
if linear_attn:
|
||||
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
||||
else:
|
||||
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed, flash=flash_attn)
|
||||
attn = Attention(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed,
|
||||
flash=flash_attn,
|
||||
)
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
attn,
|
||||
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
||||
]))
|
||||
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
@@ -220,22 +189,16 @@ class Transformer(Module):
|
||||
|
||||
# bandsplit module
|
||||
|
||||
|
||||
class BandSplit(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...]
|
||||
):
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_features = ModuleList([])
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = nn.Sequential(
|
||||
RMSNorm(dim_in),
|
||||
nn.Linear(dim_in, dim)
|
||||
)
|
||||
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
||||
|
||||
self.to_features.append(net)
|
||||
|
||||
@@ -250,13 +213,7 @@ class BandSplit(Module):
|
||||
return torch.stack(outs, dim=-2)
|
||||
|
||||
|
||||
def MLP(
|
||||
dim_in,
|
||||
dim_out,
|
||||
dim_hidden=None,
|
||||
depth=1,
|
||||
activation=nn.Tanh
|
||||
):
|
||||
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
||||
dim_hidden = default(dim_hidden, dim_in)
|
||||
|
||||
net = []
|
||||
@@ -277,13 +234,7 @@ def MLP(
|
||||
|
||||
class MaskEstimator(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...],
|
||||
depth,
|
||||
mlp_expansion_factor=4
|
||||
):
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_freqs = ModuleList([])
|
||||
@@ -292,10 +243,7 @@ class MaskEstimator(Module):
|
||||
for dim_in in dim_inputs:
|
||||
net = []
|
||||
|
||||
mlp = nn.Sequential(
|
||||
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
|
||||
nn.GLU(dim=-1)
|
||||
)
|
||||
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
||||
|
||||
self.to_freqs.append(mlp)
|
||||
|
||||
@@ -314,53 +262,106 @@ class MaskEstimator(Module):
|
||||
# main class
|
||||
|
||||
DEFAULT_FREQS_PER_BANDS = (
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2,
|
||||
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
||||
12, 12, 12, 12, 12, 12, 12, 12,
|
||||
24, 24, 24, 24, 24, 24, 24, 24,
|
||||
48, 48, 48, 48, 48, 48, 48, 48,
|
||||
128, 129,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
128,
|
||||
129,
|
||||
)
|
||||
|
||||
|
||||
class BSRoformer(Module):
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
||||
# in the paper, they divide into ~60 bands, test with 1 for starters
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.,
|
||||
ff_dropout=0.,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=2,
|
||||
multi_stft_resolution_loss_weight=1.,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
||||
# in the paper, they divide into ~60 bands, test with 1 for starters
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.0,
|
||||
ff_dropout=0.0,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=2,
|
||||
multi_stft_resolution_loss_weight=1.0,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -379,7 +380,7 @@ class BSRoformer(Module):
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
flash_attn=flash_attn,
|
||||
norm_output=False
|
||||
norm_output=False,
|
||||
)
|
||||
|
||||
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
@@ -400,26 +401,23 @@ class BSRoformer(Module):
|
||||
self.final_norm = RMSNorm(dim)
|
||||
|
||||
self.stft_kwargs = dict(
|
||||
n_fft=stft_n_fft,
|
||||
hop_length=stft_hop_length,
|
||||
win_length=stft_win_length,
|
||||
normalized=stft_normalized
|
||||
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
|
||||
)
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
|
||||
freqs = torch.stft(
|
||||
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True
|
||||
).shape[1]
|
||||
|
||||
assert len(freqs_per_bands) > 1
|
||||
assert sum(
|
||||
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
|
||||
assert sum(freqs_per_bands) == freqs, (
|
||||
f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
||||
)
|
||||
|
||||
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
|
||||
|
||||
self.band_split = BandSplit(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex
|
||||
)
|
||||
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
||||
|
||||
self.mask_estimators = nn.ModuleList([])
|
||||
|
||||
@@ -440,17 +438,9 @@ class BSRoformer(Module):
|
||||
self.multi_stft_n_fft = stft_n_fft
|
||||
self.multi_stft_window_fn = multi_stft_window_fn
|
||||
|
||||
self.multi_stft_kwargs = dict(
|
||||
hop_length=multi_stft_hop_size,
|
||||
normalized=multi_stft_normalized
|
||||
)
|
||||
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raw_audio,
|
||||
target=None,
|
||||
return_loss_breakdown=False
|
||||
):
|
||||
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
||||
"""
|
||||
einops
|
||||
|
||||
@@ -469,14 +459,16 @@ class BSRoformer(Module):
|
||||
x_is_mps = True if device.type == "mps" else False
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
||||
|
||||
channels = raw_audio.shape[1]
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
|
||||
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
||||
)
|
||||
|
||||
# to stft
|
||||
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
@@ -485,16 +477,21 @@ class BSRoformer(Module):
|
||||
try:
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
except:
|
||||
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
|
||||
stft_repr = torch.stft(
|
||||
raw_audio.cpu() if x_is_mps else raw_audio,
|
||||
**self.stft_kwargs,
|
||||
window=stft_window.cpu() if x_is_mps else stft_window,
|
||||
return_complex=True,
|
||||
).to(device)
|
||||
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
||||
|
||||
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
|
||||
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
@@ -505,16 +502,15 @@ class BSRoformer(Module):
|
||||
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
x, ft_ps = pack([x], "b * d")
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
(x,) = unpack(x, ft_ps, "b * d")
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
@@ -523,24 +519,24 @@ class BSRoformer(Module):
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
x = rearrange(x, "b t f d -> b f t d")
|
||||
x, ps = pack([x], "* t d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
(x,) = unpack(x, ps, "* t d")
|
||||
x = rearrange(x, "b f t d -> b t f d")
|
||||
x, ps = pack([x], "* f d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
(x,) = unpack(x, ps, "* f d")
|
||||
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
@@ -553,11 +549,11 @@ class BSRoformer(Module):
|
||||
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
|
||||
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
|
||||
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
||||
|
||||
# complex number multiplication
|
||||
|
||||
@@ -568,18 +564,26 @@ class BSRoformer(Module):
|
||||
|
||||
# istft
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
||||
|
||||
# same as torch.stft() fix for MacOS MPS above
|
||||
try:
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
|
||||
recon_audio = torch.istft(
|
||||
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]
|
||||
)
|
||||
except:
|
||||
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
|
||||
recon_audio = torch.istft(
|
||||
stft_repr.cpu() if x_is_mps else stft_repr,
|
||||
**self.stft_kwargs,
|
||||
window=stft_window.cpu() if x_is_mps else stft_window,
|
||||
return_complex=False,
|
||||
length=raw_audio.shape[-1],
|
||||
).to(device)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
||||
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems)
|
||||
|
||||
if num_stems == 1:
|
||||
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
|
||||
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
||||
|
||||
# if a target is passed in, calculate loss for learning
|
||||
|
||||
@@ -590,13 +594,13 @@ class BSRoformer(Module):
|
||||
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
||||
|
||||
if target.ndim == 2:
|
||||
target = rearrange(target, '... t -> ... 1 t')
|
||||
target = rearrange(target, "... t -> ... 1 t")
|
||||
|
||||
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
|
||||
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
|
||||
|
||||
loss = F.l1_loss(recon_audio, target)
|
||||
|
||||
multi_stft_resolution_loss = 0.
|
||||
multi_stft_resolution_loss = 0.0
|
||||
|
||||
for window_size in self.multi_stft_resolutions_window_sizes:
|
||||
res_stft_kwargs = dict(
|
||||
@@ -607,8 +611,8 @@ class BSRoformer(Module):
|
||||
**self.multi_stft_kwargs,
|
||||
)
|
||||
|
||||
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
|
||||
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
||||
|
||||
@@ -619,4 +623,4 @@ class BSRoformer(Module):
|
||||
if not return_loss_breakdown:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (loss, multi_stft_resolution_loss)
|
||||
return total_loss, (loss, multi_stft_resolution_loss)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum, Tensor
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from typing import Tuple, Optional, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
@@ -22,6 +22,7 @@ from librosa import filters
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@@ -38,9 +39,9 @@ def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
def pad_at_dim(t, pad, dim=-1, value=0.):
|
||||
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = ((0, 0) * dims_from_right)
|
||||
def pad_at_dim(t, pad, dim=-1, value=0.0):
|
||||
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = (0, 0) * dims_from_right
|
||||
return F.pad(t, (*zeros, *pad), value=value)
|
||||
|
||||
|
||||
@@ -50,10 +51,11 @@ def l2norm(t):
|
||||
|
||||
# norm
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -62,13 +64,9 @@ class RMSNorm(Module):
|
||||
|
||||
# attention
|
||||
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mult=4,
|
||||
dropout=0.
|
||||
):
|
||||
def __init__(self, dim, mult=4, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
@@ -77,7 +75,7 @@ class FeedForward(Module):
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_inner, dim),
|
||||
nn.Dropout(dropout)
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -85,18 +83,10 @@ class FeedForward(Module):
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
rotary_embed=None,
|
||||
flash=True
|
||||
):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale = dim_head**-0.5
|
||||
dim_inner = heads * dim_head
|
||||
|
||||
self.rotary_embed = rotary_embed
|
||||
@@ -108,15 +98,12 @@ class Attention(Module):
|
||||
|
||||
self.to_gates = nn.Linear(dim, heads)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim_inner, dim, bias=False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
|
||||
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
||||
|
||||
if exists(self.rotary_embed):
|
||||
q = self.rotary_embed.rotate_queries_or_keys(q)
|
||||
@@ -125,9 +112,9 @@ class Attention(Module):
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
gates = self.to_gates(x)
|
||||
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
|
||||
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -137,42 +124,22 @@ class LinearAttention(Module):
|
||||
"""
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=32,
|
||||
heads=8,
|
||||
scale=8,
|
||||
flash=False,
|
||||
dropout=0.
|
||||
):
|
||||
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False),
|
||||
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
|
||||
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = Attend(
|
||||
scale=scale,
|
||||
dropout=dropout,
|
||||
flash=flash
|
||||
)
|
||||
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange('b h d n -> b n (h d)'),
|
||||
nn.Linear(dim_inner, dim, bias=False)
|
||||
)
|
||||
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x)
|
||||
@@ -187,19 +154,19 @@ class LinearAttention(Module):
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.,
|
||||
ff_dropout=0.,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.0,
|
||||
ff_dropout=0.0,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
@@ -208,18 +175,20 @@ class Transformer(Module):
|
||||
if linear_attn:
|
||||
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
||||
else:
|
||||
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed, flash=flash_attn)
|
||||
attn = Attention(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed,
|
||||
flash=flash_attn,
|
||||
)
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
attn,
|
||||
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
||||
]))
|
||||
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
@@ -229,22 +198,16 @@ class Transformer(Module):
|
||||
|
||||
# bandsplit module
|
||||
|
||||
|
||||
class BandSplit(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...]
|
||||
):
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_features = ModuleList([])
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = nn.Sequential(
|
||||
RMSNorm(dim_in),
|
||||
nn.Linear(dim_in, dim)
|
||||
)
|
||||
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
||||
|
||||
self.to_features.append(net)
|
||||
|
||||
@@ -259,13 +222,7 @@ class BandSplit(Module):
|
||||
return torch.stack(outs, dim=-2)
|
||||
|
||||
|
||||
def MLP(
|
||||
dim_in,
|
||||
dim_out,
|
||||
dim_hidden=None,
|
||||
depth=1,
|
||||
activation=nn.Tanh
|
||||
):
|
||||
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
||||
dim_hidden = default(dim_hidden, dim_in)
|
||||
|
||||
net = []
|
||||
@@ -286,13 +243,7 @@ def MLP(
|
||||
|
||||
class MaskEstimator(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...],
|
||||
depth,
|
||||
mlp_expansion_factor=4
|
||||
):
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_freqs = ModuleList([])
|
||||
@@ -301,10 +252,7 @@ class MaskEstimator(Module):
|
||||
for dim_in in dim_inputs:
|
||||
net = []
|
||||
|
||||
mlp = nn.Sequential(
|
||||
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
|
||||
nn.GLU(dim=-1)
|
||||
)
|
||||
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
||||
|
||||
self.to_freqs.append(mlp)
|
||||
|
||||
@@ -322,43 +270,43 @@ class MaskEstimator(Module):
|
||||
|
||||
# main class
|
||||
|
||||
class MelBandRoformer(Module):
|
||||
|
||||
class MelBandRoformer(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
num_bands=60,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.1,
|
||||
ff_dropout=0.1,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
sample_rate=44100, # needed for mel filter bank from librosa
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=1,
|
||||
multi_stft_resolution_loss_weight=1.,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
num_bands=60,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.1,
|
||||
ff_dropout=0.1,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
sample_rate=44100, # needed for mel filter bank from librosa
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=1,
|
||||
multi_stft_resolution_loss_weight=1.0,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -376,7 +324,7 @@ class MelBandRoformer(Module):
|
||||
dim_head=dim_head,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
flash_attn=flash_attn
|
||||
flash_attn=flash_attn,
|
||||
)
|
||||
|
||||
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
@@ -397,13 +345,12 @@ class MelBandRoformer(Module):
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
self.stft_kwargs = dict(
|
||||
n_fft=stft_n_fft,
|
||||
hop_length=stft_hop_length,
|
||||
win_length=stft_win_length,
|
||||
normalized=stft_normalized
|
||||
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
|
||||
)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
|
||||
freqs = torch.stft(
|
||||
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True
|
||||
).shape[1]
|
||||
|
||||
# create mel filter bank
|
||||
# with librosa.filters.mel as in section 2 of paper
|
||||
@@ -414,43 +361,40 @@ class MelBandRoformer(Module):
|
||||
|
||||
# for some reason, it doesn't include the first freq? just force a value for now
|
||||
|
||||
mel_filter_bank[0][0] = 1.
|
||||
mel_filter_bank[0][0] = 1.0
|
||||
|
||||
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
|
||||
# so let's force a positive value
|
||||
|
||||
mel_filter_bank[-1, -1] = 1.
|
||||
mel_filter_bank[-1, -1] = 1.0
|
||||
|
||||
# binary as in paper (then estimated masks are averaged for overlapping regions)
|
||||
|
||||
freqs_per_band = mel_filter_bank > 0
|
||||
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
|
||||
assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
|
||||
|
||||
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
|
||||
repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
|
||||
freq_indices = repeated_freq_indices[freqs_per_band]
|
||||
|
||||
if stereo:
|
||||
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
|
||||
freq_indices = repeat(freq_indices, "f -> f s", s=2)
|
||||
freq_indices = freq_indices * 2 + torch.arange(2)
|
||||
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
|
||||
freq_indices = rearrange(freq_indices, "f s -> (f s)")
|
||||
|
||||
self.register_buffer('freq_indices', freq_indices, persistent=False)
|
||||
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
|
||||
self.register_buffer("freq_indices", freq_indices, persistent=False)
|
||||
self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
|
||||
|
||||
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
|
||||
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
|
||||
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
|
||||
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
|
||||
|
||||
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
|
||||
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
|
||||
self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
|
||||
self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
|
||||
|
||||
# band split and mask estimator
|
||||
|
||||
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
||||
|
||||
self.band_split = BandSplit(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex
|
||||
)
|
||||
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
||||
|
||||
self.mask_estimators = nn.ModuleList([])
|
||||
|
||||
@@ -471,19 +415,11 @@ class MelBandRoformer(Module):
|
||||
self.multi_stft_n_fft = stft_n_fft
|
||||
self.multi_stft_window_fn = multi_stft_window_fn
|
||||
|
||||
self.multi_stft_kwargs = dict(
|
||||
hop_length=multi_stft_hop_size,
|
||||
normalized=multi_stft_normalized
|
||||
)
|
||||
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
||||
|
||||
self.match_input_audio_length = match_input_audio_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raw_audio,
|
||||
target=None,
|
||||
return_loss_breakdown=False
|
||||
):
|
||||
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
||||
"""
|
||||
einops
|
||||
|
||||
@@ -499,28 +435,29 @@ class MelBandRoformer(Module):
|
||||
device = raw_audio.device
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
||||
|
||||
batch, channels, raw_audio_length = raw_audio.shape
|
||||
|
||||
istft_length = raw_audio_length if self.match_input_audio_length else None
|
||||
|
||||
assert (not self.stereo and channels == 1) or (
|
||||
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
|
||||
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
||||
)
|
||||
|
||||
# to stft
|
||||
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
||||
|
||||
# index out all frequencies for all frequency ranges across bands ascending in one go
|
||||
|
||||
@@ -532,7 +469,7 @@ class MelBandRoformer(Module):
|
||||
|
||||
# fold the complex (real and imag) into the frequencies dimension
|
||||
|
||||
x = rearrange(x, 'b f t c -> b t (f c)')
|
||||
x = rearrange(x, "b f t c -> b t (f c)")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
@@ -543,16 +480,15 @@ class MelBandRoformer(Module):
|
||||
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
x, ft_ps = pack([x], "b * d")
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
(x,) = unpack(x, ft_ps, "b * d")
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
@@ -561,24 +497,24 @@ class MelBandRoformer(Module):
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
x = rearrange(x, "b t f d -> b f t d")
|
||||
x, ps = pack([x], "* t d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
(x,) = unpack(x, ps, "* t d")
|
||||
x = rearrange(x, "b f t d -> b t f d")
|
||||
x, ps = pack([x], "* f d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
(x,) = unpack(x, ps, "* f d")
|
||||
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
@@ -588,11 +524,11 @@ class MelBandRoformer(Module):
|
||||
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
|
||||
masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
|
||||
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
||||
|
||||
# complex number multiplication
|
||||
|
||||
@@ -603,12 +539,12 @@ class MelBandRoformer(Module):
|
||||
|
||||
# need to average the estimated mask for the overlapped frequencies
|
||||
|
||||
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
|
||||
scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=num_stems, t=stft_repr.shape[-1])
|
||||
|
||||
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
|
||||
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
|
||||
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
|
||||
|
||||
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
|
||||
denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
|
||||
|
||||
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
||||
|
||||
@@ -618,15 +554,16 @@ class MelBandRoformer(Module):
|
||||
|
||||
# istft
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
|
||||
length=istft_length)
|
||||
recon_audio = torch.istft(
|
||||
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=istft_length
|
||||
)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
|
||||
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=num_stems)
|
||||
|
||||
if num_stems == 1:
|
||||
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
|
||||
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
||||
|
||||
# if a target is passed in, calculate loss for learning
|
||||
|
||||
@@ -637,13 +574,13 @@ class MelBandRoformer(Module):
|
||||
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
||||
|
||||
if target.ndim == 2:
|
||||
target = rearrange(target, '... t -> ... 1 t')
|
||||
target = rearrange(target, "... t -> ... 1 t")
|
||||
|
||||
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
|
||||
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
|
||||
|
||||
loss = F.l1_loss(recon_audio, target)
|
||||
|
||||
multi_stft_resolution_loss = 0.
|
||||
multi_stft_resolution_loss = 0.0
|
||||
|
||||
for window_size in self.multi_stft_resolutions_window_sizes:
|
||||
res_stft_kwargs = dict(
|
||||
@@ -654,8 +591,8 @@ class MelBandRoformer(Module):
|
||||
**self.multi_stft_kwargs,
|
||||
)
|
||||
|
||||
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
|
||||
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
||||
|
||||
|
||||
@@ -1,28 +1,31 @@
|
||||
# This code is modified from https://github.com/ZFTurbo/
|
||||
import librosa
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class Roformer_Loader:
|
||||
def get_config(self, config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
# use fullloader to load tag !!python/tuple, code can be improved
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return config
|
||||
|
||||
def get_default_config(self):
|
||||
default_config = None
|
||||
if self.model_type == 'bs_roformer':
|
||||
if self.model_type == "bs_roformer":
|
||||
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
|
||||
# Other BS_Roformer models may not be compatible
|
||||
# fmt: off
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
@@ -51,9 +54,10 @@ class Roformer_Loader:
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
"inference": {"batch_size": 2, "num_overlap": 2},
|
||||
}
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
# fmt: on
|
||||
elif self.model_type == "mel_band_roformer":
|
||||
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
|
||||
# Other Mel_Band_Roformer models may not be compatible
|
||||
default_config = {
|
||||
@@ -82,29 +86,30 @@ class Roformer_Loader:
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
"inference": {"batch_size": 2, "num_overlap": 2},
|
||||
}
|
||||
|
||||
return default_config
|
||||
|
||||
|
||||
def get_model_from_config(self):
|
||||
if self.model_type == 'bs_roformer':
|
||||
if self.model_type == "bs_roformer":
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
|
||||
model = BSRoformer(**dict(self.config["model"]))
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
elif self.model_type == "mel_band_roformer":
|
||||
from bs_roformer.mel_band_roformer import MelBandRoformer
|
||||
|
||||
model = MelBandRoformer(**dict(self.config["model"]))
|
||||
else:
|
||||
print('Error: Unknown model: {}'.format(self.model_type))
|
||||
print("Error: Unknown model: {}".format(self.model_type))
|
||||
model = None
|
||||
return model
|
||||
|
||||
|
||||
def demix_track(self, model, mix, device):
|
||||
C = self.config["audio"]["chunk_size"] # chunk_size
|
||||
C = self.config["audio"]["chunk_size"] # chunk_size
|
||||
N = self.config["inference"]["num_overlap"]
|
||||
fade_size = C // 10
|
||||
step = int(C // N)
|
||||
@@ -116,7 +121,7 @@ class Roformer_Loader:
|
||||
|
||||
# Do pad from the beginning and end to account floating window results better
|
||||
if length_init > 2 * border and (border > 0):
|
||||
mix = nn.functional.pad(mix, (border, border), mode='reflect')
|
||||
mix = nn.functional.pad(mix, (border, border), mode="reflect")
|
||||
|
||||
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
|
||||
window_size = C
|
||||
@@ -125,17 +130,17 @@ class Roformer_Loader:
|
||||
window_start = torch.ones(window_size)
|
||||
window_middle = torch.ones(window_size)
|
||||
window_finish = torch.ones(window_size)
|
||||
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
|
||||
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
|
||||
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
|
||||
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
|
||||
window_middle[-fade_size:] *= fadeout
|
||||
window_middle[:fade_size] *= fadein
|
||||
|
||||
with torch.amp.autocast('cuda'):
|
||||
with torch.amp.autocast("cuda"):
|
||||
with torch.inference_mode():
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
|
||||
else:
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
req_shape = (1,) + tuple(mix.shape)
|
||||
|
||||
result = torch.zeros(req_shape, dtype=torch.float32)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32)
|
||||
@@ -143,15 +148,15 @@ class Roformer_Loader:
|
||||
batch_data = []
|
||||
batch_locations = []
|
||||
while i < mix.shape[1]:
|
||||
part = mix[:, i:i + C].to(device)
|
||||
part = mix[:, i : i + C].to(device)
|
||||
length = part.shape[-1]
|
||||
if length < C:
|
||||
if length > C // 2 + 1:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode="reflect")
|
||||
else:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode="constant", value=0)
|
||||
if self.is_half:
|
||||
part=part.half()
|
||||
part = part.half()
|
||||
batch_data.append(part)
|
||||
batch_locations.append((i, length))
|
||||
i += step
|
||||
@@ -170,8 +175,8 @@ class Roformer_Loader:
|
||||
|
||||
for j in range(len(batch_locations)):
|
||||
start, l = batch_locations[j]
|
||||
result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
|
||||
counter[..., start:start+l] += window[..., :l]
|
||||
result[..., start : start + l] += x[j][..., :l].cpu() * window[..., :l]
|
||||
counter[..., start : start + l] += window[..., :l]
|
||||
|
||||
batch_data = []
|
||||
batch_locations = []
|
||||
@@ -191,7 +196,6 @@ class Roformer_Loader:
|
||||
else:
|
||||
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
|
||||
|
||||
|
||||
def run_folder(self, input, vocal_root, others_root, format):
|
||||
self.model.eval()
|
||||
path = input
|
||||
@@ -200,20 +204,20 @@ class Roformer_Loader:
|
||||
file_base_name = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
sample_rate = 44100
|
||||
if 'sample_rate' in self.config["audio"]:
|
||||
sample_rate = self.config["audio"]['sample_rate']
|
||||
if "sample_rate" in self.config["audio"]:
|
||||
sample_rate = self.config["audio"]["sample_rate"]
|
||||
|
||||
try:
|
||||
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
||||
except Exception as e:
|
||||
print('Can read track: {}'.format(path))
|
||||
print('Error message: {}'.format(str(e)))
|
||||
print("Can read track: {}".format(path))
|
||||
print("Error message: {}".format(str(e)))
|
||||
return
|
||||
|
||||
# in case if model only supports mono tracks
|
||||
isstereo = self.config["model"].get("stereo", True)
|
||||
if not isstereo and len(mix.shape) != 1:
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
|
||||
|
||||
mix_orig = mix.copy()
|
||||
@@ -226,7 +230,7 @@ class Roformer_Loader:
|
||||
# other instruments are caculated by subtracting target instrument from mixture
|
||||
target_instrument = self.config["training"]["target_instrument"]
|
||||
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
|
||||
other = mix_orig - res[target_instrument] # caculate other instruments
|
||||
other = mix_orig - res[target_instrument] # caculate other instruments
|
||||
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
|
||||
@@ -237,11 +241,10 @@ class Roformer_Loader:
|
||||
vocal_inst = self.config["training"]["instruments"][0]
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
|
||||
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
|
||||
for other in self.config["training"]["instruments"][1:]: # save other instruments
|
||||
for other in self.config["training"]["instruments"][1:]: # save other instruments
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
|
||||
self.save_audio(path_other, res[other].T, sr, format)
|
||||
|
||||
|
||||
def save_audio(self, path, data, sr, format):
|
||||
# input path should be endwith '.wav'
|
||||
if format in ["wav", "flac"]:
|
||||
@@ -250,10 +253,11 @@ class Roformer_Loader:
|
||||
sf.write(path, data, sr)
|
||||
else:
|
||||
sf.write(path, data, sr)
|
||||
os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
|
||||
try: os.remove(path)
|
||||
except: pass
|
||||
|
||||
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format))
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
|
||||
def __init__(self, model_path, config_path, device, is_half):
|
||||
self.device = device
|
||||
@@ -270,7 +274,9 @@ class Roformer_Loader:
|
||||
if not os.path.exists(config_path):
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, raise an error
|
||||
raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again.")
|
||||
raise ValueError(
|
||||
"Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again."
|
||||
)
|
||||
self.config = self.get_default_config()
|
||||
else:
|
||||
# if there is a configuration file
|
||||
@@ -289,12 +295,10 @@ class Roformer_Loader:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if(is_half==False):
|
||||
if is_half == False:
|
||||
self.model = model.to(device)
|
||||
else:
|
||||
self.model = model.half().to(device)
|
||||
|
||||
|
||||
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
|
||||
self.run_folder(input, vocal_root, others_root, format)
|
||||
|
||||
|
||||
@@ -13,9 +13,7 @@ cpu = torch.device("cpu")
|
||||
|
||||
|
||||
class ConvTDFNetTrim:
|
||||
def __init__(
|
||||
self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024
|
||||
):
|
||||
def __init__(self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024):
|
||||
super(ConvTDFNetTrim, self).__init__()
|
||||
|
||||
self.dim_f = dim_f
|
||||
@@ -24,17 +22,13 @@ class ConvTDFNetTrim:
|
||||
self.hop = hop
|
||||
self.n_bins = self.n_fft // 2 + 1
|
||||
self.chunk_size = hop * (self.dim_t - 1)
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(
|
||||
device
|
||||
)
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
|
||||
self.target_name = target_name
|
||||
self.blender = "blender" in model_name
|
||||
|
||||
self.dim_c = 4
|
||||
out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
|
||||
self.freq_pad = torch.zeros(
|
||||
[1, out_c, self.n_bins - self.dim_f, self.dim_t]
|
||||
).to(device)
|
||||
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
|
||||
|
||||
self.n = L // 2
|
||||
|
||||
@@ -50,28 +44,18 @@ class ConvTDFNetTrim:
|
||||
)
|
||||
x = torch.view_as_real(x)
|
||||
x = x.permute([0, 3, 1, 2])
|
||||
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
|
||||
[-1, self.dim_c, self.n_bins, self.dim_t]
|
||||
)
|
||||
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
|
||||
return x[:, :, : self.dim_f]
|
||||
|
||||
def istft(self, x, freq_pad=None):
|
||||
freq_pad = (
|
||||
self.freq_pad.repeat([x.shape[0], 1, 1, 1])
|
||||
if freq_pad is None
|
||||
else freq_pad
|
||||
)
|
||||
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
|
||||
x = torch.cat([x, freq_pad], -2)
|
||||
c = 4 * 2 if self.target_name == "*" else 2
|
||||
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
|
||||
[-1, 2, self.n_bins, self.dim_t]
|
||||
)
|
||||
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
|
||||
x = x.permute([0, 2, 3, 1])
|
||||
x = x.contiguous()
|
||||
x = torch.view_as_complex(x)
|
||||
x = torch.istft(
|
||||
x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
|
||||
)
|
||||
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
||||
return x.reshape([-1, c, self.chunk_size])
|
||||
|
||||
|
||||
@@ -93,9 +77,7 @@ class Predictor:
|
||||
|
||||
logger.info(ort.get_available_providers())
|
||||
self.args = args
|
||||
self.model_ = get_models(
|
||||
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
|
||||
)
|
||||
self.model_ = get_models(device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft)
|
||||
self.model = ort.InferenceSession(
|
||||
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
|
||||
providers=[
|
||||
@@ -152,9 +134,7 @@ class Predictor:
|
||||
trim = model.n_fft // 2
|
||||
gen_size = model.chunk_size - 2 * trim
|
||||
pad = gen_size - n_sample % gen_size
|
||||
mix_p = np.concatenate(
|
||||
(np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
|
||||
)
|
||||
mix_p = np.concatenate((np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1)
|
||||
mix_waves = []
|
||||
i = 0
|
||||
while i < n_sample + pad:
|
||||
@@ -172,15 +152,8 @@ class Predictor:
|
||||
)
|
||||
tar_waves = model.istft(torch.tensor(spec_pred))
|
||||
else:
|
||||
tar_waves = model.istft(
|
||||
torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
|
||||
)
|
||||
tar_signal = (
|
||||
tar_waves[:, :, trim:-trim]
|
||||
.transpose(0, 1)
|
||||
.reshape(2, -1)
|
||||
.numpy()[:, :-pad]
|
||||
)
|
||||
tar_waves = model.istft(torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0]))
|
||||
tar_signal = tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
|
||||
|
||||
start = 0 if mix == 0 else margin_size
|
||||
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
|
||||
@@ -207,9 +180,7 @@ class Predictor:
|
||||
sources = self.demix(mix.T)
|
||||
opt = sources[0].T
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write(
|
||||
"%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
|
||||
)
|
||||
sf.write("%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate)
|
||||
sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
|
||||
else:
|
||||
path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
|
||||
@@ -219,18 +190,14 @@ class Predictor:
|
||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||
opt_path_other = path_other[:-4] + ".%s" % format
|
||||
if os.path.exists(path_vocal):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
|
||||
)
|
||||
os.system("ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal))
|
||||
if os.path.exists(opt_path_vocal):
|
||||
try:
|
||||
os.remove(path_vocal)
|
||||
except:
|
||||
pass
|
||||
if os.path.exists(path_other):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
|
||||
)
|
||||
os.system("ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other))
|
||||
if os.path.exists(opt_path_other):
|
||||
try:
|
||||
os.remove(path_other)
|
||||
@@ -240,7 +207,7 @@ class Predictor:
|
||||
|
||||
class MDXNetDereverb:
|
||||
def __init__(self, chunks):
|
||||
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy"%os.path.dirname(os.path.abspath(__file__))
|
||||
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy" % os.path.dirname(os.path.abspath(__file__))
|
||||
self.shifts = 10 # 'Predict with randomised equivariant stabilisation'
|
||||
self.mixing = "min_mag" # ['default','min_mag','max_mag']
|
||||
self.chunks = chunks
|
||||
|
||||
100
tools/uvr5/vr.py
100
tools/uvr5/vr.py
@@ -1,6 +1,8 @@
|
||||
import os,sys
|
||||
import os
|
||||
|
||||
parent_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
import logging,pdb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import librosa
|
||||
@@ -27,7 +29,7 @@ class AudioPre:
|
||||
"agg": agg,
|
||||
"high_end_process": "mirroring",
|
||||
}
|
||||
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v2.json"%parent_directory)
|
||||
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v2.json" % parent_directory)
|
||||
model = Nets.CascadedASPPNet(mp.param["bins"] * 2)
|
||||
cpk = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(cpk)
|
||||
@@ -40,9 +42,7 @@ class AudioPre:
|
||||
self.mp = mp
|
||||
self.model = model
|
||||
|
||||
def _path_audio_(
|
||||
self, music_file, ins_root=None, vocal_root=None, format="flac", is_hp3=False
|
||||
):
|
||||
def _path_audio_(self, music_file, ins_root=None, vocal_root=None, format="flac", is_hp3=False):
|
||||
if ins_root is None and vocal_root is None:
|
||||
return "No save root."
|
||||
name = os.path.basename(music_file)
|
||||
@@ -61,19 +61,19 @@ class AudioPre:
|
||||
_,
|
||||
) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug,应该上ffmpeg读取,但是太麻烦了弃坑
|
||||
music_file,
|
||||
sr = bp["sr"],
|
||||
mono = False,
|
||||
dtype = np.float32,
|
||||
res_type = bp["res_type"],
|
||||
sr=bp["sr"],
|
||||
mono=False,
|
||||
dtype=np.float32,
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
if X_wave[d].ndim == 1:
|
||||
X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]])
|
||||
else: # lower bands
|
||||
X_wave[d] = librosa.core.resample(
|
||||
X_wave[d + 1],
|
||||
orig_sr = self.mp.param["band"][d + 1]["sr"],
|
||||
target_sr = bp["sr"],
|
||||
res_type = bp["res_type"],
|
||||
orig_sr=self.mp.param["band"][d + 1]["sr"],
|
||||
target_sr=bp["sr"],
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
# Stft of wave source
|
||||
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(
|
||||
@@ -89,9 +89,7 @@ class AudioPre:
|
||||
input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
|
||||
self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"]
|
||||
)
|
||||
input_high_end = X_spec_s[d][
|
||||
:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :
|
||||
]
|
||||
input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :]
|
||||
|
||||
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp)
|
||||
aggresive_set = float(self.data["agg"] / 100)
|
||||
@@ -100,9 +98,7 @@ class AudioPre:
|
||||
"split_bin": self.mp.param["band"][1]["crop_stop"],
|
||||
}
|
||||
with torch.no_grad():
|
||||
pred, X_mag, X_phase = inference(
|
||||
X_spec_m, self.device, self.model, aggressiveness, self.data
|
||||
)
|
||||
pred, X_mag, X_phase = inference(X_spec_m, self.device, self.model, aggressiveness, self.data)
|
||||
# Postprocess
|
||||
if self.data["postprocess"]:
|
||||
pred_inv = np.clip(X_mag - pred, 0, np.inf)
|
||||
@@ -111,13 +107,11 @@ class AudioPre:
|
||||
v_spec_m = X_spec_m - y_spec_m
|
||||
|
||||
if is_hp3 == True:
|
||||
ins_root,vocal_root = vocal_root,ins_root
|
||||
ins_root, vocal_root = vocal_root, ins_root
|
||||
|
||||
if ins_root is not None:
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(
|
||||
self.data["high_end_process"], y_spec_m, input_high_end, self.mp
|
||||
)
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], y_spec_m, input_high_end, self.mp)
|
||||
wav_instrument = spec_utils.cmb_spectrogram_to_wave(
|
||||
y_spec_m, self.mp, input_high_end_h, input_high_end_
|
||||
)
|
||||
@@ -138,9 +132,7 @@ class AudioPre:
|
||||
self.mp.param["sr"],
|
||||
) #
|
||||
else:
|
||||
path = os.path.join(
|
||||
ins_root, head + "{}_{}.wav".format(name, self.data["agg"])
|
||||
)
|
||||
path = os.path.join(ins_root, head + "{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_instrument) * 32768).astype("int16"),
|
||||
@@ -160,12 +152,8 @@ class AudioPre:
|
||||
else:
|
||||
head = "vocal_"
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(
|
||||
self.data["high_end_process"], v_spec_m, input_high_end, self.mp
|
||||
)
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(
|
||||
v_spec_m, self.mp, input_high_end_h, input_high_end_
|
||||
)
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], v_spec_m, input_high_end, self.mp)
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp, input_high_end_h, input_high_end_)
|
||||
else:
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
|
||||
logger.info("%s vocals done" % name)
|
||||
@@ -179,9 +167,7 @@ class AudioPre:
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
else:
|
||||
path = os.path.join(
|
||||
vocal_root, head + "{}_{}.wav".format(name, self.data["agg"])
|
||||
)
|
||||
path = os.path.join(vocal_root, head + "{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_vocals) * 32768).astype("int16"),
|
||||
@@ -210,7 +196,7 @@ class AudioPreDeEcho:
|
||||
"agg": agg,
|
||||
"high_end_process": "mirroring",
|
||||
}
|
||||
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v3.json"%parent_directory)
|
||||
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v3.json" % parent_directory)
|
||||
nout = 64 if "DeReverb" in model_path else 48
|
||||
model = CascadedNet(mp.param["bins"] * 2, nout)
|
||||
cpk = torch.load(model_path, map_location="cpu")
|
||||
@@ -245,19 +231,19 @@ class AudioPreDeEcho:
|
||||
_,
|
||||
) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug,应该上ffmpeg读取,但是太麻烦了弃坑
|
||||
music_file,
|
||||
sr = bp["sr"],
|
||||
mono = False,
|
||||
dtype = np.float32,
|
||||
res_type = bp["res_type"],
|
||||
sr=bp["sr"],
|
||||
mono=False,
|
||||
dtype=np.float32,
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
if X_wave[d].ndim == 1:
|
||||
X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]])
|
||||
else: # lower bands
|
||||
X_wave[d] = librosa.core.resample(
|
||||
X_wave[d + 1],
|
||||
orig_sr = self.mp.param["band"][d + 1]["sr"],
|
||||
target_sr = bp["sr"],
|
||||
res_type = bp["res_type"],
|
||||
orig_sr=self.mp.param["band"][d + 1]["sr"],
|
||||
target_sr=bp["sr"],
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
# Stft of wave source
|
||||
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(
|
||||
@@ -273,9 +259,7 @@ class AudioPreDeEcho:
|
||||
input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
|
||||
self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"]
|
||||
)
|
||||
input_high_end = X_spec_s[d][
|
||||
:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :
|
||||
]
|
||||
input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :]
|
||||
|
||||
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp)
|
||||
aggresive_set = float(self.data["agg"] / 100)
|
||||
@@ -284,9 +268,7 @@ class AudioPreDeEcho:
|
||||
"split_bin": self.mp.param["band"][1]["crop_stop"],
|
||||
}
|
||||
with torch.no_grad():
|
||||
pred, X_mag, X_phase = inference(
|
||||
X_spec_m, self.device, self.model, aggressiveness, self.data
|
||||
)
|
||||
pred, X_mag, X_phase = inference(X_spec_m, self.device, self.model, aggressiveness, self.data)
|
||||
# Postprocess
|
||||
if self.data["postprocess"]:
|
||||
pred_inv = np.clip(X_mag - pred, 0, np.inf)
|
||||
@@ -296,9 +278,7 @@ class AudioPreDeEcho:
|
||||
|
||||
if ins_root is not None:
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(
|
||||
self.data["high_end_process"], y_spec_m, input_high_end, self.mp
|
||||
)
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], y_spec_m, input_high_end, self.mp)
|
||||
wav_instrument = spec_utils.cmb_spectrogram_to_wave(
|
||||
y_spec_m, self.mp, input_high_end_h, input_high_end_
|
||||
)
|
||||
@@ -315,9 +295,7 @@ class AudioPreDeEcho:
|
||||
self.mp.param["sr"],
|
||||
) #
|
||||
else:
|
||||
path = os.path.join(
|
||||
ins_root, "vocal_{}_{}.wav".format(name, self.data["agg"])
|
||||
)
|
||||
path = os.path.join(ins_root, "vocal_{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_instrument) * 32768).astype("int16"),
|
||||
@@ -333,12 +311,8 @@ class AudioPreDeEcho:
|
||||
pass
|
||||
if vocal_root is not None:
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(
|
||||
self.data["high_end_process"], v_spec_m, input_high_end, self.mp
|
||||
)
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(
|
||||
v_spec_m, self.mp, input_high_end_h, input_high_end_
|
||||
)
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], v_spec_m, input_high_end, self.mp)
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp, input_high_end_h, input_high_end_)
|
||||
else:
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
|
||||
logger.info("%s vocals done" % name)
|
||||
@@ -352,9 +326,7 @@ class AudioPreDeEcho:
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
else:
|
||||
path = os.path.join(
|
||||
vocal_root, "instrument_{}_{}.wav".format(name, self.data["agg"])
|
||||
)
|
||||
path = os.path.join(vocal_root, "instrument_{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_vocals) * 32768).astype("int16"),
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import os
|
||||
import traceback,gradio as gr
|
||||
import traceback
|
||||
import gradio as gr
|
||||
import logging
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import librosa,ffmpeg
|
||||
import soundfile as sf
|
||||
import ffmpeg
|
||||
import torch
|
||||
import sys
|
||||
from mdxnet import MDXNetDereverb
|
||||
@@ -16,8 +17,10 @@ from bsroformer import Roformer_Loader
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
analytics.version_check = lambda:None
|
||||
except:...
|
||||
|
||||
analytics.version_check = lambda: None
|
||||
except:
|
||||
...
|
||||
|
||||
weight_uvr5_root = "tools/uvr5/uvr5_weights"
|
||||
uvr5_names = []
|
||||
@@ -25,21 +28,24 @@ for name in os.listdir(weight_uvr5_root):
|
||||
if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name:
|
||||
uvr5_names.append(name.replace(".pth", "").replace(".ckpt", ""))
|
||||
|
||||
device=sys.argv[1]
|
||||
is_half=eval(sys.argv[2])
|
||||
webui_port_uvr5=int(sys.argv[3])
|
||||
is_share=eval(sys.argv[4])
|
||||
device = sys.argv[1]
|
||||
is_half = eval(sys.argv[2])
|
||||
webui_port_uvr5 = int(sys.argv[3])
|
||||
is_share = eval(sys.argv[4])
|
||||
|
||||
def html_left(text, label='p'):
|
||||
|
||||
def html_left(text, label="p"):
|
||||
return f"""<div style="text-align: left; margin: 0; padding: 0;">
|
||||
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
||||
</div>"""
|
||||
|
||||
def html_center(text, label='p'):
|
||||
|
||||
def html_center(text, label="p"):
|
||||
return f"""<div style="text-align: center; margin: 100; padding: 50;">
|
||||
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
||||
</div>"""
|
||||
|
||||
|
||||
def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format0):
|
||||
infos = []
|
||||
try:
|
||||
@@ -52,13 +58,15 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
elif "roformer" in model_name.lower():
|
||||
func = Roformer_Loader
|
||||
pre_fun = func(
|
||||
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||
config_path = os.path.join(weight_uvr5_root, model_name + ".yaml"),
|
||||
device = device,
|
||||
is_half=is_half
|
||||
model_path=os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||
config_path=os.path.join(weight_uvr5_root, model_name + ".yaml"),
|
||||
device=device,
|
||||
is_half=is_half,
|
||||
)
|
||||
if not os.path.exists(os.path.join(weight_uvr5_root, model_name + ".yaml")):
|
||||
infos.append("Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning.")
|
||||
infos.append(
|
||||
"Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning."
|
||||
)
|
||||
yield "\n".join(infos)
|
||||
else:
|
||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||
@@ -74,19 +82,15 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
paths = [path.name for path in paths]
|
||||
for path in paths:
|
||||
inp_path = os.path.join(inp_root, path)
|
||||
if(os.path.isfile(inp_path)==False):continue
|
||||
if os.path.isfile(inp_path) == False:
|
||||
continue
|
||||
need_reformat = 1
|
||||
done = 0
|
||||
try:
|
||||
info = ffmpeg.probe(inp_path, cmd="ffprobe")
|
||||
if (
|
||||
info["streams"][0]["channels"] == 2
|
||||
and info["streams"][0]["sample_rate"] == "44100"
|
||||
):
|
||||
if info["streams"][0]["channels"] == 2 and info["streams"][0]["sample_rate"] == "44100":
|
||||
need_reformat = 0
|
||||
pre_fun._path_audio_(
|
||||
inp_path, save_root_ins, save_root_vocal, format0,is_hp3
|
||||
)
|
||||
pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3)
|
||||
done = 1
|
||||
except:
|
||||
need_reformat = 1
|
||||
@@ -96,21 +100,15 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
os.path.join(os.environ["TEMP"]),
|
||||
os.path.basename(inp_path),
|
||||
)
|
||||
os.system(
|
||||
f'ffmpeg -i "{inp_path}" -vn -acodec pcm_s16le -ac 2 -ar 44100 "{tmp_path}" -y'
|
||||
)
|
||||
os.system(f'ffmpeg -i "{inp_path}" -vn -acodec pcm_s16le -ac 2 -ar 44100 "{tmp_path}" -y')
|
||||
inp_path = tmp_path
|
||||
try:
|
||||
if done == 0:
|
||||
pre_fun._path_audio_(
|
||||
inp_path, save_root_ins, save_root_vocal, format0,is_hp3
|
||||
)
|
||||
pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3)
|
||||
infos.append("%s->Success" % (os.path.basename(inp_path)))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(
|
||||
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
|
||||
)
|
||||
infos.append("%s->%s" % (os.path.basename(inp_path), traceback.format_exc()))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(traceback.format_exc())
|
||||
@@ -130,80 +128,98 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
torch.cuda.empty_cache()
|
||||
yield "\n".join(infos)
|
||||
|
||||
|
||||
with gr.Blocks(title="UVR5 WebUI") as app:
|
||||
gr.Markdown(
|
||||
value=
|
||||
i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||
+ "<br>"
|
||||
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
)
|
||||
with gr.Group():
|
||||
gr.Markdown(html_center(i18n("伴奏人声分离&去混响&去回声"),'h2'))
|
||||
gr.Markdown(html_center(i18n("伴奏人声分离&去混响&去回声"), "h2"))
|
||||
with gr.Group():
|
||||
gr.Markdown(
|
||||
value=html_left(i18n("人声伴奏分离批量处理, 使用UVR5模型。") + "<br>" + \
|
||||
i18n("合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。")+ "<br>" + \
|
||||
i18n("模型分为三类:") + "<br>" + \
|
||||
i18n("1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;") + "<br>" + \
|
||||
i18n("2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;") + "<br>" + \
|
||||
i18n("3、去混响、去延迟模型(by FoxJoy):") + "<br> " + \
|
||||
i18n("(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;") + "<br> " + \
|
||||
i18n("(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。") + "<br>" + \
|
||||
i18n("去混响/去延迟,附:") + "<br>" + \
|
||||
i18n("1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;") + "<br>" + \
|
||||
i18n("2、MDX-Net-Dereverb模型挺慢的;") + "<br>" + \
|
||||
i18n("3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。"),'h4')
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
|
||||
dir_wav_input = gr.Textbox(
|
||||
label=i18n("输入待处理音频文件夹路径"),
|
||||
placeholder="C:\\Users\\Desktop\\todo-songs",
|
||||
)
|
||||
wav_inputs = gr.File(
|
||||
file_count="multiple", label=i18n("也可批量输入音频文件, 二选一, 优先读文件夹")
|
||||
)
|
||||
with gr.Column():
|
||||
agg = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=20,
|
||||
step=1,
|
||||
label=i18n("人声提取激进程度"),
|
||||
value=10,
|
||||
interactive=True,
|
||||
visible=False, # 先不开放调整
|
||||
)
|
||||
opt_vocal_root = gr.Textbox(
|
||||
label=i18n("指定输出主人声文件夹"), value="output/uvr5_opt"
|
||||
)
|
||||
opt_ins_root = gr.Textbox(
|
||||
label=i18n("指定输出非主人声文件夹"), value="output/uvr5_opt"
|
||||
)
|
||||
format0 = gr.Radio(
|
||||
label=i18n("导出文件格式"),
|
||||
choices=["wav", "flac", "mp3", "m4a"],
|
||||
value="flac",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
but2 = gr.Button(i18n("转换"), variant="primary")
|
||||
with gr.Row():
|
||||
vc_output4 = gr.Textbox(label=i18n("输出信息"),lines=3)
|
||||
but2.click(
|
||||
uvr,
|
||||
[
|
||||
model_choose,
|
||||
dir_wav_input,
|
||||
opt_vocal_root,
|
||||
wav_inputs,
|
||||
opt_ins_root,
|
||||
agg,
|
||||
format0,
|
||||
],
|
||||
[vc_output4],
|
||||
api_name="uvr_convert",
|
||||
gr.Markdown(
|
||||
value=html_left(
|
||||
i18n("人声伴奏分离批量处理, 使用UVR5模型。")
|
||||
+ "<br>"
|
||||
+ i18n(
|
||||
"合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。"
|
||||
)
|
||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
||||
+ "<br>"
|
||||
+ i18n("模型分为三类:")
|
||||
+ "<br>"
|
||||
+ i18n(
|
||||
"1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;"
|
||||
)
|
||||
+ "<br>"
|
||||
+ i18n("2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;")
|
||||
+ "<br>"
|
||||
+ i18n("3、去混响、去延迟模型(by FoxJoy):")
|
||||
+ "<br> "
|
||||
+ i18n("(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;")
|
||||
+ "<br> "
|
||||
+ i18n(
|
||||
"(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。"
|
||||
)
|
||||
+ "<br>"
|
||||
+ i18n("去混响/去延迟,附:")
|
||||
+ "<br>"
|
||||
+ i18n("1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;")
|
||||
+ "<br>"
|
||||
+ i18n("2、MDX-Net-Dereverb模型挺慢的;")
|
||||
+ "<br>"
|
||||
+ i18n("3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。"),
|
||||
"h4",
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
|
||||
dir_wav_input = gr.Textbox(
|
||||
label=i18n("输入待处理音频文件夹路径"),
|
||||
placeholder="C:\\Users\\Desktop\\todo-songs",
|
||||
)
|
||||
wav_inputs = gr.File(
|
||||
file_count="multiple", label=i18n("也可批量输入音频文件, 二选一, 优先读文件夹")
|
||||
)
|
||||
with gr.Column():
|
||||
agg = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=20,
|
||||
step=1,
|
||||
label=i18n("人声提取激进程度"),
|
||||
value=10,
|
||||
interactive=True,
|
||||
visible=False, # 先不开放调整
|
||||
)
|
||||
opt_vocal_root = gr.Textbox(label=i18n("指定输出主人声文件夹"), value="output/uvr5_opt")
|
||||
opt_ins_root = gr.Textbox(label=i18n("指定输出非主人声文件夹"), value="output/uvr5_opt")
|
||||
format0 = gr.Radio(
|
||||
label=i18n("导出文件格式"),
|
||||
choices=["wav", "flac", "mp3", "m4a"],
|
||||
value="flac",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
but2 = gr.Button(i18n("转换"), variant="primary")
|
||||
with gr.Row():
|
||||
vc_output4 = gr.Textbox(label=i18n("输出信息"), lines=3)
|
||||
but2.click(
|
||||
uvr,
|
||||
[
|
||||
model_choose,
|
||||
dir_wav_input,
|
||||
opt_vocal_root,
|
||||
wav_inputs,
|
||||
opt_ins_root,
|
||||
agg,
|
||||
format0,
|
||||
],
|
||||
[vc_output4],
|
||||
api_name="uvr_convert",
|
||||
)
|
||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
share=is_share,
|
||||
|
||||
Reference in New Issue
Block a user