为api_v2和inference_webui_fast适配V3版本 (#2188)
* modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py modified: GPT_SoVITS/inference_webui_fast.py * 适配V3版本 * api_v2.py和inference_webui_fast.py的v3适配 * 修改了个远古bug,增加了更友好的提示信息 * 优化webui * 修改为正确的path * 修复v3 lora模型的载入问题 * 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
全部按日文识别
|
||||
'''
|
||||
import random
|
||||
import os, re, logging
|
||||
import os, re, logging, json
|
||||
import sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None)
|
||||
version=os.environ.get("version","v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
@@ -62,6 +62,9 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# is_half = False
|
||||
# device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
@@ -123,11 +126,11 @@ def inference(text, text_lang,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
):
|
||||
|
||||
seed = -1 if keep_random else seed
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||
inputs={
|
||||
"text": text,
|
||||
"text_lang": dict_language[text_lang],
|
||||
@@ -147,9 +150,14 @@ def inference(text, text_lang,
|
||||
"seed":actual_seed,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
}
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
try:
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
except NO_PROMPT_ERROR:
|
||||
gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!'))
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
@@ -163,19 +171,38 @@ def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
||||
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
|
||||
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
|
||||
_ =[[],[]]
|
||||
for i in range(2):
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
_[-1].append(pretrained_sovits_name[i])
|
||||
for i in range(3):
|
||||
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name,pretrained_sovits_name = _
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
|
||||
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
|
||||
|
||||
if os.path.exists(f"./weight.json"):
|
||||
pass
|
||||
else:
|
||||
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
|
||||
|
||||
with open(f"./weight.json", 'r', encoding="utf-8") as file:
|
||||
weight_data = file.read()
|
||||
weight_data=json.loads(weight_data)
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
|
||||
sovits_path = os.environ.get(
|
||||
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
|
||||
if isinstance(gpt_path,list):
|
||||
gpt_path = gpt_path[0]
|
||||
if isinstance(sovits_path,list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
|
||||
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
|
||||
for path in SoVITS_weight_root+GPT_weight_root:
|
||||
os.makedirs(path,exist_ok=True)
|
||||
|
||||
@@ -194,10 +221,18 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
global version, dict_language
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
|
||||
if if_lora_v3 and not os.path.exists(path_sovits_v3):
|
||||
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
@@ -210,9 +245,19 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
||||
|
||||
if model_version=="v3":
|
||||
visible_sample_steps=True
|
||||
visible_inp_refs=False
|
||||
else:
|
||||
visible_sample_steps=False
|
||||
visible_inp_refs=True
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
|
||||
|
||||
with open("./weight.json")as f:
|
||||
data=f.read()
|
||||
data=json.loads(data)
|
||||
data["SoVITS"][version]=sovits_path
|
||||
with open("./weight.json","w")as f:f.write(json.dumps(data))
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
@@ -257,13 +302,19 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
|
||||
with gr.Row():
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
how_to_cut = gr.Dropdown(
|
||||
@@ -272,10 +323,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True, scale=1
|
||||
)
|
||||
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
|
||||
|
||||
with gr.Row():
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
@@ -295,7 +350,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
],
|
||||
[output, seed],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user