Support for mel_band_roformer (#2078)
* support for mel_band_roformer * Remove unnecessary audio channel judgments * remove context manager and fix path * Update webui.py * Update README.md
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
# This code is modified from https://github.com/ZFTurbo/
|
||||
import pdb
|
||||
|
||||
import librosa
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
@@ -8,61 +6,113 @@ import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch.nn as nn
|
||||
|
||||
import yaml
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
|
||||
class BsRoformer_Loader:
|
||||
|
||||
class Roformer_Loader:
|
||||
def get_config(self, config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
# use fullloader to load tag !!python/tuple, code can be improved
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return config
|
||||
|
||||
def get_default_config(self):
|
||||
default_config = None
|
||||
if self.model_type == 'bs_roformer':
|
||||
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
|
||||
# Other BS_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 512,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
}
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
|
||||
# Other Mel_Band_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 384,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"num_bands": 60,
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"sample_rate": 44100,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def get_model_from_config(self):
|
||||
config = {
|
||||
"attn_dropout": 0.1,
|
||||
"depth": 12,
|
||||
"dim": 512,
|
||||
"dim_freqs_in": 1025,
|
||||
"dim_head": 64,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"freq_transformer_depth": 1,
|
||||
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"heads": 8,
|
||||
"linear_transformer_depth": 0,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
|
||||
"num_stems": 1,
|
||||
"stereo": True,
|
||||
"stft_hop_length": 441,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_normalized": False,
|
||||
"stft_win_length": 2048,
|
||||
"time_transformer_depth": 1,
|
||||
|
||||
}
|
||||
|
||||
|
||||
model = BSRoformer(
|
||||
**dict(config)
|
||||
)
|
||||
|
||||
if self.model_type == 'bs_roformer':
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
model = BSRoformer(**dict(self.config["model"]))
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
from bs_roformer.mel_band_roformer import MelBandRoformer
|
||||
model = MelBandRoformer(**dict(self.config["model"]))
|
||||
else:
|
||||
print('Error: Unknown model: {}'.format(self.model_type))
|
||||
model = None
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def demix_track(self, model, mix, device):
|
||||
C = 352800
|
||||
# num_overlap
|
||||
N = 1
|
||||
C = self.config["audio"]["chunk_size"] # chunk_size
|
||||
N = self.config["inference"]["num_overlap"]
|
||||
fade_size = C // 10
|
||||
step = int(C // N)
|
||||
border = C - step
|
||||
batch_size = 4
|
||||
batch_size = self.config["inference"]["batch_size"]
|
||||
|
||||
length_init = mix.shape[-1]
|
||||
|
||||
progress_bar = tqdm(total=length_init // step + 1)
|
||||
progress_bar.set_description("Processing")
|
||||
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False)
|
||||
|
||||
# Do pad from the beginning and end to account floating window results better
|
||||
if length_init > 2 * border and (border > 0):
|
||||
@@ -82,7 +132,10 @@ class BsRoformer_Loader:
|
||||
|
||||
with torch.amp.autocast('cuda'):
|
||||
with torch.inference_mode():
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
|
||||
else:
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
|
||||
result = torch.zeros(req_shape, dtype=torch.float32)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32)
|
||||
@@ -97,7 +150,7 @@ class BsRoformer_Loader:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
||||
else:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
||||
if(self.is_half==True):
|
||||
if self.is_half:
|
||||
part=part.half()
|
||||
batch_data.append(part)
|
||||
batch_locations.append((i, length))
|
||||
@@ -133,78 +186,109 @@ class BsRoformer_Loader:
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
|
||||
else:
|
||||
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
|
||||
|
||||
|
||||
def run_folder(self,input, vocal_root, others_root, format):
|
||||
# start_time = time.time()
|
||||
def run_folder(self, input, vocal_root, others_root, format):
|
||||
self.model.eval()
|
||||
path = input
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
os.makedirs(others_root, exist_ok=True)
|
||||
file_base_name = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
if not os.path.isdir(vocal_root):
|
||||
os.mkdir(vocal_root)
|
||||
|
||||
if not os.path.isdir(others_root):
|
||||
os.mkdir(others_root)
|
||||
sample_rate = 44100
|
||||
if 'sample_rate' in self.config["audio"]:
|
||||
sample_rate = self.config["audio"]['sample_rate']
|
||||
|
||||
try:
|
||||
mix, sr = librosa.load(path, sr=44100, mono=False)
|
||||
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
||||
except Exception as e:
|
||||
print('Can read track: {}'.format(path))
|
||||
print('Error message: {}'.format(str(e)))
|
||||
return
|
||||
|
||||
# Convert mono to stereo if needed
|
||||
if len(mix.shape) == 1:
|
||||
mix = np.stack([mix, mix], axis=0)
|
||||
# in case if model only supports mono tracks
|
||||
isstereo = self.config["model"].get("stereo", True)
|
||||
if not isstereo and len(mix.shape) != 1:
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
|
||||
|
||||
mix_orig = mix.copy()
|
||||
|
||||
mixture = torch.tensor(mix, dtype=torch.float32)
|
||||
res = self.demix_track(self.model, mixture, self.device)
|
||||
|
||||
estimates = res['vocals'].T
|
||||
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
|
||||
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
|
||||
if self.config["training"]["target_instrument"] is not None:
|
||||
# if target instrument is specified, save target instrument as vocal and other instruments as others
|
||||
# other instruments are caculated by subtracting target instrument from mixture
|
||||
target_instrument = self.config["training"]["target_instrument"]
|
||||
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
|
||||
other = mix_orig - res[target_instrument] # caculate other instruments
|
||||
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
|
||||
self.save_audio(path_vocal, res[target_instrument].T, sr, format)
|
||||
self.save_audio(path_other, other.T, sr, format)
|
||||
else:
|
||||
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
|
||||
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
|
||||
sf.write(path_vocal, estimates, sr)
|
||||
sf.write(path_other, mix_orig.T - estimates, sr)
|
||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||
opt_path_other = path_other[:-4] + ".%s" % format
|
||||
if os.path.exists(path_vocal):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
|
||||
)
|
||||
if os.path.exists(opt_path_vocal):
|
||||
try:
|
||||
os.remove(path_vocal)
|
||||
except:
|
||||
pass
|
||||
if os.path.exists(path_other):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
|
||||
)
|
||||
if os.path.exists(opt_path_other):
|
||||
try:
|
||||
os.remove(path_other)
|
||||
except:
|
||||
pass
|
||||
|
||||
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
||||
# if target instrument is not specified, save the first instrument as vocal and the rest as others
|
||||
vocal_inst = self.config["training"]["instruments"][0]
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
|
||||
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
|
||||
for other in self.config["training"]["instruments"][1:]: # save other instruments
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
|
||||
self.save_audio(path_other, res[other].T, sr, format)
|
||||
|
||||
|
||||
def __init__(self, model_path, device,is_half):
|
||||
def save_audio(self, path, data, sr, format):
|
||||
# input path should be endwith '.wav'
|
||||
if format in ["wav", "flac"]:
|
||||
if format == "flac":
|
||||
path = path[:-3] + "flac"
|
||||
sf.write(path, data, sr)
|
||||
else:
|
||||
sf.write(path, data, sr)
|
||||
os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
|
||||
try: os.remove(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def __init__(self, model_path, config_path, device, is_half):
|
||||
self.device = device
|
||||
self.extract_instrumental=True
|
||||
self.is_half = is_half
|
||||
self.model_type = None
|
||||
self.config = None
|
||||
|
||||
# get model_type, first try:
|
||||
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower():
|
||||
self.model_type = "bs_roformer"
|
||||
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower():
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, raise an error
|
||||
raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again.")
|
||||
self.config = self.get_default_config()
|
||||
else:
|
||||
# if there is a configuration file
|
||||
self.config = self.get_config(config_path)
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, second try, get model_type from the configuration file
|
||||
if "freqs_per_bands" in self.config["model"]:
|
||||
# if freqs_per_bands in config, it's a bs_roformer model
|
||||
self.model_type = "bs_roformer"
|
||||
else:
|
||||
# else it's a mel_band_roformer model
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
print("Detected model type: {}".format(self.model_type))
|
||||
model = self.get_model_from_config()
|
||||
state_dict = torch.load(model_path,map_location="cpu")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
self.is_half=is_half
|
||||
|
||||
if(is_half==False):
|
||||
self.model = model.to(device)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user