14 Commits

Author SHA1 Message Date
zzz
7dec5f5bb0 Merge pull request #2460 from L-jasmine/export_v2pro
优化 torch_script 导出模型
2025-06-13 22:10:11 +08:00
RVC-Boss
1a9b8854ee Merge pull request #2456 from L-jasmine/export_v2pro
export_torch_script.py support v2Pro & v2ProPlus
2025-06-12 23:15:46 +08:00
csh
5c91e66d2e export_torch_script.py support v2Pro & v2ProPlus 2025-06-12 21:53:14 +08:00
RVC-Boss
ed89a02337 修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
2025-06-11 23:14:52 +08:00
RVC-Boss
cd6de7398e Merge pull request #2449 from KamioRinn/maga
support v4 v2Pro v2ProPlus for api & optimize LangSegmenter
2025-06-11 10:29:39 +08:00
YYuX-1145
dd2b9253aa Update TTS.py (#2450) 2025-06-11 10:28:42 +08:00
KamioRinn
29165eb02e support v4 v2Pro v2ProPlus for api 2025-06-11 02:09:07 +08:00
KamioRinn
746cb536c6 Fix LangSegmenter 2025-06-10 19:18:05 +08:00
Emmanuel Ferdman
0d2f273402 Resolve Python Logger warnings (#2379)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-10 18:03:23 +08:00
RVC-Boss
d39836b8fa Update Changelog_CN.md 2025-06-10 17:30:06 +08:00
RVC-Boss
2c0436b9ce 修复实验名结尾出现空格在win中路径不正确的问题
修复实验名结尾出现空格在win中路径不正确的问题
2025-06-10 14:58:00 +08:00
RVC-Boss
8056efe4ab 修复ge.sum数值可能爆炸问题
修复ge.sum数值可能爆炸问题
2025-06-09 23:53:16 +08:00
wzy3650
d6b78c927a fix configs error (#2439)
* fix configs error

* fix configs error

---------

Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
Co-authored-by: wangzeyuan <wangzeyuan@shengwang.cn>
2025-06-09 11:25:55 +08:00
RVC-Boss
74e79ae6d6 Delete batch_inference.py 2025-06-07 14:40:30 +08:00
15 changed files with 554 additions and 562 deletions

View File

@@ -354,7 +354,7 @@ class ScaledAdam(BatchedOptimizer):
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
@@ -362,7 +362,7 @@ class ScaledAdam(BatchedOptimizer):
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
"""
Show information of parameter wihch dominanting tot_sumsq.
Show information of parameter which dominating tot_sumsq.
Args:
tuples: a list of tuples of (param, state, param_names)
@@ -415,7 +415,7 @@ class ScaledAdam(BatchedOptimizer):
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f"Parameter Dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"

View File

@@ -1073,7 +1073,7 @@ class TTS:
###### setting reference audio and prompt text preprocessing ########
t0 = time.perf_counter()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"] or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)):
if not os.path.exists(ref_audio_path):
raise ValueError(f"{ref_audio_path} not exists")
self.set_ref_audio(ref_audio_path)

View File

@@ -159,6 +159,10 @@ class TextPreprocessor:
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import argparse
from io import BytesIO
from typing import Optional
from my_utils import load_audio
import torch
@@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert
from sv import SV
import kaldi as Kaldi
import os
import soundfile
@@ -32,6 +36,22 @@ default_config = {
"EOS": 1024,
}
sv_cn_model = None
def init_sv_cn(device, is_half):
global sv_cn_model
sv_cn_model = SV(device, is_half)
def load_sovits_new(sovits_path):
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != b"PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path, map_location="cpu", weights_only=False)
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"]
@@ -83,7 +103,7 @@ def logits_to_probs(
@torch.jit.script
def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
q = torch.empty_like(probs_sort).exponential_(1.0)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@@ -94,7 +114,7 @@ def sample(
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
repetition_penalty: float = 1.35,
):
probs = logits_to_probs(
logits=logits,
@@ -109,8 +129,8 @@ def sample(
@torch.jit.script
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
def spectrogram_torch(hann_window:Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
# hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@@ -289,8 +309,9 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
@@ -328,15 +349,22 @@ class T2STransformer:
class VitsModel(nn.Module):
def __init__(self, vits_path):
def __init__(self, vits_path, version=None, is_half=True, device="cpu"):
super().__init__()
# dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path, weights_only=False)
dict_s2 = load_sovits_new(vits_path)
self.hps = dict_s2["config"]
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
if version is None:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
else:
self.hps["model"]["version"] = "v2"
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
self.hps["model"]["version"] = version
else:
raise ValueError(f"Unsupported version: {version}")
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
@@ -346,11 +374,18 @@ class VitsModel(nn.Module):
n_speakers=self.hps.data.n_speakers,
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
self.vq_model.dec.remove_weight_norm()
if is_half:
self.vq_model = self.vq_model.half()
self.vq_model = self.vq_model.to(device)
self.vq_model.eval()
self.hann_window = torch.hann_window(self.hps.data.win_length, device=device, dtype= torch.float16 if is_half else torch.float32)
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
refer = spectrogram_torch(
self.hann_window,
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
@@ -358,7 +393,7 @@ class VitsModel(nn.Module):
self.hps.data.win_length,
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
class T2SModel(nn.Module):
@@ -632,7 +667,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T.to(text_seq.device)
@@ -640,7 +675,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ssl_content = ssl(ref_audio).to(device)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path).to(device)
vits = VitsModel(vits_path,device=device,is_half=False)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
@@ -679,6 +714,124 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
print("#### exported gpt_sovits ####")
def export_prov2(
gpt_path,
vits_path,
version,
ref_audio_path,
ref_text,
output_path,
export_bert_and_ssl=False,
device="cpu",
is_half=True,
):
if sv_cn_model == None:
init_sv_cn(device,is_half)
if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"目录已创建: {output_path}")
else:
print(f"目录已存在: {output_path}")
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
if export_bert_and_ssl:
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print("#### exported ssl ####")
export_bert(output_path)
else:
s = ExportSSLModel(ssl)
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
ref_text, "all_zh", "v2"
)
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
if is_half:
ref_bert = ref_bert.half()
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
if is_half:
text_bert = text_bert.half()
text_bert = text_bert.to(text_seq.device)
ssl_content = ssl(ref_audio)
if is_half:
ssl_content = ssl_content.half()
ssl_content = ssl_content.to(device)
sv_model = ExportERes2NetV2(sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path, weights_only=False)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
if is_half:
raw_t2s = raw_t2s.half()
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m).to(device)
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
if is_half:
ref_audio_sr = ref_audio_sr.half()
ref_audio_sr = ref_audio_sr.to(device)
torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
torch._dynamo.mark_dynamic(ref_seq, 1)
torch._dynamo.mark_dynamic(text_seq, 1)
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)
# torch._dynamo.mark_dynamic(sv_emb, 0)
top_k = torch.LongTensor([5]).to(device)
# 先跑一遍 sv_model 让它加载 cache详情见 L880
gpt_sovits.sv_model(ref_audio_sr)
with torch.no_grad():
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert,
top_k,
),
)
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print("#### exported gpt_sovits ####")
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
print("start write wav")
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
@torch.jit.script
def parse_audio(ref_audio):
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
@@ -717,6 +870,66 @@ class GPT_SoVITS(nn.Module):
return audio
class ExportERes2NetV2(nn.Module):
def __init__(self, sv_cn_model:SV):
super(ExportERes2NetV2, self).__init__()
self.bn1 = sv_cn_model.embedding_model.bn1
self.conv1 = sv_cn_model.embedding_model.conv1
self.layer1 = sv_cn_model.embedding_model.layer1
self.layer2 = sv_cn_model.embedding_model.layer2
self.layer3 = sv_cn_model.embedding_model.layer3
self.layer4 = sv_cn_model.embedding_model.layer4
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
self.fuse34 = sv_cn_model.embedding_model.fuse34
# audio_16k.shape: [1,N]
def forward(self, audio_16k):
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
# 只跟 device 和 dtype 有关
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
x = torch.stack([x])
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
class GPT_SoVITS_V2Pro(nn.Module):
def __init__(self, t2s: T2SModel, vits: VitsModel,sv_model:ExportERes2NetV2):
super().__init__()
self.t2s = t2s
self.vits = vits
self.sv_model = sv_model
def forward(
self,
ssl_content: torch.Tensor,
ref_audio_sr: torch.Tensor,
ref_seq: Tensor,
text_seq: Tensor,
ref_bert: Tensor,
text_bert: Tensor,
top_k: LongTensor,
speed=1.0,
):
codes = self.vits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
sv_emb = self.sv_model(audio_16k)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
return audio
def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
@@ -833,29 +1046,53 @@ def export_symbel(version="v2"):
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument("--output_path", required=True, help="Path to the output directory")
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
parser.add_argument(
"--sovits_model", required=True, help="Path to the SoVITS model file"
)
parser.add_argument(
"--ref_audio", required=True, help="Path to the reference audio file"
)
parser.add_argument(
"--ref_text", required=True, help="Path to the reference text file"
)
parser.add_argument(
"--output_path", required=True, help="Path to the output directory"
)
parser.add_argument(
"--export_common_model", action="store_true", help="Export Bert and SSL model"
)
parser.add_argument("--device", help="Device to use")
parser.add_argument("--version", help="version of the model", default="v2")
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
args = parser.parse_args()
export(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
device=args.device,
export_bert_and_ssl=args.export_common_model,
)
if args.version in ["v2Pro", "v2ProPlus"]:
is_half = not args.no_half
print(f"Using half precision: {is_half}")
export_prov2(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
version=args.version,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
export_bert_and_ssl=args.export_common_model,
device=args.device,
is_half=is_half,
)
else:
export(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
device=args.device,
export_bert_and_ssl=args.export_common_model,
)
import inference_webui
if __name__ == "__main__":
inference_webui.is_half = False
inference_webui.dtype = torch.float32
main()
with torch.no_grad():
main()
# test()

View File

@@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward(
self,
@@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
top_k,
):
refer = spectrogram_torch(
self.hann_window,
ref_audio_32k,
self.filter_length,
self.sampling_rate,
@@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward(
self,
@@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
top_k,
):
refer = spectrogram_torch(
self.hann_window,
ref_audio_32k,
self.filter_length,
self.sampling_rate,
@@ -1149,7 +1153,7 @@ def export_2(version="v3"):
raw_t2s = raw_t2s.half().to(device)
t2s_m = T2SModel(raw_t2s).half().to(device)
t2s_m.eval()
t2s_m = torch.jit.script(t2s_m)
t2s_m = torch.jit.script(t2s_m).to(device)
t2s_m.eval()
# t2s_m.top_k = 15
logger.info("t2s_m ok")
@@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3():
with torch.no_grad():
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
# export_2("v4")
# export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
export_2("v4")
# test_export_gpt_sovits_v3()

View File

@@ -214,7 +214,7 @@ v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path:
if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path]
global vq_model, hps, version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
@@ -361,7 +361,7 @@ except:
def change_gpt_weights(gpt_path):
if "" in gpt_path:
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
global hz, max_sec, t2s_model, config
hz = 50
@@ -623,6 +623,10 @@ def get_phones_and_bert(text, language, version, final=False):
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:

View File

@@ -114,11 +114,11 @@ tts_config.device = device
tts_config.is_half = is_half
tts_config.version = version
if gpt_path is not None:
if "" in gpt_path:
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
if "" in sovits_path:
if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path]
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
@@ -217,7 +217,7 @@ v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path:
if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path]
global version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
@@ -283,6 +283,12 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
f.write(json.dumps(data))
def change_gpt_weights(gpt_path):
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
tts_pipeline.init_t2s_weights(gpt_path)
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
gr.HTML(
top_html.format(
@@ -457,7 +463,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
inference_button,
],
) #
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
with gr.Group():
gr.Markdown(

View File

@@ -762,6 +762,7 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
v2pro_set={"v2Pro","v2ProPlus"}
class SynthesizerTrn(nn.Module):
"""
@@ -867,20 +868,33 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
refer_mask = torch.ones_like(refer[:1, :1, :])
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb)
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
if self.is_v2pro:
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
else:
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale

View File

@@ -1,4 +1,6 @@
import math
import pdb
import numpy as np
import torch
from torch import nn
@@ -718,8 +720,10 @@ class MelStyleEncoder(nn.Module):
else:
len_ = (~mask).sum(dim=1).unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(-1), 0)
x = x.sum(dim=1)
out = torch.div(x, len_)
dtype=x.dtype
x = x.float()
x=torch.div(x,len_.unsqueeze(1))
out=x.sum(dim=1).to(dtype)
return out
def forward(self, x, mask=None):
@@ -743,7 +747,6 @@ class MelStyleEncoder(nn.Module):
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w.unsqueeze(-1)

View File

@@ -127,7 +127,7 @@ def get_sovits_version_from_path_fast(sovits_path):
def load_sovits_new(sovits_path):
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != "PK":
if meta != b"PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)

View File

@@ -283,7 +283,7 @@ def get_hparams_from_file(config_path):
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
logger.warning(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir,
)
@@ -296,7 +296,7 @@ def check_git_hash(model_dir):
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
logger.warning(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8],
cur_hash[:8],

268
api.py
View File

@@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from feature_extractor import cnhubert
from io import BytesIO
from module.models import SynthesizerTrn, SynthesizerTrnV3
from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
@@ -198,8 +198,38 @@ def is_full(*items): # 任意一项为空返回False
return True
def init_bigvgan():
bigvgan_model = hifigan_model = sv_cn_model = None
def clean_hifigan_model():
global hifigan_model
if hifigan_model:
hifigan_model = hifigan_model.cpu()
hifigan_model = None
try:
torch.cuda.empty_cache()
except:
pass
def clean_bigvgan_model():
global bigvgan_model
if bigvgan_model:
bigvgan_model = bigvgan_model.cpu()
bigvgan_model = None
try:
torch.cuda.empty_cache()
except:
pass
def clean_sv_cn_model():
global sv_cn_model
if sv_cn_model:
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
sv_cn_model = None
try:
torch.cuda.empty_cache()
except:
pass
def init_bigvgan():
global bigvgan_model, hifigan_model,sv_cn_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
@@ -209,20 +239,53 @@ def init_bigvgan():
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval()
if is_half == True:
bigvgan_model = bigvgan_model.half().to(device)
else:
bigvgan_model = bigvgan_model.to(device)
resample_transform_dict = {}
def init_hifigan():
global hifigan_model, bigvgan_model,sv_cn_model
hifigan_model = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0,
is_bias=True,
)
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if is_half == True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
def resample(audio_tensor, sr0):
from sv import SV
def init_sv_cn():
global hifigan_model, bigvgan_model, sv_cn_model
sv_cn_model = SV(device, is_half)
resample_transform_dict={}
def resample(audio_tensor, sr0,sr1,device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
return resample_transform_dict[sr0](audio_tensor)
key="%s-%s-%s"%(sr0,sr1,str(device))
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(
sr0, sr1
).to(device)
return resample_transform_dict[key](audio_tensor)
from module.mel_processing import mel_spectrogram_torch
@@ -252,6 +315,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False,
},
)
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
sr_model = None
@@ -293,12 +369,18 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
from config import pretrained_sovits_name
path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 == True and is_exist_s2gv3 == False:
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
if if_lora_v3 == True and is_exist == False:
logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"]
@@ -311,11 +393,13 @@ def get_sovits_weights(sovits_path):
else:
hps.model.version = "v2"
if model_version == "v3":
hps.model.version = "v3"
model_params_dict = vars(hps.model)
if model_version != "v3":
if model_version not in {"v3", "v4"}:
if "Pro" in model_version:
hps.model.version = model_version
if sv_cn_model == None:
init_sv_cn()
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@@ -323,13 +407,18 @@ def get_sovits_weights(sovits_path):
**model_params_dict,
)
else:
hps.model.version = model_version
vq_model = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**model_params_dict,
)
init_bigvgan()
if model_version == "v3":
init_bigvgan()
if model_version == "v4":
init_hifigan()
model_version = hps.model.version
logger.info(f"模型版本: {model_version}")
if "pretrained" not in sovits_path:
@@ -345,7 +434,8 @@ def get_sovits_weights(sovits_path):
if if_lora_v3 == False:
vq_model.load_state_dict(dict_s2["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False)
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@@ -479,6 +569,10 @@ def get_phones_and_bert(text, language, version, final=False):
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
@@ -533,23 +627,32 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found")
def get_spepc(hps, filename):
audio, _ = librosa.load(filename, sr=int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
sr1=int(hps.data.sampling_rate)
audio, sr0=torchaudio.load(filename)
if sr0!=sr1:
audio=audio.to(device)
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
audio=resample(audio,sr0,sr1,device)
else:
audio=audio.to(device)
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
maxx = audio.abs().max()
if maxx > 1:
audio /= min(2, maxx)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
audio,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
spec=spec.to(dtype)
if is_v2pro==True:
audio=resample(audio,sr1,16000,device).to(dtype)
return spec,audio
def pack_audio(audio_bytes, data, rate):
@@ -736,6 +839,16 @@ def get_tts_wav(
t2s_model = infer_gpt.t2s_model
max_sec = infer_gpt.max_sec
if version == "v3":
if sample_steps not in [4, 8, 16, 32, 64, 128]:
sample_steps = 32
elif version == "v4":
if sample_steps not in [4, 8, 16, 32]:
sample_steps = 8
if if_sr and version != "v3":
if_sr = False
t0 = ttime()
prompt_text = prompt_text.strip("\n")
if prompt_text[-1] not in splits:
@@ -759,19 +872,29 @@ def get_tts_wav(
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3":
is_v2pro = version in {"v2Pro","v2ProPlus"}
if version not in {"v3", "v4"}:
refers = []
if is_v2pro:
sv_emb= []
if sv_cn_model == None:
init_sv_cn()
if inp_refs:
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
try:#####这里加上提取sv的逻辑要么一堆sv一堆refer要么单个sv单个refer
refer,audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro)
refers.append(refer)
if is_v2pro:
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
except Exception as e:
logger.error(e)
if len(refers) == 0:
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
refers,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro)
refers=[refers]
if is_v2pro:
sv_emb=[sv_cn_model.compute_embedding3(audio_tensor)]
else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
refer,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
t1 = ttime()
# os.environ['version'] = version
@@ -811,41 +934,48 @@ def get_tts_wav(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime()
if version != "v3":
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
if version not in {"v3", "v4"}:
if is_v2pro:
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)
.detach()
.cpu()
.numpy()[0, 0]
)
else:
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach()
.cpu()
.numpy()[0, 0]
)
else:
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr != 24000:
ref_audio = resample(ref_audio, sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
tgt_sr = 24000 if version == "v3" else 32000
if sr != tgt_sr:
ref_audio = resample(ref_audio, sr, tgt_sr, device)
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if T_min > 468:
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
Tref = 468 if version == "v3" else 500
Tchunk = 934 if version == "v3" else 1000
if T_min > Tref:
mel2 = mel2[:, :, -Tref:]
fea_ref = fea_ref[:, :, -Tref:]
T_min = Tref
chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while 1:
@@ -854,22 +984,24 @@ def get_tts_wav(
break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
if bigvgan_model == None:
init_bigvgan()
cfm_res = torch.cat(cfm_resss, 2)
cfm_res = denorm_spec(cfm_res)
if version == "v3":
if bigvgan_model == None:
init_bigvgan()
else: # v4
if hifigan_model == None:
init_hifigan()
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
wav_gen = vocoder_model(cfm_res)
audio = wav_gen[0][0].cpu().detach().numpy()
max_audio = np.abs(audio).max()
@@ -880,7 +1012,13 @@ def get_tts_wav(
audio_opt = np.concatenate(audio_opt, 0)
t4 = ttime()
sr = hps.data.sampling_rate if version != "v3" else 24000
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
sr = 32000
elif version == "v3":
sr = 24000
else:
sr = 48000 # v4
if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
@@ -900,8 +1038,12 @@ def get_tts_wav(
if not stream_mode == "normal":
if media_type == "wav":
sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
sr = 32000
elif version == "v3":
sr = 48000 if if_sr else 24000
else:
sr = 48000 # v4
audio_bytes = pack_wav(audio_bytes, sr)
yield audio_bytes.getvalue()
@@ -966,8 +1108,6 @@ def handle(
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if sample_steps not in [4, 8, 16, 32]:
sample_steps = 32
if cut_punc == None:
text = cut_text(text, default_cut_punc)
@@ -1071,10 +1211,10 @@ default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, a
# 模型路径检查
if sovits_path == "":
sovits_path = g_config.pretrained_sovits_path
logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
if gpt_path == "":
gpt_path = g_config.pretrained_gpt_path
logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":

View File

@@ -1,442 +0,0 @@
import argparse
import os
import pdb
import signal
import sys
from time import time as ttime
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from feature_extractor import cnhubert
from io import BytesIO
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
import config as global_config
g_config = global_config.Config()
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
parser.add_argument("-dr", "--default_refer_path", type=str, default="",
help="默认参考音频路径, 请求缺少参考音频时调用")
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
args = parser.parse_args()
sovits_path = args.sovits_path
gpt_path = args.gpt_path
default_refer_path = args.default_refer_path
default_refer_text = args.default_refer_text
default_refer_language = args.default_refer_language
has_preset = False
device = args.device
port = args.port
host = args.bind_addr
if sovits_path == "":
sovits_path = g_config.pretrained_sovits_path
print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
if gpt_path == "":
gpt_path = g_config.pretrained_gpt_path
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "":
default_refer_path, default_refer_text, default_refer_language = "", "", ""
print("[INFO] 未指定默认参考音频")
has_preset = False
else:
print(f"[INFO] 默认参考音频路径: {default_refer_path}")
print(f"[INFO] 默认参考音频文本: {default_refer_text}")
print(f"[INFO] 默认参考音频语种: {default_refer_language}")
has_preset = True
is_half = g_config.is_half
if args.full_precision:
is_half = False
if args.half_precision:
is_half = True
if args.full_precision and args.half_precision:
is_half = g_config.is_half # 炒饭fallback
print(f"[INFO] 半精: {is_half}")
cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
n_semantic = 1024
dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False)
hps = dict_s2["config"]
print(hps)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if is_half:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50
max_sec = config['data']['max_sec']
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
hps.data.win_length, center=False)
return spec
dict_language = {
"中文": "zh",
"英文": "en",
"日文": "ja",
"ZH": "zh",
"EN": "en",
"JA": "ja",
"zh": "zh",
"en": "en",
"ja": "ja"
}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k=torch.cat([wav16k,zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
# yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
return hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
def get_tts_wavs(ref_wav_path, prompt_text, prompt_language, textss, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k=torch.cat([wav16k,zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
audios_opt=[]
for text0 in textss:
texts = text0.strip("\n").split("\n")
audio_opt = []
for text in texts:
text=text.strip("")+""
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
audios_opt.append([text0,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16)])
return audios_opt
# get_tts_wav(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", "我觉得还是该给喜欢的女孩子一场认真的告白。", "中文")
# with open(r"D:\BaiduNetdiskDownload\gsv\烟嗓-todo1.txt","r",encoding="utf8")as f:
# with open(r"D:\BaiduNetdiskDownload\gsv\年下-todo1.txt","r",encoding="utf8")as f:
# with open(r"D:\BaiduNetdiskDownload\gsv\萧逸3b.txt","r",encoding="utf8")as f:
with open(r"D:\BaiduNetdiskDownload\gsv\萧逸4.txt","r",encoding="utf8")as f:
textss=f.read().split("\n")
for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", textss, "中文")):
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\足够的能力,去制定好自己的生活规划。低沉烟嗓.MP3_1940480_2095360.wav", "足够的能力,去制定好自己的生活规划。", "中文", textss, "中文")):
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\不会呀!你前几天才吃过你还说好吃来着。年下少年音.MP3_537600_711040.wav", "不会呀!你前几天才吃过你还说好吃来着。", "中文", textss, "中文")):
print(idx,text)
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\烟嗓第一批\%04d-%s.wav"%(idx,text),audio,32000)
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\年下\%04d-%s.wav"%(idx,text),audio,32000)
sf.write(r"D:\BaiduNetdiskDownload\gsv\output\萧逸第4批\%04d-%s.wav"%(idx,text),audio,32000)
# def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
# if command == "/restart":
# os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
# elif command == "/exit":
# os.kill(os.getpid(), signal.SIGTERM)
# exit(0)
#
# if (
# refer_wav_path == "" or refer_wav_path is None
# or prompt_text == "" or prompt_text is None
# or prompt_language == "" or prompt_language is None
# ):
# refer_wav_path, prompt_text, prompt_language = (
# default_refer_path,
# default_refer_text,
# default_refer_language,
# )
# if not has_preset:
# raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
#
# with torch.no_grad():
# gen = get_tts_wav(
# refer_wav_path, prompt_text, prompt_language, text, text_language
# )
# sampling_rate, audio_data = next(gen)
#
# wav = BytesIO()
# sf.write(wav, audio_data, sampling_rate, format="wav")
# wav.seek(0)
#
# torch.cuda.empty_cache()
# return StreamingResponse(wav, media_type="audio/wav")
# app = FastAPI()
#
#
# @app.post("/")
# async def tts_endpoint(request: Request):
# json_post_raw = await request.json()
# return handle(
# json_post_raw.get("command"),
# json_post_raw.get("refer_wav_path"),
# json_post_raw.get("prompt_text"),
# json_post_raw.get("prompt_language"),
# json_post_raw.get("text"),
# json_post_raw.get("text_language"),
# )
#
#
# @app.get("/")
# async def tts_endpoint(
# command: str = None,
# refer_wav_path: str = None,
# prompt_text: str = None,
# prompt_language: str = None,
# text: str = None,
# text_language: str = None,
# ):
# return handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language)
#
#
# if __name__ == "__main__":
# uvicorn.run(app, host=host, port=port, workers=1)

View File

@@ -578,3 +578,19 @@
- 内容: 优化精度自动检测逻辑, 给 WebUI 前端界面模块增加折叠功能.
- 类型: 新功能
- 提交: XXXXRT666, RVC-Boss
- 2025.06.06 [PR#2427](https://github.com/RVC-Boss/GPT-SoVITS/pull/2427)
- 内容: X一X型多音字判断修复
- 类型: 修复
- 提交: wzy3650
- 2025.06.05 [PR#2439](https://github.com/RVC-Boss/GPT-SoVITS/pull/2439)
- 内容: 配置修复sovits模型读取修复
- 类型: 修复
- 提交: wzy3650
- 2025.06.09 [Commit#8056efe4](https://github.com/RVC-Boss/GPT-SoVITS/commit/8056efe4ab7bbc3610c72ae356a6f37518441f7d)
- 内容: 修复ge.sum数值可能爆炸导致推理无声的问题
- 类型: 修复
- 提交: RVC-Boss
- 2025.06.10 [Commit#2c0436b9](https://github.com/RVC-Boss/GPT-SoVITS/commit/2c0436b9ce397424ae03476c836fb64c6e5ebcc6)
- 内容: 修复实验名结尾出现空格在win中路径不正确的问题
- 类型: 修复
- 提交: RVC-Boss

View File

@@ -507,6 +507,7 @@ def open1Ba(
):
global p_train_SoVITS
if p_train_SoVITS == None:
exp_name=exp_name.rstrip(" ")
config_file = (
"GPT_SoVITS/configs/s2.json"
if version not in {"v2Pro", "v2ProPlus"}
@@ -603,6 +604,7 @@ def open1Bb(
):
global p_train_GPT
if p_train_GPT == None:
exp_name=exp_name.rstrip(" ")
with open(
"GPT_SoVITS/configs/s1longer.yaml" if version == "v1" else "GPT_SoVITS/configs/s1longer-v2.yaml"
) as f:
@@ -785,6 +787,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1a == []:
opt_dir = "%s/%s" % (exp_root, exp_name)
config = {
@@ -874,6 +877,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1b == []:
config = {
"inp_text": inp_text,
@@ -962,6 +966,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
inp_text = my_utils.clean_path(inp_text)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1c == []:
opt_dir = "%s/%s" % (exp_root, exp_name)
config_file = (
@@ -1059,6 +1064,7 @@ def open1abc(
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1abc == []:
opt_dir = "%s/%s" % (exp_root, exp_name)
try: