Merge pull request #2460 from L-jasmine/export_v2pro

优化 torch_script 导出模型
This commit is contained in:
zzz
2025-06-13 22:10:11 +08:00
committed by GitHub
parent 1a9b8854ee
commit 7dec5f5bb0
2 changed files with 25 additions and 16 deletions

View File

@@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward(
self,
@@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
top_k,
):
refer = spectrogram_torch(
self.hann_window,
ref_audio_32k,
self.filter_length,
self.sampling_rate,
@@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward(
self,
@@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
top_k,
):
refer = spectrogram_torch(
self.hann_window,
ref_audio_32k,
self.filter_length,
self.sampling_rate,
@@ -1149,7 +1153,7 @@ def export_2(version="v3"):
raw_t2s = raw_t2s.half().to(device)
t2s_m = T2SModel(raw_t2s).half().to(device)
t2s_m.eval()
t2s_m = torch.jit.script(t2s_m)
t2s_m = torch.jit.script(t2s_m).to(device)
t2s_m.eval()
# t2s_m.top_k = 15
logger.info("t2s_m ok")
@@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3():
with torch.no_grad():
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
# export_2("v4")
# export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
export_2("v4")
# test_export_gpt_sovits_v3()