Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,23 +1,22 @@
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
from torch import nn
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
from tools.my_utils import load_audio
import os
import json
import os
import soundfile
from text import cleaned_text_to_sequence
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@@ -73,7 +72,7 @@ class T2SEncoder(nn.Module):
super().__init__()
self.encoder = t2s.onnx_encoder
self.vits = vits
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
codes = self.vits.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
@@ -102,22 +101,22 @@ class T2SModel(nn.Module):
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
self.first_stage_decoder = self.t2s_model.first_stage_decoder
self.stage_decoder = self.t2s_model.stage_decoder
#self.t2s_model = torch.jit.script(self.t2s_model)
# self.t2s_model = torch.jit.script(self.t2s_model)
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.t2s_model.early_stop_num
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1]
#[1,N,512] [1,N]
# [1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
@@ -131,13 +130,11 @@ class T2SModel(nn.Module):
return y[:, -idx:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
if dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_encoder_export_output = torch.onnx.dynamo_export(
self.onnx_encoder,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
export_options=export_options
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
)
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
return
@@ -149,13 +146,13 @@ class T2SModel(nn.Module):
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"],
dynamic_axes={
"ref_seq": {1 : "ref_length"},
"text_seq": {1 : "text_length"},
"ref_bert": {0 : "ref_length"},
"text_bert": {0 : "text_length"},
"ssl_content": {2 : "ssl_length"},
"ref_seq": {1: "ref_length"},
"text_seq": {1: "text_length"},
"ref_bert": {0: "ref_length"},
"text_bert": {0: "text_length"},
"ssl_content": {2: "ssl_length"},
},
opset_version=16
opset_version=16,
)
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
@@ -166,11 +163,11 @@ class T2SModel(nn.Module):
input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={
"x": {1 : "x_length"},
"prompts": {1 : "prompts_length"},
"x": {1: "x_length"},
"prompts": {1: "prompts_length"},
},
verbose=False,
opset_version=16
opset_version=16,
)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
@@ -181,38 +178,38 @@ class T2SModel(nn.Module):
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={
"iy": {1 : "iy_length"},
"ik": {1 : "ik_length"},
"iv": {1 : "iv_length"},
"iy_emb": {1 : "iy_emb_length"},
"ix_example": {1 : "ix_example_length"},
"iy": {1: "iy_length"},
"ik": {1: "ik_length"},
"iv": {1: "iv_length"},
"iy_emb": {1: "iy_emb_length"},
"ix_example": {1: "ix_example_length"},
},
verbose=False,
opset_version=16
opset_version=16,
)
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path, map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
@@ -220,7 +217,7 @@ class VitsModel(nn.Module):
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
@@ -230,18 +227,22 @@ class GptSoVits(nn.Module):
super().__init__()
self.vits = vits
self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug:
import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(None, {
"text_seq" : text_seq.detach().cpu().numpy(),
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
"ref_audio" : ref_audio.detach().cpu().numpy()
})
audio1 = sess.run(
None,
{
"text_seq": text_seq.detach().cpu().numpy(),
"pred_semantic": pred_semantic.detach().cpu().numpy(),
"ref_audio": ref_audio.detach().cpu().numpy(),
},
)
return audio, audio1
return audio
@@ -255,12 +256,12 @@ class GptSoVits(nn.Module):
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["audio"],
dynamic_axes={
"text_seq": {1 : "text_length"},
"pred_semantic": {2 : "pred_length"},
"ref_audio": {1 : "audio_length"},
"text_seq": {1: "text_length"},
"pred_semantic": {2: "pred_length"},
"ref_audio": {1: "audio_length"},
},
opset_version=17,
verbose=False
verbose=False,
)
@@ -278,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"n",
"i2",
"h",
"ao3",
",",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
text_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
try:
os.mkdir(f"onnx/{project_name}")
@@ -326,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
if __name__ == "__main__":
@@ -341,4 +395,4 @@ if __name__ == "__main__":
exp_path = "nahida"
export(vits_path, gpt_path, exp_path)
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)