Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

507
api.py
View File

@@ -140,9 +140,9 @@ RESP: 无
"""
import argparse
import os,re
import os
import re
import sys
now_dir = os.getcwd()
@@ -152,10 +152,11 @@ sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal
from text.LangSegmenter import LangSegmenter
from time import time as ttime
import torch, torchaudio
import torch
import torchaudio
import librosa
import soundfile as sf
from fastapi import FastAPI, Request, Query, HTTPException
from fastapi import FastAPI, Request, Query
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer
@@ -163,12 +164,11 @@ import numpy as np
from feature_extractor import cnhubert
from io import BytesIO
from module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, PeftModel, get_peft_model
from peft import LoraConfig, get_peft_model
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio
import config as global_config
import logging
import subprocess
@@ -201,7 +201,11 @@ def is_full(*items): # 任意一项为空返回False
def init_bigvgan():
global bigvgan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval()
@@ -211,57 +215,71 @@ def init_bigvgan():
bigvgan_model = bigvgan_model.to(device)
resample_transform_dict={}
resample_transform_dict = {}
def resample(audio_tensor, sr0):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
return resample_transform_dict[sr0](audio_tensor)
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
from module.mel_processing import mel_spectrogram_torch
spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False
})
sr_model=None
def audio_sr(audio,sr):
mel_fn = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
sr_model = None
def audio_sr(audio, sr):
global sr_model
if sr_model==None:
if sr_model == None:
from tools.audio_sr import AP_BWE
try:
sr_model=AP_BWE(device,DictToAttrRecursive)
sr_model = AP_BWE(device, DictToAttrRecursive)
except FileNotFoundError:
logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载")
return audio.cpu().detach().numpy(),sr
return sr_model(audio,sr)
return audio.cpu().detach().numpy(), sr
return sr_model(audio, sr)
class Speaker:
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None):
self.name = name
self.sovits = sovits
self.gpt = gpt
self.phones = phones
self.bert = bert
self.prompt = prompt
speaker_list = {}
@@ -270,22 +288,25 @@ class Sovits:
self.vq_model = vq_model
self.hps = hps
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def get_sovits_weights(sovits_path):
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3=os.path.exists(path_sovits_v3)
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3==True and is_exist_s2gv3==False:
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 == True and is_exist_s2gv3 == False:
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if 'enc_p.text_embedding.weight' not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
hps.model.version = "v2" # v3model,v2sybomls
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
@@ -294,27 +315,28 @@ def get_sovits_weights(sovits_path):
hps.model.version = "v3"
model_params_dict = vars(hps.model)
if model_version!="v3":
if model_version != "v3":
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict
**model_params_dict,
)
else:
vq_model = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict
**model_params_dict,
)
init_bigvgan()
model_version=hps.model.version
model_version = hps.model.version
logger.info(f"模型版本: {model_version}")
if ("pretrained" not in sovits_path):
if "pretrained" not in sovits_path:
try:
del vq_model.enc_q
except:pass
except:
pass
if is_half == True:
vq_model = vq_model.half().to(device)
else:
@@ -324,7 +346,7 @@ def get_sovits_weights(sovits_path):
vq_model.load_state_dict(dict_s2["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
lora_rank=dict_s2["lora_rank"]
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
@@ -340,13 +362,17 @@ def get_sovits_weights(sovits_path):
sovits = Sovits(vq_model, hps)
return sovits
class Gpt:
def __init__(self, max_sec, t2s_model):
self.max_sec = max_sec
self.t2s_model = t2s_model
global hz
hz = 50
def get_gpt_weights(gpt_path):
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
@@ -363,7 +389,8 @@ def get_gpt_weights(gpt_path):
gpt = Gpt(max_sec, t2s_model)
return gpt
def change_gpt_sovits_weights(gpt_path,sovits_path):
def change_gpt_sovits_weights(gpt_path, sovits_path):
try:
gpt = get_gpt_weights(gpt_path)
sovits = get_sovits_weights(sovits_path)
@@ -392,16 +419,16 @@ def get_bert_feature(text, word2ph):
def clean_text_inf(text, language, version):
language = language.replace("all_","")
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
language = language.replace("all_", "")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
@@ -410,24 +437,27 @@ def get_bert_inf(phones, word2ph, norm_text, language):
return bert
from text import chinese
def get_phones_and_bert(text,language,version,final=False):
def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"zh",version)
return get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
@@ -435,8 +465,8 @@ def get_phones_and_bert(text,language,version,final=False):
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
textlist = []
langlist = []
if language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
@@ -467,12 +497,12 @@ def get_phones_and_bert(text,language,version,final=False):
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return get_phones_and_bert("." + text,language,version,final=True)
return get_phones_and_bert("." + text, language, version, final=True)
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text
class DictToAttrRecursive(dict):
@@ -504,15 +534,21 @@ class DictToAttrRecursive(dict):
def get_spepc(hps, filename):
audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
audio, _ = librosa.load(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
maxx=audio.abs().max()
if(maxx>1):
audio/=min(2,maxx)
maxx = audio.abs().max()
if maxx > 1:
audio /= min(2, maxx)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
hps.data.win_length, center=False)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
@@ -546,10 +582,11 @@ def pack_ogg(audio_bytes, data, rate):
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
def handle_pack_ogg():
with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
import threading
# See: https://docs.python.org/3/library/threading.html
# The stack size of this thread is at least 32768
# If stack overflow error still occurs, just modify the `stack_size`.
@@ -581,35 +618,47 @@ def pack_raw(audio_bytes, data, rate):
def pack_wav(audio_bytes, rate):
if is_int32:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32")
else:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV')
sf.write(wav_bytes, data, rate, format="WAV")
return wav_bytes
def pack_aac(audio_bytes, data, rate):
if is_int32:
pcm = 's32le'
bit_rate = '256k'
pcm = "s32le"
bit_rate = "256k"
else:
pcm = 's16le'
bit_rate = '128k'
process = subprocess.Popen([
'ffmpeg',
'-f', pcm, # 输入16位有符号小端整数PCM
'-ar', str(rate), # 设置采样率
'-ac', '1', # 单声道
'-i', 'pipe:0', # 从管道读取输入
'-c:a', 'aac', # 音频编码器为AAC
'-b:a', bit_rate, # 比特率
'-vn', # 不包含视频
'-f', 'adts', # 输出AAC数据流格式
'pipe:1' # 将输出写入管道
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
pcm = "s16le"
bit_rate = "128k"
process = subprocess.Popen(
[
"ffmpeg",
"-f",
pcm, # 输入16位有符号小端整数PCM
"-ar",
str(rate), # 设置采样率
"-ac",
"1", # 单声道
"-i",
"pipe:0", # 从管道读取输入
"-c:a",
"aac", # 音频编码器为AAC
"-b:a",
bit_rate, # 比特率
"-vn", # 不包含视频
"-f",
"adts", # 输出AAC数据流格式
"pipe:1", # 将输出写入管道
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
audio_bytes.write(out)
@@ -632,7 +681,7 @@ def cut_text(text, punc):
items = re.split(f"({punds})", text)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
if len(items) % 2 == 1:
mergeitems.append(items[-1])
text = "\n".join(mergeitems)
@@ -646,8 +695,38 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text)
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"):
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
}
def get_tts_wav(
ref_wav_path,
prompt_text,
prompt_language,
text,
text_language,
top_k=15,
top_p=0.6,
temperature=0.6,
speed=1,
inp_refs=None,
sample_steps=32,
if_sr=False,
spk="default",
):
infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
@@ -659,7 +738,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t0 = ttime()
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
if prompt_text[-1] not in splits:
prompt_text += "" if prompt_language != "en" else "."
prompt_language, text = prompt_language, text.strip("\n")
dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
@@ -667,7 +747,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
if is_half == True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
@@ -680,15 +760,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3":
refers=[]
if(inp_refs):
refers = []
if inp_refs:
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
if len(refers) == 0:
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
@@ -707,7 +787,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
audio_opt = []
if (text[-1] not in splits): text += "" if text_language != "en" else "."
if text[-1] not in splits:
text += "" if text_language != "en" else "."
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
bert = torch.cat([bert1, bert2], 1)
@@ -722,56 +803,62 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k = top_k,
top_p = top_p,
temperature = temperature,
early_stop_num=hz * max_sec)
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime()
if version != "v3":
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
else:
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000:
ref_audio=resample(ref_audio,sr)
if sr != 24000:
ref_audio = resample(ref_audio, sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
if T_min > 468:
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
mel2=mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
while 1:
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
if fea_todo_chunk.shape[-1] == 0:
break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
cfm_res = vq_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
@@ -779,14 +866,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
if bigvgan_model==None:init_bigvgan()
if bigvgan_model == None:
init_bigvgan()
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0].cpu().detach().numpy()
audio = wav_gen[0][0].cpu().detach().numpy()
max_audio=np.abs(audio).max()
if max_audio>1:
audio/=max_audio
max_audio = np.abs(audio).max()
if max_audio > 1:
audio /= max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
audio_opt = np.concatenate(audio_opt, 0)
@@ -795,29 +883,29 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
sr = hps.data.sampling_rate if version != "v3" else 24000
if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
max_audio=np.abs(audio_opt).max()
if max_audio > 1: audio_opt /= max_audio
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
max_audio = np.abs(audio_opt).max()
if max_audio > 1:
audio_opt /= max_audio
sr = 48000
if is_int32:
audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr)
audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr)
else:
audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
yield audio_chunk
if not stream_mode == "normal":
if not stream_mode == "normal":
if media_type == "wav":
sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr
audio_bytes = pack_wav(audio_bytes,sr)
audio_bytes = pack_wav(audio_bytes, sr)
yield audio_bytes.getvalue()
def handle_control(command):
if command == "restart":
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
@@ -828,7 +916,9 @@ def handle_control(command):
def handle_change(path, text, language):
if is_empty(path, text, language):
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
return JSONResponse(
{"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400
)
if path != "" or path is not None:
default_refer.path = path
@@ -842,15 +932,31 @@ def handle_change(path, text, language):
logger.info(f"当前默认参考音频语种: {default_refer.language}")
logger.info(f"is_ready: {default_refer.is_ready()}")
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr):
def handle(
refer_wav_path,
prompt_text,
prompt_language,
text,
text_language,
cut_punc,
top_k,
top_p,
temperature,
speed,
inp_refs,
sample_steps,
if_sr,
):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
or prompt_language == "" or prompt_language is None
refer_wav_path == ""
or refer_wav_path is None
or prompt_text == ""
or prompt_text is None
or prompt_language == ""
or prompt_language is None
):
refer_wav_path, prompt_text, prompt_language = (
default_refer.path,
@@ -860,17 +966,31 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if not sample_steps in [4,8,16,32]:
if sample_steps not in [4, 8, 16, 32]:
sample_steps = 32
if cut_punc == None:
text = cut_text(text,default_cut_punc)
text = cut_text(text, default_cut_punc)
else:
text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type)
text = cut_text(text, cut_punc)
return StreamingResponse(
get_tts_wav(
refer_wav_path,
prompt_text,
prompt_language,
text,
text_language,
top_k,
top_p,
temperature,
speed,
inp_refs,
sample_steps,
if_sr,
),
media_type="audio/" + media_type,
)
# --------------------------------
@@ -886,7 +1006,7 @@ dict_language = {
"粤英混合": "yue",
"日英混合": "ja",
"韩英混合": "ko",
"多语种混合": "auto", #多语种启动切分识别语种
"多语种混合": "auto", # 多语种启动切分识别语种
"多语种混合(粤语)": "auto_yue",
"all_zh": "all_zh",
"all_yue": "all_yue",
@@ -903,7 +1023,7 @@ dict_language = {
# logger
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
logger = logging.getLogger('uvicorn')
logger = logging.getLogger("uvicorn")
# 获取配置
g_config = global_config.Config()
@@ -919,8 +1039,12 @@ parser.add_argument("-dl", "--default_refer_language", type=str, default="", hel
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用精度")
parser.add_argument(
"-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用精度"
)
parser.add_argument(
"-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度"
)
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
@@ -972,14 +1096,14 @@ if args.full_precision and args.half_precision:
logger.info(f"半精: {is_half}")
# 流式返回模式
if args.stream_mode.lower() in ["normal","n"]:
if args.stream_mode.lower() in ["normal", "n"]:
stream_mode = "normal"
logger.info("流式返回已开启")
else:
stream_mode = "close"
# 音频编码格式
if args.media_type.lower() in ["aac","ogg"]:
if args.media_type.lower() in ["aac", "ogg"]:
media_type = args.media_type.lower()
elif stream_mode == "close":
media_type = "wav"
@@ -988,12 +1112,12 @@ else:
logger.info(f"编码格式: {media_type}")
# 音频数据类型
if args.sub_type.lower() == 'int32':
if args.sub_type.lower() == "int32":
is_int32 = True
logger.info(f"数据类型: int32")
logger.info("数据类型: int32")
else:
is_int32 = False
logger.info(f"数据类型: int16")
logger.info("数据类型: int16")
# 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path
@@ -1006,8 +1130,7 @@ if is_half:
else:
bert_model = bert_model.to(device)
ssl_model = ssl_model.to(device)
change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path)
# --------------------------------
@@ -1015,21 +1138,21 @@ change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
# --------------------------------
app = FastAPI()
@app.post("/set_model")
async def set_model(request: Request):
json_post_raw = await request.json()
return change_gpt_sovits_weights(
gpt_path = json_post_raw.get("gpt_model_path"),
sovits_path = json_post_raw.get("sovits_model_path")
gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path")
)
@app.get("/set_model")
async def set_model(
gpt_model_path: str = None,
sovits_model_path: str = None,
gpt_model_path: str = None,
sovits_model_path: str = None,
):
return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path)
@app.post("/control")
@@ -1047,18 +1170,12 @@ async def control(command: str = None):
async def change_refer(request: Request):
json_post_raw = await request.json()
return handle_change(
json_post_raw.get("refer_wav_path"),
json_post_raw.get("prompt_text"),
json_post_raw.get("prompt_language")
json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language")
)
@app.get("/change_refer")
async def change_refer(
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None
):
async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None):
return handle_change(refer_wav_path, prompt_text, prompt_language)
@@ -1078,27 +1195,41 @@ async def tts_endpoint(request: Request):
json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", []),
json_post_raw.get("sample_steps", 32),
json_post_raw.get("if_sr", False)
json_post_raw.get("if_sr", False),
)
@app.get("/")
async def tts_endpoint(
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None,
text: str = None,
text_language: str = None,
cut_punc: str = None,
top_k: int = 15,
top_p: float = 1.0,
temperature: float = 1.0,
speed: float = 1.0,
inp_refs: list = Query(default=[]),
sample_steps: int = 32,
if_sr: bool = False
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None,
text: str = None,
text_language: str = None,
cut_punc: str = None,
top_k: int = 15,
top_p: float = 1.0,
temperature: float = 1.0,
speed: float = 1.0,
inp_refs: list = Query(default=[]),
sample_steps: int = 32,
if_sr: bool = False,
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr)
return handle(
refer_wav_path,
prompt_text,
prompt_language,
text,
text_language,
cut_punc,
top_k,
top_p,
temperature,
speed,
inp_refs,
sample_steps,
if_sr,
)
if __name__ == "__main__":