bsroformer support fp16 inference
bsroformer support fp16 inference
This commit is contained in:
@@ -17,8 +17,8 @@ from bsroformer import BsRoformer_Loader
|
||||
weight_uvr5_root = "tools/uvr5/uvr5_weights"
|
||||
uvr5_names = []
|
||||
for name in os.listdir(weight_uvr5_root):
|
||||
if name.endswith(".pth") or "onnx" in name:
|
||||
uvr5_names.append(name.replace(".pth", ""))
|
||||
if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name:
|
||||
uvr5_names.append(name.replace(".pth", "").replace(".ckpt", ""))
|
||||
|
||||
device=sys.argv[1]
|
||||
is_half=eval(sys.argv[2])
|
||||
@@ -37,8 +37,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
|
||||
func = BsRoformer_Loader
|
||||
pre_fun = func(
|
||||
model_path = os.path.join(weight_uvr5_root, model_name + ".pth"),
|
||||
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||
device = device,
|
||||
is_half=is_half
|
||||
)
|
||||
else:
|
||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||
|
||||
Reference in New Issue
Block a user