Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -6,16 +6,16 @@ from export_torch_script import (
spectrogram_torch,
)
from f5_tts.model.backbones.dit import DiT
from feature_extractor import cnhubert
from inference_webui import get_phones_and_bert
import librosa
from module import commons
from module.mel_processing import mel_spectrogram_torch, spectral_normalize_torch
from module.mel_processing import mel_spectrogram_torch
from module.models_onnx import CFM, SynthesizerTrnV3
import numpy as np
import torch._dynamo.config
import torchaudio
import logging, uvicorn
import logging
import uvicorn
import torch
import soundfile
from librosa.filters import mel as librosa_mel_fn
@@ -32,7 +32,6 @@ now_dir = os.getcwd()
class MelSpectrgram(torch.nn.Module):
def __init__(
self,
dtype,
@@ -48,14 +47,12 @@ class MelSpectrgram(torch.nn.Module):
):
super().__init__()
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
self.n_fft:int = n_fft
self.hop_size:int = hop_size
self.win_size:int = win_size
self.center:bool = center
self.n_fft: int = n_fft
self.hop_size: int = hop_size
self.win_size: int = win_size
self.center: bool = center
def forward(self, y):
y = torch.nn.functional.pad(
@@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
):
T_min = fea_ref.size(2)
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = self.cfm(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
)
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
@@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
spec_min = -12
spec_max = 2
@torch.jit.script
def norm_spec(x):
spec_min = -12
@@ -212,7 +208,6 @@ def denorm_spec(x):
class ExportGPTSovitsHalf(torch.nn.Module):
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
super().__init__()
self.hps = hps
@@ -231,15 +226,15 @@ class ExportGPTSovitsHalf(torch.nn.Module):
center=False,
)
# self.dtype = dtype
self.filter_length:int = hps.data.filter_length
self.sampling_rate:int = hps.data.sampling_rate
self.hop_length:int = hps.data.hop_length
self.win_length:int = hps.data.win_length
self.filter_length: int = hps.data.filter_length
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
def forward(
self,
ssl_content,
ref_audio_32k:torch.FloatTensor,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0,
phoneme_ids1,
bert1,
@@ -255,21 +250,17 @@ class ExportGPTSovitsHalf(torch.nn.Module):
center=False,
).to(ssl_content.dtype)
codes = self.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0)
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
pred_semantic = self.t2s_m(
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
ge = self.vq_model.create_ge(refer)
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
prompt_ = prompt.unsqueeze(0)
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
return fea_ref, fea_todo, mel2
class GPTSoVITSV3(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, bigvgan):
super().__init__()
@@ -303,9 +295,9 @@ class GPTSoVITSV3(torch.nn.Module):
def forward(
self,
ssl_content,
ref_audio_32k:torch.FloatTensor,
phoneme_ids0:torch.LongTensor,
phoneme_ids1:torch.LongTensor,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0: torch.LongTensor,
phoneme_ids1: torch.LongTensor,
bert1,
bert2,
top_k: torch.LongTensor,
@@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
):
# current_time = datetime.now()
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 934 - fea_ref.shape[2]
wav_gen_list = []
idx = 0
@@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
complete_len = chunk_len - fea_todo_chunk.shape[-1]
if complete_len != 0:
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2)
fea_todo_chunk = torch.cat(
[
fea_todo_chunk,
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
],
2,
)
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
idx += chunk_len
@@ -339,17 +339,17 @@ class GPTSoVITSV3(torch.nn.Module):
cfm_res = denorm_spec(cfm_res)
bigvgan_res = self.bigvgan(cfm_res)
wav_gen_list.append(bigvgan_res)
wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length]
def init_bigvgan():
global bigvgan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x"
% (now_dir,),
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
@@ -467,10 +467,7 @@ def export_cfm(
cfm = e_cfm.cfm
B, T = mu.size(0), mu.size(1)
x = (
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
* temperature
)
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
print("x:", x.shape, x.dtype)
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
@@ -565,11 +562,7 @@ def export():
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = sovits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
@@ -626,10 +619,7 @@ def export():
"create_ge": refer,
}
trace_vq_model = torch.jit.trace_module(
sovits.vq_model, inputs, optimize=True
)
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
trace_vq_model.save("onnx/ad/vq_model.pt")
print(fea_ref.shape, fea_ref.dtype, ge.shape)
@@ -714,9 +704,7 @@ def export():
idx += chunk_len
cfm_res, fea_ref, mel2 = export_cfm_(
fea_ref, fea_todo_chunk, mel2, sample_steps
)
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
cfm_resss.append(cfm_res)
continue
@@ -726,9 +714,7 @@ def export():
with torch.inference_mode():
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
bigvgan_model_ = torch.jit.trace(
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
)
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
wav_gen = bigvgan_model(cmf_res)
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
@@ -748,7 +734,6 @@ def test_export(
bigvgan,
output,
):
# hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0
@@ -773,13 +758,9 @@ def test_export(
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
phones1, bert1, norm_text1 = get_phones_and_bert(
@@ -799,8 +780,18 @@ def test_export(
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time)
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
print(
ssl_content.shape,
ref_audio_32k.shape,
phoneme_ids0.shape,
phoneme_ids1.shape,
bert1.shape,
bert2.shape,
top_k.shape,
)
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 934 - fea_ref.shape[2]
print(fea_ref.shape, fea_todo.shape, mel2.shape)
@@ -812,7 +803,6 @@ def test_export(
wav_gen_length = fea_todo.shape[2] * 256
while 1:
current_time = datetime.now()
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
@@ -861,7 +851,6 @@ def test_export1(
gpt_sovits_v3,
output,
):
# hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0
@@ -886,14 +875,10 @@ def test_export1(
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
phones1, bert1, norm_text1 = get_phones_and_bert(
@@ -913,11 +898,19 @@ def test_export1(
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time)
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
print(
ssl_content.shape,
ref_audio_32k.shape,
phoneme_ids0.shape,
phoneme_ids1.shape,
bert1.shape,
bert2.shape,
top_k.shape,
)
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
wav_gen = torch.cat([wav_gen,zero_wav_torch],0)
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
audio = wav_gen.cpu().detach().numpy()
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@@ -929,20 +922,19 @@ import time
def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
# cfm = ExportCFM(sovits.cfm)
# cfm.cfm.estimator = dit
sovits.cfm = None
cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device)
# cfm = torch.jit.optimize_for_inference(cfm)
cfm = cfm.half().to(device)
cfm.eval()
logger.info(f"cfm ok")
logger.info("cfm ok")
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
# v2 的 gpt 也可以用
@@ -957,17 +949,14 @@ def test_():
t2s_m = torch.jit.script(t2s_m)
t2s_m.eval()
# t2s_m.top_k = 15
logger.info(f"t2s_m ok")
logger.info("t2s_m ok")
vq_model: torch.jit.ScriptModule = torch.jit.load(
"onnx/ad/vq_model.pt", map_location=device
)
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
# vq_model = torch.jit.optimize_for_inference(vq_model)
# vq_model = vq_model.half().to(device)
vq_model.eval()
# vq_model = sovits.vq_model
logger.info(f"vq_model ok")
logger.info("vq_model ok")
# gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
# gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
@@ -975,7 +964,7 @@ def test_():
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
# gpt_sovits_v3_half.eval()
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
logger.info(f"gpt_sovits_v3_half ok")
logger.info("gpt_sovits_v3_half ok")
# init_bigvgan()
# global bigvgan_model
@@ -985,7 +974,7 @@ def test_():
bigvgan_model = bigvgan_model.cuda()
bigvgan_model.eval()
logger.info(f"bigvgan ok")
logger.info("bigvgan ok")
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
@@ -1020,8 +1009,9 @@ def test_():
# "out2.wav",
# )
def test_export_gpt_sovits_v3():
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt",map_location=device)
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
# test_export1(
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
# gpt_sovits_v3,