为api_v2和inference_webui_fast适配V3版本 (#2188)

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	modified:   GPT_SoVITS/inference_webui_fast.py

* 适配V3版本

* api_v2.py和inference_webui_fast.py的v3适配

* 修改了个远古bug,增加了更友好的提示信息

* 优化webui

* 修改为正确的path

* 修复v3 lora模型的载入问题

* 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题
This commit is contained in:
ChasonJiang
2025-03-26 14:34:51 +08:00
committed by GitHub
parent 165882d64f
commit 7394dc7b0c
12 changed files with 486 additions and 146 deletions

View File

@@ -7,7 +7,7 @@
全部按日文识别
'''
import random
import os, re, logging
import os, re, logging, json
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2")
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto, scan_language_list
@@ -62,6 +62,9 @@ if torch.cuda.is_available():
else:
device = "cpu"
# is_half = False
# device = "cpu"
dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
@@ -123,11 +126,11 @@ def inference(text, text_lang,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
repetition_penalty, sample_steps, super_sampling,
):
seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
inputs={
"text": text,
"text_lang": dict_language[text_lang],
@@ -147,9 +150,14 @@ def inference(text, text_lang,
"seed":actual_seed,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
try:
for item in tts_pipeline.run(inputs):
yield item, actual_seed
except NO_PROMPT_ERROR:
gr.Warning(i18n('V3不支持无参考文本模式请填写参考文本'))
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
@@ -163,19 +171,38 @@ def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
_ =[[],[]]
for i in range(2):
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
if os.path.exists(f"./weight.json"):
pass
else:
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
with open(f"./weight.json", 'r', encoding="utf-8") as file:
weight_data = file.read()
weight_data=json.loads(weight_data)
gpt_path = os.environ.get(
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
if isinstance(gpt_path,list):
gpt_path = gpt_path[0]
if isinstance(sovits_path,list):
sovits_path = sovits_path[0]
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True)
@@ -194,10 +221,18 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
tts_pipeline.init_vits_weights(sovits_path)
global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
tts_pipeline.init_vits_weights(sovits_path)
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
@@ -210,9 +245,19 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else:
text_update = {'__type__':'update', 'value':''}
text_language_update = {'__type__':'update', 'value':i18n("中文")}
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
if model_version=="v3":
visible_sample_steps=True
visible_inp_refs=False
else:
visible_sample_steps=False
visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
data["SoVITS"][version]=sovits_path
with open("./weight.json","w")as f:f.write(json.dumps(data))
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
@@ -257,13 +302,19 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Row():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
with gr.Row():
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
with gr.Row():
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
with gr.Row():
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Column():
with gr.Row():
how_to_cut = gr.Dropdown(
@@ -272,10 +323,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=i18n("凑四句一切"),
interactive=True, scale=1
)
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
with gr.Row():
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
@@ -295,7 +350,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
repetition_penalty, sample_steps, super_sampling,
],
[output, seed],
)