Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -6,16 +6,16 @@ from export_torch_script import (
|
||||
spectrogram_torch,
|
||||
)
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from feature_extractor import cnhubert
|
||||
from inference_webui import get_phones_and_bert
|
||||
import librosa
|
||||
from module import commons
|
||||
from module.mel_processing import mel_spectrogram_torch, spectral_normalize_torch
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
from module.models_onnx import CFM, SynthesizerTrnV3
|
||||
import numpy as np
|
||||
import torch._dynamo.config
|
||||
import torchaudio
|
||||
import logging, uvicorn
|
||||
import logging
|
||||
import uvicorn
|
||||
import torch
|
||||
import soundfile
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
@@ -32,7 +32,6 @@ now_dir = os.getcwd()
|
||||
|
||||
|
||||
class MelSpectrgram(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype,
|
||||
@@ -48,14 +47,12 @@ class MelSpectrgram(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||
self.n_fft:int = n_fft
|
||||
self.hop_size:int = hop_size
|
||||
self.win_size:int = win_size
|
||||
self.center:bool = center
|
||||
self.n_fft: int = n_fft
|
||||
self.hop_size: int = hop_size
|
||||
self.win_size: int = win_size
|
||||
self.center: bool = center
|
||||
|
||||
def forward(self, y):
|
||||
y = torch.nn.functional.pad(
|
||||
@@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
|
||||
):
|
||||
T_min = fea_ref.size(2)
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
cfm_res = self.cfm(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
|
||||
)
|
||||
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
@@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_spec(x):
|
||||
spec_min = -12
|
||||
@@ -212,7 +208,6 @@ def denorm_spec(x):
|
||||
|
||||
|
||||
class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
|
||||
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||
super().__init__()
|
||||
self.hps = hps
|
||||
@@ -231,15 +226,15 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
center=False,
|
||||
)
|
||||
# self.dtype = dtype
|
||||
self.filter_length:int = hps.data.filter_length
|
||||
self.sampling_rate:int = hps.data.sampling_rate
|
||||
self.hop_length:int = hps.data.hop_length
|
||||
self.win_length:int = hps.data.win_length
|
||||
self.filter_length: int = hps.data.filter_length
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k:torch.FloatTensor,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0,
|
||||
phoneme_ids1,
|
||||
bert1,
|
||||
@@ -255,21 +250,17 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
center=False,
|
||||
).to(ssl_content.dtype)
|
||||
|
||||
|
||||
codes = self.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
pred_semantic = self.t2s_m(
|
||||
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
|
||||
ge = self.vq_model.create_ge(refer)
|
||||
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
|
||||
prompt_ = prompt.unsqueeze(0)
|
||||
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
|
||||
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
@@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class GPTSoVITSV3(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||
super().__init__()
|
||||
@@ -303,9 +295,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k:torch.FloatTensor,
|
||||
phoneme_ids0:torch.LongTensor,
|
||||
phoneme_ids1:torch.LongTensor,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0: torch.LongTensor,
|
||||
phoneme_ids1: torch.LongTensor,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k: torch.LongTensor,
|
||||
@@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
):
|
||||
# current_time = datetime.now()
|
||||
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
@@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
||||
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||
if complete_len != 0:
|
||||
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2)
|
||||
fea_todo_chunk = torch.cat(
|
||||
[
|
||||
fea_todo_chunk,
|
||||
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
idx += chunk_len
|
||||
@@ -339,17 +339,17 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
bigvgan_res = self.bigvgan(cfm_res)
|
||||
wav_gen_list.append(bigvgan_res)
|
||||
|
||||
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model
|
||||
from BigVGAN import bigvgan
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x"
|
||||
% (now_dir,),
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||
use_cuda_kernel=False,
|
||||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
@@ -467,10 +467,7 @@ def export_cfm(
|
||||
cfm = e_cfm.cfm
|
||||
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = (
|
||||
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||
* temperature
|
||||
)
|
||||
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
||||
print("x:", x.shape, x.dtype)
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
@@ -565,11 +562,7 @@ def export():
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = sovits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
@@ -626,10 +619,7 @@ def export():
|
||||
"create_ge": refer,
|
||||
}
|
||||
|
||||
|
||||
trace_vq_model = torch.jit.trace_module(
|
||||
sovits.vq_model, inputs, optimize=True
|
||||
)
|
||||
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
|
||||
trace_vq_model.save("onnx/ad/vq_model.pt")
|
||||
|
||||
print(fea_ref.shape, fea_ref.dtype, ge.shape)
|
||||
@@ -714,9 +704,7 @@ def export():
|
||||
|
||||
idx += chunk_len
|
||||
|
||||
cfm_res, fea_ref, mel2 = export_cfm_(
|
||||
fea_ref, fea_todo_chunk, mel2, sample_steps
|
||||
)
|
||||
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
cfm_resss.append(cfm_res)
|
||||
continue
|
||||
|
||||
@@ -726,9 +714,7 @@ def export():
|
||||
with torch.inference_mode():
|
||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||
bigvgan_model_ = torch.jit.trace(
|
||||
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
|
||||
)
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
@@ -748,7 +734,6 @@ def test_export(
|
||||
bigvgan,
|
||||
output,
|
||||
):
|
||||
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
@@ -773,13 +758,9 @@ def test_export(
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
|
||||
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
@@ -799,8 +780,18 @@ def test_export(
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
print(
|
||||
ssl_content.shape,
|
||||
ref_audio_32k.shape,
|
||||
phoneme_ids0.shape,
|
||||
phoneme_ids1.shape,
|
||||
bert1.shape,
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
print(fea_ref.shape, fea_todo.shape, mel2.shape)
|
||||
|
||||
@@ -812,7 +803,6 @@ def test_export(
|
||||
wav_gen_length = fea_todo.shape[2] * 256
|
||||
|
||||
while 1:
|
||||
|
||||
current_time = datetime.now()
|
||||
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
@@ -861,7 +851,6 @@ def test_export1(
|
||||
gpt_sovits_v3,
|
||||
output,
|
||||
):
|
||||
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
@@ -886,14 +875,10 @@ def test_export1(
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
|
||||
|
||||
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
@@ -913,11 +898,19 @@ def test_export1(
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
||||
print(
|
||||
ssl_content.shape,
|
||||
ref_audio_32k.shape,
|
||||
phoneme_ids0.shape,
|
||||
phoneme_ids1.shape,
|
||||
bert1.shape,
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
|
||||
wav_gen = torch.cat([wav_gen,zero_wav_torch],0)
|
||||
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
||||
|
||||
audio = wav_gen.cpu().detach().numpy()
|
||||
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
@@ -929,20 +922,19 @@ import time
|
||||
|
||||
|
||||
def test_():
|
||||
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
|
||||
# cfm = ExportCFM(sovits.cfm)
|
||||
# cfm.cfm.estimator = dit
|
||||
sovits.cfm = None
|
||||
|
||||
|
||||
cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device)
|
||||
# cfm = torch.jit.optimize_for_inference(cfm)
|
||||
cfm = cfm.half().to(device)
|
||||
|
||||
|
||||
cfm.eval()
|
||||
|
||||
logger.info(f"cfm ok")
|
||||
logger.info("cfm ok")
|
||||
|
||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||
# v2 的 gpt 也可以用
|
||||
@@ -957,17 +949,14 @@ def test_():
|
||||
t2s_m = torch.jit.script(t2s_m)
|
||||
t2s_m.eval()
|
||||
# t2s_m.top_k = 15
|
||||
logger.info(f"t2s_m ok")
|
||||
logger.info("t2s_m ok")
|
||||
|
||||
|
||||
vq_model: torch.jit.ScriptModule = torch.jit.load(
|
||||
"onnx/ad/vq_model.pt", map_location=device
|
||||
)
|
||||
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
|
||||
# vq_model = torch.jit.optimize_for_inference(vq_model)
|
||||
# vq_model = vq_model.half().to(device)
|
||||
vq_model.eval()
|
||||
# vq_model = sovits.vq_model
|
||||
logger.info(f"vq_model ok")
|
||||
logger.info("vq_model ok")
|
||||
|
||||
# gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
|
||||
# gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
|
||||
@@ -975,7 +964,7 @@ def test_():
|
||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
||||
# gpt_sovits_v3_half.eval()
|
||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||
logger.info(f"gpt_sovits_v3_half ok")
|
||||
logger.info("gpt_sovits_v3_half ok")
|
||||
|
||||
# init_bigvgan()
|
||||
# global bigvgan_model
|
||||
@@ -985,7 +974,7 @@ def test_():
|
||||
bigvgan_model = bigvgan_model.cuda()
|
||||
bigvgan_model.eval()
|
||||
|
||||
logger.info(f"bigvgan ok")
|
||||
logger.info("bigvgan ok")
|
||||
|
||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||
@@ -1020,8 +1009,9 @@ def test_():
|
||||
# "out2.wav",
|
||||
# )
|
||||
|
||||
|
||||
def test_export_gpt_sovits_v3():
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt",map_location=device)
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
||||
# test_export1(
|
||||
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||
# gpt_sovits_v3,
|
||||
|
||||
Reference in New Issue
Block a user