bsroformer support fp16 inference

bsroformer support fp16 inference
This commit is contained in:
RVC-Boss
2024-08-01 21:26:59 +08:00
committed by GitHub
parent 10e885d9ac
commit e62e965323
2 changed files with 17 additions and 7 deletions

View File

@@ -1,4 +1,5 @@
# This code is modified from https://github.com/ZFTurbo/
import pdb
import librosa
from tqdm import tqdm
@@ -10,6 +11,7 @@ import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
from bs_roformer.bs_roformer import BSRoformer
class BsRoformer_Loader:
def get_model_from_config(self):
@@ -40,7 +42,7 @@ class BsRoformer_Loader:
}
from bs_roformer.bs_roformer import BSRoformer
model = BSRoformer(
**dict(config)
)
@@ -95,6 +97,8 @@ class BsRoformer_Loader:
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
if(self.is_half==True):
part=part.half()
batch_data.append(part)
batch_locations.append((i, length))
i += step
@@ -102,6 +106,7 @@ class BsRoformer_Loader:
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0)
# print(23333333,arr.dtype)
x = model(arr)
window = window_middle
@@ -192,14 +197,18 @@ class BsRoformer_Loader:
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def __init__(self, model_path, device):
def __init__(self, model_path, device,is_half):
self.device = device
self.extract_instrumental=True
model = self.get_model_from_config()
state_dict = torch.load(model_path)
state_dict = torch.load(model_path,map_location="cpu")
model.load_state_dict(state_dict)
self.model = model.to(device)
self.is_half=is_half
if(is_half==False):
self.model = model.to(device)
else:
self.model = model.half().to(device)
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):