Merge pull request #2460 from L-jasmine/export_v2pro
优化 torch_script 导出模型
This commit is contained in:
@@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
@@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
@@ -1149,7 +1153,7 @@ def export_2(version="v3"):
|
||||
raw_t2s = raw_t2s.half().to(device)
|
||||
t2s_m = T2SModel(raw_t2s).half().to(device)
|
||||
t2s_m.eval()
|
||||
t2s_m = torch.jit.script(t2s_m)
|
||||
t2s_m = torch.jit.script(t2s_m).to(device)
|
||||
t2s_m.eval()
|
||||
# t2s_m.top_k = 15
|
||||
logger.info("t2s_m ok")
|
||||
@@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3():
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
# export_2("v4")
|
||||
# export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
export_2("v4")
|
||||
# test_export_gpt_sovits_v3()
|
||||
|
||||
Reference in New Issue
Block a user