Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
507
api.py
507
api.py
@@ -140,9 +140,9 @@ RESP: 无
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import os,re
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
@@ -152,10 +152,11 @@ sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
import signal
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from time import time as ttime
|
||||
import torch, torchaudio
|
||||
import torch
|
||||
import torchaudio
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, Query, HTTPException
|
||||
from fastapi import FastAPI, Request, Query
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
@@ -163,12 +164,11 @@ import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from io import BytesIO
|
||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from tools.my_utils import load_audio
|
||||
import config as global_config
|
||||
import logging
|
||||
import subprocess
|
||||
@@ -201,7 +201,11 @@ def is_full(*items): # 任意一项为空返回False
|
||||
def init_bigvgan():
|
||||
global bigvgan_model
|
||||
from BigVGAN import bigvgan
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||
use_cuda_kernel=False,
|
||||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
bigvgan_model.remove_weight_norm()
|
||||
bigvgan_model = bigvgan_model.eval()
|
||||
@@ -211,57 +215,71 @@ def init_bigvgan():
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
|
||||
resample_transform_dict={}
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
||||
sr0, 24000
|
||||
).to(device)
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
|
||||
|
||||
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
|
||||
|
||||
def norm_spec(x):
|
||||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||
|
||||
|
||||
def denorm_spec(x):
|
||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||
mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False
|
||||
})
|
||||
|
||||
|
||||
sr_model=None
|
||||
def audio_sr(audio,sr):
|
||||
mel_fn = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
sr_model = None
|
||||
|
||||
|
||||
def audio_sr(audio, sr):
|
||||
global sr_model
|
||||
if sr_model==None:
|
||||
if sr_model == None:
|
||||
from tools.audio_sr import AP_BWE
|
||||
|
||||
try:
|
||||
sr_model=AP_BWE(device,DictToAttrRecursive)
|
||||
sr_model = AP_BWE(device, DictToAttrRecursive)
|
||||
except FileNotFoundError:
|
||||
logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载")
|
||||
return audio.cpu().detach().numpy(),sr
|
||||
return sr_model(audio,sr)
|
||||
return audio.cpu().detach().numpy(), sr
|
||||
return sr_model(audio, sr)
|
||||
|
||||
|
||||
class Speaker:
|
||||
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
|
||||
def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None):
|
||||
self.name = name
|
||||
self.sovits = sovits
|
||||
self.gpt = gpt
|
||||
self.phones = phones
|
||||
self.bert = bert
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
|
||||
speaker_list = {}
|
||||
|
||||
|
||||
@@ -270,22 +288,25 @@ class Sovits:
|
||||
self.vq_model = vq_model
|
||||
self.hps = hps
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
is_exist_s2gv3=os.path.exists(path_sovits_v3)
|
||||
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3==True and is_exist_s2gv3==False:
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
|
||||
dict_s2 = load_sovits_new(sovits_path)
|
||||
hps = dict_s2["config"]
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
if 'enc_p.text_embedding.weight' not in dict_s2['weight']:
|
||||
hps.model.version = "v2"#v3model,v2sybomls
|
||||
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
||||
hps.model.version = "v2" # v3model,v2sybomls
|
||||
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
hps.model.version = "v1"
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
@@ -294,27 +315,28 @@ def get_sovits_weights(sovits_path):
|
||||
hps.model.version = "v3"
|
||||
|
||||
model_params_dict = vars(hps.model)
|
||||
if model_version!="v3":
|
||||
if model_version != "v3":
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict
|
||||
**model_params_dict,
|
||||
)
|
||||
else:
|
||||
vq_model = SynthesizerTrnV3(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict
|
||||
**model_params_dict,
|
||||
)
|
||||
init_bigvgan()
|
||||
model_version=hps.model.version
|
||||
model_version = hps.model.version
|
||||
logger.info(f"模型版本: {model_version}")
|
||||
if ("pretrained" not in sovits_path):
|
||||
if "pretrained" not in sovits_path:
|
||||
try:
|
||||
del vq_model.enc_q
|
||||
except:pass
|
||||
except:
|
||||
pass
|
||||
if is_half == True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
@@ -324,7 +346,7 @@ def get_sovits_weights(sovits_path):
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
else:
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
||||
lora_rank=dict_s2["lora_rank"]
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=lora_rank,
|
||||
@@ -340,13 +362,17 @@ def get_sovits_weights(sovits_path):
|
||||
sovits = Sovits(vq_model, hps)
|
||||
return sovits
|
||||
|
||||
|
||||
class Gpt:
|
||||
def __init__(self, max_sec, t2s_model):
|
||||
self.max_sec = max_sec
|
||||
self.t2s_model = t2s_model
|
||||
|
||||
|
||||
global hz
|
||||
hz = 50
|
||||
|
||||
|
||||
def get_gpt_weights(gpt_path):
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
config = dict_s1["config"]
|
||||
@@ -363,7 +389,8 @@ def get_gpt_weights(gpt_path):
|
||||
gpt = Gpt(max_sec, t2s_model)
|
||||
return gpt
|
||||
|
||||
def change_gpt_sovits_weights(gpt_path,sovits_path):
|
||||
|
||||
def change_gpt_sovits_weights(gpt_path, sovits_path):
|
||||
try:
|
||||
gpt = get_gpt_weights(gpt_path)
|
||||
sovits = get_sovits_weights(sovits_path)
|
||||
@@ -392,16 +419,16 @@ def get_bert_feature(text, word2ph):
|
||||
|
||||
|
||||
def clean_text_inf(text, language, version):
|
||||
language = language.replace("all_","")
|
||||
language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
language=language.replace("all_","")
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
@@ -410,24 +437,27 @@ def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
|
||||
return bert
|
||||
|
||||
|
||||
from text import chinese
|
||||
def get_phones_and_bert(text,language,version,final=False):
|
||||
|
||||
|
||||
def get_phones_and_bert(text, language, version, final=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "all_zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
if re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"zh",version)
|
||||
return get_phones_and_bert(formattext, "zh", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext,"yue",version)
|
||||
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return get_phones_and_bert(formattext, "yue", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
@@ -435,8 +465,8 @@ def get_phones_and_bert(text,language,version,final=False):
|
||||
dtype=torch.float16 if is_half == True else torch.float32,
|
||||
).to(device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
@@ -467,12 +497,12 @@ def get_phones_and_bert(text,language,version,final=False):
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return get_phones_and_bert("." + text,language,version,final=True)
|
||||
return get_phones_and_bert("." + text, language, version, final=True)
|
||||
|
||||
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
|
||||
return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
@@ -504,15 +534,21 @@ class DictToAttrRecursive(dict):
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
|
||||
audio, _ = librosa.load(filename, int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
maxx=audio.abs().max()
|
||||
if(maxx>1):
|
||||
audio/=min(2,maxx)
|
||||
maxx = audio.abs().max()
|
||||
if maxx > 1:
|
||||
audio /= min(2, maxx)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
|
||||
hps.data.win_length, center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
@@ -546,10 +582,11 @@ def pack_ogg(audio_bytes, data, rate):
|
||||
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
|
||||
|
||||
def handle_pack_ogg():
|
||||
with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
import threading
|
||||
|
||||
# See: https://docs.python.org/3/library/threading.html
|
||||
# The stack size of this thread is at least 32768
|
||||
# If stack overflow error still occurs, just modify the `stack_size`.
|
||||
@@ -581,35 +618,47 @@ def pack_raw(audio_bytes, data, rate):
|
||||
|
||||
def pack_wav(audio_bytes, rate):
|
||||
if is_int32:
|
||||
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
|
||||
data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32)
|
||||
wav_bytes = BytesIO()
|
||||
sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
|
||||
sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32")
|
||||
else:
|
||||
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
|
||||
data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16)
|
||||
wav_bytes = BytesIO()
|
||||
sf.write(wav_bytes, data, rate, format='WAV')
|
||||
sf.write(wav_bytes, data, rate, format="WAV")
|
||||
return wav_bytes
|
||||
|
||||
|
||||
def pack_aac(audio_bytes, data, rate):
|
||||
if is_int32:
|
||||
pcm = 's32le'
|
||||
bit_rate = '256k'
|
||||
pcm = "s32le"
|
||||
bit_rate = "256k"
|
||||
else:
|
||||
pcm = 's16le'
|
||||
bit_rate = '128k'
|
||||
process = subprocess.Popen([
|
||||
'ffmpeg',
|
||||
'-f', pcm, # 输入16位有符号小端整数PCM
|
||||
'-ar', str(rate), # 设置采样率
|
||||
'-ac', '1', # 单声道
|
||||
'-i', 'pipe:0', # 从管道读取输入
|
||||
'-c:a', 'aac', # 音频编码器为AAC
|
||||
'-b:a', bit_rate, # 比特率
|
||||
'-vn', # 不包含视频
|
||||
'-f', 'adts', # 输出AAC数据流格式
|
||||
'pipe:1' # 将输出写入管道
|
||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
pcm = "s16le"
|
||||
bit_rate = "128k"
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
pcm, # 输入16位有符号小端整数PCM
|
||||
"-ar",
|
||||
str(rate), # 设置采样率
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-i",
|
||||
"pipe:0", # 从管道读取输入
|
||||
"-c:a",
|
||||
"aac", # 音频编码器为AAC
|
||||
"-b:a",
|
||||
bit_rate, # 比特率
|
||||
"-vn", # 不包含视频
|
||||
"-f",
|
||||
"adts", # 输出AAC数据流格式
|
||||
"pipe:1", # 将输出写入管道
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
audio_bytes.write(out)
|
||||
|
||||
@@ -632,7 +681,7 @@ def cut_text(text, punc):
|
||||
items = re.split(f"({punds})", text)
|
||||
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
|
||||
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
||||
if len(items)%2 == 1:
|
||||
if len(items) % 2 == 1:
|
||||
mergeitems.append(items[-1])
|
||||
text = "\n".join(mergeitems)
|
||||
|
||||
@@ -646,8 +695,38 @@ def only_punc(text):
|
||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"):
|
||||
splits = {
|
||||
",",
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
",",
|
||||
".",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
":",
|
||||
":",
|
||||
"—",
|
||||
"…",
|
||||
}
|
||||
|
||||
|
||||
def get_tts_wav(
|
||||
ref_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
top_k=15,
|
||||
top_p=0.6,
|
||||
temperature=0.6,
|
||||
speed=1,
|
||||
inp_refs=None,
|
||||
sample_steps=32,
|
||||
if_sr=False,
|
||||
spk="default",
|
||||
):
|
||||
infer_sovits = speaker_list[spk].sovits
|
||||
vq_model = infer_sovits.vq_model
|
||||
hps = infer_sovits.hps
|
||||
@@ -659,7 +738,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||
if prompt_text[-1] not in splits:
|
||||
prompt_text += "。" if prompt_language != "en" else "."
|
||||
prompt_language, text = prompt_language, text.strip("\n")
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
@@ -667,7 +747,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if (is_half == True):
|
||||
if is_half == True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
@@ -680,15 +760,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
if version != "v3":
|
||||
refers=[]
|
||||
if(inp_refs):
|
||||
refers = []
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||
refers.append(refer)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if(len(refers)==0):
|
||||
if len(refers) == 0:
|
||||
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
else:
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
@@ -707,7 +787,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
continue
|
||||
|
||||
audio_opt = []
|
||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||
if text[-1] not in splits:
|
||||
text += "。" if text_language != "en" else "."
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
@@ -722,56 +803,62 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k = top_k,
|
||||
top_p = top_p,
|
||||
temperature = temperature,
|
||||
early_stop_num=hz * max_sec)
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
t3 = ttime()
|
||||
|
||||
if version != "v3":
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refers,speed=speed).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
) ###试试重建不带上prompt部分
|
||||
else:
|
||||
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
||||
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio=ref_audio.to(device).float()
|
||||
if (ref_audio.shape[0] == 2):
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr!=24000:
|
||||
ref_audio=resample(ref_audio,sr)
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if (T_min > 468):
|
||||
if T_min > 468:
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||
# print("mel2",mel2)
|
||||
mel2=mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
|
||||
mel2 = mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||
# print("fea_todo",fea_todo)
|
||||
# print("ge",ge.abs().mean())
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while (1):
|
||||
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
|
||||
if (fea_todo_chunk.shape[-1] == 0): break
|
||||
while 1:
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
if fea_todo_chunk.shape[-1] == 0:
|
||||
break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
# set_seed(123)
|
||||
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2]:]
|
||||
cfm_res = vq_model.cfm.inference(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||
)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
# print("fea", fea)
|
||||
# print("mel2in", mel2)
|
||||
@@ -779,14 +866,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
if bigvgan_model==None:init_bigvgan()
|
||||
if bigvgan_model == None:
|
||||
init_bigvgan()
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
audio=wav_gen[0][0].cpu().detach().numpy()
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
max_audio=np.abs(audio).max()
|
||||
if max_audio>1:
|
||||
audio/=max_audio
|
||||
max_audio = np.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
audio_opt = np.concatenate(audio_opt, 0)
|
||||
@@ -795,29 +883,29 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
||||
if if_sr and sr == 24000:
|
||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
|
||||
max_audio=np.abs(audio_opt).max()
|
||||
if max_audio > 1: audio_opt /= max_audio
|
||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||
max_audio = np.abs(audio_opt).max()
|
||||
if max_audio > 1:
|
||||
audio_opt /= max_audio
|
||||
sr = 48000
|
||||
|
||||
if is_int32:
|
||||
audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr)
|
||||
audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr)
|
||||
else:
|
||||
audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr)
|
||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr)
|
||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
if stream_mode == "normal":
|
||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||
yield audio_chunk
|
||||
|
||||
if not stream_mode == "normal":
|
||||
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
sr = 48000 if if_sr else 24000
|
||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
||||
audio_bytes = pack_wav(audio_bytes,sr)
|
||||
audio_bytes = pack_wav(audio_bytes, sr)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
||||
|
||||
def handle_control(command):
|
||||
if command == "restart":
|
||||
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
||||
@@ -828,7 +916,9 @@ def handle_control(command):
|
||||
|
||||
def handle_change(path, text, language):
|
||||
if is_empty(path, text, language):
|
||||
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
|
||||
return JSONResponse(
|
||||
{"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400
|
||||
)
|
||||
|
||||
if path != "" or path is not None:
|
||||
default_refer.path = path
|
||||
@@ -842,15 +932,31 @@ def handle_change(path, text, language):
|
||||
logger.info(f"当前默认参考音频语种: {default_refer.language}")
|
||||
logger.info(f"is_ready: {default_refer.is_ready()}")
|
||||
|
||||
|
||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||
|
||||
|
||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr):
|
||||
def handle(
|
||||
refer_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
cut_punc,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
speed,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
):
|
||||
if (
|
||||
refer_wav_path == "" or refer_wav_path is None
|
||||
or prompt_text == "" or prompt_text is None
|
||||
or prompt_language == "" or prompt_language is None
|
||||
refer_wav_path == ""
|
||||
or refer_wav_path is None
|
||||
or prompt_text == ""
|
||||
or prompt_text is None
|
||||
or prompt_language == ""
|
||||
or prompt_language is None
|
||||
):
|
||||
refer_wav_path, prompt_text, prompt_language = (
|
||||
default_refer.path,
|
||||
@@ -860,17 +966,31 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||
|
||||
if not sample_steps in [4,8,16,32]:
|
||||
if sample_steps not in [4, 8, 16, 32]:
|
||||
sample_steps = 32
|
||||
|
||||
if cut_punc == None:
|
||||
text = cut_text(text,default_cut_punc)
|
||||
text = cut_text(text, default_cut_punc)
|
||||
else:
|
||||
text = cut_text(text,cut_punc)
|
||||
|
||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type)
|
||||
|
||||
text = cut_text(text, cut_punc)
|
||||
|
||||
return StreamingResponse(
|
||||
get_tts_wav(
|
||||
refer_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
speed,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
),
|
||||
media_type="audio/" + media_type,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------
|
||||
@@ -886,7 +1006,7 @@ dict_language = {
|
||||
"粤英混合": "yue",
|
||||
"日英混合": "ja",
|
||||
"韩英混合": "ko",
|
||||
"多语种混合": "auto", #多语种启动切分识别语种
|
||||
"多语种混合": "auto", # 多语种启动切分识别语种
|
||||
"多语种混合(粤语)": "auto_yue",
|
||||
"all_zh": "all_zh",
|
||||
"all_yue": "all_yue",
|
||||
@@ -903,7 +1023,7 @@ dict_language = {
|
||||
|
||||
# logger
|
||||
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
||||
logger = logging.getLogger('uvicorn')
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
# 获取配置
|
||||
g_config = global_config.Config()
|
||||
@@ -919,8 +1039,12 @@ parser.add_argument("-dl", "--default_refer_language", type=str, default="", hel
|
||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||
parser.add_argument(
|
||||
"-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度"
|
||||
)
|
||||
# bool值的用法为 `python ./api.py -fp ...`
|
||||
# 此时 full_precision==True, half_precision==False
|
||||
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
|
||||
@@ -972,14 +1096,14 @@ if args.full_precision and args.half_precision:
|
||||
logger.info(f"半精: {is_half}")
|
||||
|
||||
# 流式返回模式
|
||||
if args.stream_mode.lower() in ["normal","n"]:
|
||||
if args.stream_mode.lower() in ["normal", "n"]:
|
||||
stream_mode = "normal"
|
||||
logger.info("流式返回已开启")
|
||||
else:
|
||||
stream_mode = "close"
|
||||
|
||||
# 音频编码格式
|
||||
if args.media_type.lower() in ["aac","ogg"]:
|
||||
if args.media_type.lower() in ["aac", "ogg"]:
|
||||
media_type = args.media_type.lower()
|
||||
elif stream_mode == "close":
|
||||
media_type = "wav"
|
||||
@@ -988,12 +1112,12 @@ else:
|
||||
logger.info(f"编码格式: {media_type}")
|
||||
|
||||
# 音频数据类型
|
||||
if args.sub_type.lower() == 'int32':
|
||||
if args.sub_type.lower() == "int32":
|
||||
is_int32 = True
|
||||
logger.info(f"数据类型: int32")
|
||||
logger.info("数据类型: int32")
|
||||
else:
|
||||
is_int32 = False
|
||||
logger.info(f"数据类型: int16")
|
||||
logger.info("数据类型: int16")
|
||||
|
||||
# 初始化模型
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
@@ -1006,8 +1130,7 @@ if is_half:
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
ssl_model = ssl_model.to(device)
|
||||
change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
|
||||
|
||||
change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path)
|
||||
|
||||
|
||||
# --------------------------------
|
||||
@@ -1015,21 +1138,21 @@ change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
|
||||
# --------------------------------
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/set_model")
|
||||
async def set_model(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
return change_gpt_sovits_weights(
|
||||
gpt_path = json_post_raw.get("gpt_model_path"),
|
||||
sovits_path = json_post_raw.get("sovits_model_path")
|
||||
gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path")
|
||||
)
|
||||
|
||||
|
||||
@app.get("/set_model")
|
||||
async def set_model(
|
||||
gpt_model_path: str = None,
|
||||
sovits_model_path: str = None,
|
||||
gpt_model_path: str = None,
|
||||
sovits_model_path: str = None,
|
||||
):
|
||||
return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
|
||||
return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path)
|
||||
|
||||
|
||||
@app.post("/control")
|
||||
@@ -1047,18 +1170,12 @@ async def control(command: str = None):
|
||||
async def change_refer(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
return handle_change(
|
||||
json_post_raw.get("refer_wav_path"),
|
||||
json_post_raw.get("prompt_text"),
|
||||
json_post_raw.get("prompt_language")
|
||||
json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language")
|
||||
)
|
||||
|
||||
|
||||
@app.get("/change_refer")
|
||||
async def change_refer(
|
||||
refer_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
prompt_language: str = None
|
||||
):
|
||||
async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None):
|
||||
return handle_change(refer_wav_path, prompt_text, prompt_language)
|
||||
|
||||
|
||||
@@ -1078,27 +1195,41 @@ async def tts_endpoint(request: Request):
|
||||
json_post_raw.get("speed", 1.0),
|
||||
json_post_raw.get("inp_refs", []),
|
||||
json_post_raw.get("sample_steps", 32),
|
||||
json_post_raw.get("if_sr", False)
|
||||
json_post_raw.get("if_sr", False),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def tts_endpoint(
|
||||
refer_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
prompt_language: str = None,
|
||||
text: str = None,
|
||||
text_language: str = None,
|
||||
cut_punc: str = None,
|
||||
top_k: int = 15,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
speed: float = 1.0,
|
||||
inp_refs: list = Query(default=[]),
|
||||
sample_steps: int = 32,
|
||||
if_sr: bool = False
|
||||
refer_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
prompt_language: str = None,
|
||||
text: str = None,
|
||||
text_language: str = None,
|
||||
cut_punc: str = None,
|
||||
top_k: int = 15,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
speed: float = 1.0,
|
||||
inp_refs: list = Query(default=[]),
|
||||
sample_steps: int = 32,
|
||||
if_sr: bool = False,
|
||||
):
|
||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr)
|
||||
return handle(
|
||||
refer_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
cut_punc,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
speed,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user