Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -13,9 +13,7 @@ cpu = torch.device("cpu")
class ConvTDFNetTrim:
def __init__(
self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024
):
def __init__(self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024):
super(ConvTDFNetTrim, self).__init__()
self.dim_f = dim_f
@@ -24,17 +22,13 @@ class ConvTDFNetTrim:
self.hop = hop
self.n_bins = self.n_fft // 2 + 1
self.chunk_size = hop * (self.dim_t - 1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(
device
)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
self.target_name = target_name
self.blender = "blender" in model_name
self.dim_c = 4
out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
self.freq_pad = torch.zeros(
[1, out_c, self.n_bins - self.dim_f, self.dim_t]
).to(device)
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
self.n = L // 2
@@ -50,28 +44,18 @@ class ConvTDFNetTrim:
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
[-1, self.dim_c, self.n_bins, self.dim_t]
)
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
return x[:, :, : self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = (
self.freq_pad.repeat([x.shape[0], 1, 1, 1])
if freq_pad is None
else freq_pad
)
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
c = 4 * 2 if self.target_name == "*" else 2
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
[-1, 2, self.n_bins, self.dim_t]
)
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
x = x.permute([0, 2, 3, 1])
x = x.contiguous()
x = torch.view_as_complex(x)
x = torch.istft(
x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1, c, self.chunk_size])
@@ -93,9 +77,7 @@ class Predictor:
logger.info(ort.get_available_providers())
self.args = args
self.model_ = get_models(
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
)
self.model_ = get_models(device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft)
self.model = ort.InferenceSession(
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
providers=[
@@ -152,9 +134,7 @@ class Predictor:
trim = model.n_fft // 2
gen_size = model.chunk_size - 2 * trim
pad = gen_size - n_sample % gen_size
mix_p = np.concatenate(
(np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
)
mix_p = np.concatenate((np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1)
mix_waves = []
i = 0
while i < n_sample + pad:
@@ -172,15 +152,8 @@ class Predictor:
)
tar_waves = model.istft(torch.tensor(spec_pred))
else:
tar_waves = model.istft(
torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
)
tar_signal = (
tar_waves[:, :, trim:-trim]
.transpose(0, 1)
.reshape(2, -1)
.numpy()[:, :-pad]
)
tar_waves = model.istft(torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0]))
tar_signal = tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
start = 0 if mix == 0 else margin_size
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
@@ -207,9 +180,7 @@ class Predictor:
sources = self.demix(mix.T)
opt = sources[0].T
if format in ["wav", "flac"]:
sf.write(
"%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
)
sf.write("%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate)
sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
else:
path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
@@ -219,18 +190,14 @@ class Predictor:
opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
)
os.system("ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal))
if os.path.exists(opt_path_vocal):
try:
os.remove(path_vocal)
except:
pass
if os.path.exists(path_other):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
)
os.system("ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other))
if os.path.exists(opt_path_other):
try:
os.remove(path_other)
@@ -240,7 +207,7 @@ class Predictor:
class MDXNetDereverb:
def __init__(self, chunks):
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy"%os.path.dirname(os.path.abspath(__file__))
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy" % os.path.dirname(os.path.abspath(__file__))
self.shifts = 10 # 'Predict with randomised equivariant stabilisation'
self.mixing = "min_mag" # ['default','min_mag','max_mag']
self.chunks = chunks