Fix onnx_export to support v2 (#1604)

This commit is contained in:
zzz
2024-09-13 11:27:22 +08:00
committed by GitHub
parent 570da092c9
commit 0c000191b3
2 changed files with 58 additions and 31 deletions

View File

@@ -1,11 +1,12 @@
from module.models_onnx import SynthesizerTrn, symbols
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
from torch import nn
from feature_extractor import cnhubert
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path=cnhubert_base_path
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
@@ -196,6 +197,11 @@ class VitsModel(nn.Module):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
@@ -267,13 +273,13 @@ class SSLModel(nn.Module):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(vits_path, gpt_path, project_name):
def export(vits_path, gpt_path, project_name, vits_model="v2"):
vits = VitsModel(vits_path)
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
@@ -287,34 +293,38 @@ def export(vits_path, gpt_path, project_name):
pass
ssl_content = ssl(ref_audio_16k).float()
debug = False
# debug = False
debug = True
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
return
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
else:
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if vits_model == "v1":
symbols = symbols_v1
else:
symbols = symbols_v2
MoeVSConf = {
"Folder" : f"{project_name}",
"Name" : f"{project_name}",
"Type" : "GPT-SoVits",
"Rate" : vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
"Symbol": symbols,
"AddBlank": False
}
"Folder": f"{project_name}",
"Name": f"{project_name}",
"Type": "GPT-SoVits",
"Rate": vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
# "Symbol": symbols,
"AddBlank": False,
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)