Compare commits
14 Commits
20250606v2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7dec5f5bb0 | ||
|
|
1a9b8854ee | ||
|
|
5c91e66d2e | ||
|
|
ed89a02337 | ||
|
|
cd6de7398e | ||
|
|
dd2b9253aa | ||
|
|
29165eb02e | ||
|
|
746cb536c6 | ||
|
|
0d2f273402 | ||
|
|
d39836b8fa | ||
|
|
2c0436b9ce | ||
|
|
8056efe4ab | ||
|
|
d6b78c927a | ||
|
|
74e79ae6d6 |
@@ -354,7 +354,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
@@ -362,7 +362,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
Show information of parameter which dominating tot_sumsq.
|
||||
|
||||
Args:
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
@@ -415,7 +415,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f"Parameter Dominating tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
|
||||
@@ -1073,7 +1073,7 @@ class TTS:
|
||||
|
||||
###### setting reference audio and prompt text preprocessing ########
|
||||
t0 = time.perf_counter()
|
||||
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
|
||||
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"] or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)):
|
||||
if not os.path.exists(ref_audio_path):
|
||||
raise ValueError(f"{ref_audio_path} not exists")
|
||||
self.set_ref_audio(ref_audio_path)
|
||||
|
||||
@@ -159,6 +159,10 @@ class TextPreprocessor:
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
import torch
|
||||
@@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn
|
||||
|
||||
from inference_webui import get_phones_and_bert
|
||||
|
||||
from sv import SV
|
||||
import kaldi as Kaldi
|
||||
|
||||
import os
|
||||
import soundfile
|
||||
|
||||
@@ -32,6 +36,22 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
sv_cn_model = None
|
||||
def init_sv_cn(device, is_half):
|
||||
global sv_cn_model
|
||||
sv_cn_model = SV(device, is_half)
|
||||
|
||||
def load_sovits_new(sovits_path):
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != b"PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
bio.seek(0)
|
||||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
@@ -83,7 +103,7 @@ def logits_to_probs(
|
||||
@torch.jit.script
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
# Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
q = torch.empty_like(probs_sort).exponential_(1.0)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
@@ -94,7 +114,7 @@ def sample(
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits,
|
||||
@@ -109,8 +129,8 @@ def sample(
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
def spectrogram_torch(hann_window:Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||
# hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@@ -289,8 +309,9 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
@@ -328,15 +349,22 @@ class T2STransformer:
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
def __init__(self, vits_path, version=None, is_half=True, device="cpu"):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path, weights_only=False)
|
||||
dict_s2 = load_sovits_new(vits_path)
|
||||
self.hps = dict_s2["config"]
|
||||
|
||||
if version is None:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
else:
|
||||
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
|
||||
self.hps["model"]["version"] = version
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
@@ -346,11 +374,18 @@ class VitsModel(nn.Module):
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vq_model.dec.remove_weight_norm()
|
||||
if is_half:
|
||||
self.vq_model = self.vq_model.half()
|
||||
self.vq_model = self.vq_model.to(device)
|
||||
self.vq_model.eval()
|
||||
self.hann_window = torch.hann_window(self.hps.data.win_length, device=device, dtype= torch.float16 if is_half else torch.float32)
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio,
|
||||
self.hps.data.filter_length,
|
||||
self.hps.data.sampling_rate,
|
||||
@@ -358,7 +393,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
@@ -632,7 +667,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
@@ -640,7 +675,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ssl_content = ssl(ref_audio).to(device)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path).to(device)
|
||||
vits = VitsModel(vits_path,device=device,is_half=False)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
@@ -679,6 +714,124 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
print("#### exported gpt_sovits ####")
|
||||
|
||||
|
||||
def export_prov2(
|
||||
gpt_path,
|
||||
vits_path,
|
||||
version,
|
||||
ref_audio_path,
|
||||
ref_text,
|
||||
output_path,
|
||||
export_bert_and_ssl=False,
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
):
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn(device,is_half)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
print(f"目录已创建: {output_path}")
|
||||
else:
|
||||
print(f"目录已存在: {output_path}")
|
||||
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
if export_bert_and_ssl:
|
||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||
torch.jit.script(s).save(ssl_path)
|
||||
print("#### exported ssl ####")
|
||||
export_bert(output_path)
|
||||
else:
|
||||
s = ExportSSLModel(ssl)
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||
ref_text, "all_zh", "v2"
|
||||
)
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T
|
||||
if is_half:
|
||||
ref_bert = ref_bert.half()
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
if is_half:
|
||||
text_bert = text_bert.half()
|
||||
text_bert = text_bert.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio)
|
||||
if is_half:
|
||||
ssl_content = ssl_content.half()
|
||||
ssl_content = ssl_content.to(device)
|
||||
|
||||
sv_model = ExportERes2NetV2(sv_cn_model)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
if is_half:
|
||||
raw_t2s = raw_t2s.half()
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m).to(device)
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
|
||||
gpt_sovits.eval()
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
|
||||
if is_half:
|
||||
ref_audio_sr = ref_audio_sr.half()
|
||||
ref_audio_sr = ref_audio_sr.to(device)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
# torch._dynamo.mark_dynamic(sv_emb, 0)
|
||||
|
||||
top_k = torch.LongTensor([5]).to(device)
|
||||
# 先跑一遍 sv_model 让它加载 cache,详情见 L880
|
||||
gpt_sovits.sv_model(ref_audio_sr)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert,
|
||||
top_k,
|
||||
),
|
||||
)
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
print("#### exported gpt_sovits ####")
|
||||
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
print("start write wav")
|
||||
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||
@@ -717,6 +870,66 @@ class GPT_SoVITS(nn.Module):
|
||||
return audio
|
||||
|
||||
|
||||
class ExportERes2NetV2(nn.Module):
|
||||
def __init__(self, sv_cn_model:SV):
|
||||
super(ExportERes2NetV2, self).__init__()
|
||||
self.bn1 = sv_cn_model.embedding_model.bn1
|
||||
self.conv1 = sv_cn_model.embedding_model.conv1
|
||||
self.layer1 = sv_cn_model.embedding_model.layer1
|
||||
self.layer2 = sv_cn_model.embedding_model.layer2
|
||||
self.layer3 = sv_cn_model.embedding_model.layer3
|
||||
self.layer4 = sv_cn_model.embedding_model.layer4
|
||||
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
|
||||
self.fuse34 = sv_cn_model.embedding_model.fuse34
|
||||
|
||||
# audio_16k.shape: [1,N]
|
||||
def forward(self, audio_16k):
|
||||
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
|
||||
# 只跟 device 和 dtype 有关
|
||||
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
|
||||
x = torch.stack([x])
|
||||
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
|
||||
|
||||
class GPT_SoVITS_V2Pro(nn.Module):
|
||||
def __init__(self, t2s: T2SModel, vits: VitsModel,sv_model:ExportERes2NetV2):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
self.sv_model = sv_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content: torch.Tensor,
|
||||
ref_audio_sr: torch.Tensor,
|
||||
ref_seq: Tensor,
|
||||
text_seq: Tensor,
|
||||
ref_bert: Tensor,
|
||||
text_bert: Tensor,
|
||||
top_k: LongTensor,
|
||||
speed=1.0,
|
||||
):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||
sv_emb = self.sv_model(audio_16k)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
|
||||
return audio
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
@@ -833,14 +1046,41 @@ def export_symbel(version="v2"):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||
parser.add_argument(
|
||||
"--sovits_model", required=True, help="Path to the SoVITS model file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_audio", required=True, help="Path to the reference audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_text", required=True, help="Path to the reference text file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", required=True, help="Path to the output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_common_model", action="store_true", help="Export Bert and SSL model"
|
||||
)
|
||||
parser.add_argument("--device", help="Device to use")
|
||||
parser.add_argument("--version", help="version of the model", default="v2")
|
||||
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.version in ["v2Pro", "v2ProPlus"]:
|
||||
is_half = not args.no_half
|
||||
print(f"Using half precision: {is_half}")
|
||||
export_prov2(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
version=args.version,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
device=args.device,
|
||||
is_half=is_half,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
@@ -852,10 +1092,7 @@ def main():
|
||||
)
|
||||
|
||||
|
||||
import inference_webui
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference_webui.is_half = False
|
||||
inference_webui.dtype = torch.float32
|
||||
with torch.no_grad():
|
||||
main()
|
||||
# test()
|
||||
|
||||
@@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
@@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
@@ -1149,7 +1153,7 @@ def export_2(version="v3"):
|
||||
raw_t2s = raw_t2s.half().to(device)
|
||||
t2s_m = T2SModel(raw_t2s).half().to(device)
|
||||
t2s_m.eval()
|
||||
t2s_m = torch.jit.script(t2s_m)
|
||||
t2s_m = torch.jit.script(t2s_m).to(device)
|
||||
t2s_m.eval()
|
||||
# t2s_m.top_k = 15
|
||||
logger.info("t2s_m ok")
|
||||
@@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3():
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
# export_2("v4")
|
||||
# export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
export_2("v4")
|
||||
# test_export_gpt_sovits_v3()
|
||||
|
||||
@@ -214,7 +214,7 @@ v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global vq_model, hps, version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
@@ -361,7 +361,7 @@ except:
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path:
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
global hz, max_sec, t2s_model, config
|
||||
hz = 50
|
||||
@@ -623,6 +623,10 @@ def get_phones_and_bert(text, language, version, final=False):
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
|
||||
@@ -114,11 +114,11 @@ tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
tts_config.version = version
|
||||
if gpt_path is not None:
|
||||
if "!" in gpt_path:
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
if "!" in sovits_path:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
@@ -217,7 +217,7 @@ v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
@@ -283,6 +283,12 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_pipeline.init_t2s_weights(gpt_path)
|
||||
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
|
||||
gr.HTML(
|
||||
top_html.format(
|
||||
@@ -457,7 +463,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
inference_button,
|
||||
],
|
||||
) #
|
||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(
|
||||
|
||||
@@ -762,6 +762,7 @@ class CodePredictor(nn.Module):
|
||||
|
||||
return pred_codes.transpose(0, 1)
|
||||
|
||||
v2pro_set={"v2Pro","v2ProPlus"}
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
@@ -867,19 +868,32 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.text_embedding.requires_grad_(False)
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
self.is_v2pro=self.version in v2pro_set
|
||||
if self.is_v2pro:
|
||||
self.sv_emb = nn.Linear(20480, gin_channels)
|
||||
self.ge_to512 = nn.Linear(gin_channels, 512)
|
||||
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
||||
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb)
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
if self.is_v2pro:
|
||||
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
|
||||
else:
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import math
|
||||
import pdb
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -718,8 +720,10 @@ class MelStyleEncoder(nn.Module):
|
||||
else:
|
||||
len_ = (~mask).sum(dim=1).unsqueeze(1)
|
||||
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
||||
x = x.sum(dim=1)
|
||||
out = torch.div(x, len_)
|
||||
dtype=x.dtype
|
||||
x = x.float()
|
||||
x=torch.div(x,len_.unsqueeze(1))
|
||||
out=x.sum(dim=1).to(dtype)
|
||||
return out
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
@@ -743,7 +747,6 @@ class MelStyleEncoder(nn.Module):
|
||||
x = self.fc(x)
|
||||
# temoral average pooling
|
||||
w = self.temporal_avg_pool(x, mask=mask)
|
||||
|
||||
return w.unsqueeze(-1)
|
||||
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
def load_sovits_new(sovits_path):
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != "PK":
|
||||
if meta != b"PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
|
||||
@@ -283,7 +283,7 @@ def get_hparams_from_file(config_path):
|
||||
def check_git_hash(model_dir):
|
||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir,
|
||||
)
|
||||
@@ -296,7 +296,7 @@ def check_git_hash(model_dir):
|
||||
if os.path.exists(path):
|
||||
saved_hash = open(path).read()
|
||||
if saved_hash != cur_hash:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8],
|
||||
cur_hash[:8],
|
||||
|
||||
252
api.py
252
api.py
@@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from io import BytesIO
|
||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
@@ -198,8 +198,38 @@ def is_full(*items): # 任意一项为空返回False
|
||||
return True
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
bigvgan_model = hifigan_model = sv_cn_model = None
|
||||
def clean_hifigan_model():
|
||||
global hifigan_model
|
||||
if hifigan_model:
|
||||
hifigan_model = hifigan_model.cpu()
|
||||
hifigan_model = None
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
def clean_bigvgan_model():
|
||||
global bigvgan_model
|
||||
if bigvgan_model:
|
||||
bigvgan_model = bigvgan_model.cpu()
|
||||
bigvgan_model = None
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
def clean_sv_cn_model():
|
||||
global sv_cn_model
|
||||
if sv_cn_model:
|
||||
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
|
||||
sv_cn_model = None
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model, hifigan_model,sv_cn_model
|
||||
from BigVGAN import bigvgan
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
@@ -209,20 +239,53 @@ def init_bigvgan():
|
||||
# remove weight norm in the model and set to eval mode
|
||||
bigvgan_model.remove_weight_norm()
|
||||
bigvgan_model = bigvgan_model.eval()
|
||||
|
||||
if is_half == True:
|
||||
bigvgan_model = bigvgan_model.half().to(device)
|
||||
else:
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
|
||||
def init_hifigan():
|
||||
global hifigan_model, bigvgan_model,sv_cn_model
|
||||
hifigan_model = Generator(
|
||||
initial_channel=100,
|
||||
resblock="1",
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates=[10, 6, 2, 2, 2],
|
||||
upsample_initial_channel=512,
|
||||
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||
gin_channels=0,
|
||||
is_bias=True,
|
||||
)
|
||||
hifigan_model.eval()
|
||||
hifigan_model.remove_weight_norm()
|
||||
state_dict_g = torch.load(
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
|
||||
)
|
||||
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||
if is_half == True:
|
||||
hifigan_model = hifigan_model.half().to(device)
|
||||
else:
|
||||
hifigan_model = hifigan_model.to(device)
|
||||
|
||||
|
||||
from sv import SV
|
||||
def init_sv_cn():
|
||||
global hifigan_model, bigvgan_model, sv_cn_model
|
||||
sv_cn_model = SV(device, is_half)
|
||||
|
||||
|
||||
resample_transform_dict={}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0):
|
||||
def resample(audio_tensor, sr0,sr1,device):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
key="%s-%s-%s"%(sr0,sr1,str(device))
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(
|
||||
sr0, sr1
|
||||
).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
|
||||
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
@@ -252,6 +315,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
sr_model = None
|
||||
@@ -293,12 +369,18 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
from config import pretrained_sovits_name
|
||||
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
|
||||
if if_lora_v3 == True and is_exist == False:
|
||||
logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||||
|
||||
dict_s2 = load_sovits_new(sovits_path)
|
||||
hps = dict_s2["config"]
|
||||
@@ -311,11 +393,13 @@ def get_sovits_weights(sovits_path):
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
|
||||
if model_version == "v3":
|
||||
hps.model.version = "v3"
|
||||
|
||||
model_params_dict = vars(hps.model)
|
||||
if model_version != "v3":
|
||||
if model_version not in {"v3", "v4"}:
|
||||
if "Pro" in model_version:
|
||||
hps.model.version = model_version
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn()
|
||||
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
@@ -323,13 +407,18 @@ def get_sovits_weights(sovits_path):
|
||||
**model_params_dict,
|
||||
)
|
||||
else:
|
||||
hps.model.version = model_version
|
||||
vq_model = SynthesizerTrnV3(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict,
|
||||
)
|
||||
if model_version == "v3":
|
||||
init_bigvgan()
|
||||
if model_version == "v4":
|
||||
init_hifigan()
|
||||
|
||||
model_version = hps.model.version
|
||||
logger.info(f"模型版本: {model_version}")
|
||||
if "pretrained" not in sovits_path:
|
||||
@@ -345,7 +434,8 @@ def get_sovits_weights(sovits_path):
|
||||
if if_lora_v3 == False:
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
else:
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False)
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@@ -479,6 +569,10 @@ def get_phones_and_bert(text, language, version, final=False):
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
@@ -533,23 +627,32 @@ class DictToAttrRecursive(dict):
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio, _ = librosa.load(filename, sr=int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
||||
sr1=int(hps.data.sampling_rate)
|
||||
audio, sr0=torchaudio.load(filename)
|
||||
if sr0!=sr1:
|
||||
audio=audio.to(device)
|
||||
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
|
||||
audio=resample(audio,sr0,sr1,device)
|
||||
else:
|
||||
audio=audio.to(device)
|
||||
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
|
||||
|
||||
maxx = audio.abs().max()
|
||||
if maxx > 1:
|
||||
audio /= min(2, maxx)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
audio,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return spec
|
||||
spec=spec.to(dtype)
|
||||
if is_v2pro==True:
|
||||
audio=resample(audio,sr1,16000,device).to(dtype)
|
||||
return spec,audio
|
||||
|
||||
|
||||
def pack_audio(audio_bytes, data, rate):
|
||||
@@ -736,6 +839,16 @@ def get_tts_wav(
|
||||
t2s_model = infer_gpt.t2s_model
|
||||
max_sec = infer_gpt.max_sec
|
||||
|
||||
if version == "v3":
|
||||
if sample_steps not in [4, 8, 16, 32, 64, 128]:
|
||||
sample_steps = 32
|
||||
elif version == "v4":
|
||||
if sample_steps not in [4, 8, 16, 32]:
|
||||
sample_steps = 8
|
||||
|
||||
if if_sr and version != "v3":
|
||||
if_sr = False
|
||||
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if prompt_text[-1] not in splits:
|
||||
@@ -759,19 +872,29 @@ def get_tts_wav(
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
if version != "v3":
|
||||
is_v2pro = version in {"v2Pro","v2ProPlus"}
|
||||
if version not in {"v3", "v4"}:
|
||||
refers = []
|
||||
if is_v2pro:
|
||||
sv_emb= []
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn()
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||
try:#####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
|
||||
refer,audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro)
|
||||
refers.append(refer)
|
||||
if is_v2pro:
|
||||
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if len(refers) == 0:
|
||||
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
refers,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro)
|
||||
refers=[refers]
|
||||
if is_v2pro:
|
||||
sv_emb=[sv_cn_model.compute_embedding3(audio_tensor)]
|
||||
else:
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
refer,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
||||
|
||||
t1 = ttime()
|
||||
# os.environ['version'] = version
|
||||
@@ -811,41 +934,48 @@ def get_tts_wav(
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
t3 = ttime()
|
||||
|
||||
if version != "v3":
|
||||
if version not in {"v3", "v4"}:
|
||||
if is_v2pro:
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
)
|
||||
else:
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
) ###试试重建不带上prompt部分
|
||||
)
|
||||
else:
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
||||
|
||||
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
if sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr, tgt_sr, device)
|
||||
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if T_min > 468:
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||
# print("mel2",mel2)
|
||||
Tref = 468 if version == "v3" else 500
|
||||
Tchunk = 934 if version == "v3" else 1000
|
||||
if T_min > Tref:
|
||||
mel2 = mel2[:, :, -Tref:]
|
||||
fea_ref = fea_ref[:, :, -Tref:]
|
||||
T_min = Tref
|
||||
chunk_len = Tchunk - T_min
|
||||
mel2 = mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||
# print("fea_todo",fea_todo)
|
||||
# print("ge",ge.abs().mean())
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while 1:
|
||||
@@ -854,22 +984,24 @@ def get_tts_wav(
|
||||
break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
# set_seed(123)
|
||||
cfm_res = vq_model.cfm.inference(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||
)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
# print("fea", fea)
|
||||
# print("mel2in", mel2)
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
if version == "v3":
|
||||
if bigvgan_model == None:
|
||||
init_bigvgan()
|
||||
else: # v4
|
||||
if hifigan_model == None:
|
||||
init_hifigan()
|
||||
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
wav_gen = vocoder_model(cfm_res)
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
max_audio = np.abs(audio).max()
|
||||
@@ -880,7 +1012,13 @@ def get_tts_wav(
|
||||
audio_opt = np.concatenate(audio_opt, 0)
|
||||
t4 = ttime()
|
||||
|
||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
||||
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||
sr = 32000
|
||||
elif version == "v3":
|
||||
sr = 24000
|
||||
else:
|
||||
sr = 48000 # v4
|
||||
|
||||
if if_sr and sr == 24000:
|
||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||
@@ -900,8 +1038,12 @@ def get_tts_wav(
|
||||
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||
sr = 32000
|
||||
elif version == "v3":
|
||||
sr = 48000 if if_sr else 24000
|
||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
||||
else:
|
||||
sr = 48000 # v4
|
||||
audio_bytes = pack_wav(audio_bytes, sr)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
@@ -966,8 +1108,6 @@ def handle(
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||
|
||||
if sample_steps not in [4, 8, 16, 32]:
|
||||
sample_steps = 32
|
||||
|
||||
if cut_punc == None:
|
||||
text = cut_text(text, default_cut_punc)
|
||||
@@ -1071,10 +1211,10 @@ default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, a
|
||||
# 模型路径检查
|
||||
if sovits_path == "":
|
||||
sovits_path = g_config.pretrained_sovits_path
|
||||
logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||
logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||
if gpt_path == "":
|
||||
gpt_path = g_config.pretrained_gpt_path
|
||||
logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
|
||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
||||
|
||||
@@ -1,442 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import pdb
|
||||
import signal
|
||||
import sys
|
||||
from time import time as ttime
|
||||
import torch
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import uvicorn
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from io import BytesIO
|
||||
from module.models import SynthesizerTrn
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from my_utils import load_audio
|
||||
import config as global_config
|
||||
|
||||
g_config = global_config.Config()
|
||||
|
||||
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
|
||||
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
|
||||
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
|
||||
|
||||
parser.add_argument("-dr", "--default_refer_path", type=str, default="",
|
||||
help="默认参考音频路径, 请求缺少参考音频时调用")
|
||||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
||||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
||||
|
||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||
# bool值的用法为 `python ./api.py -fp ...`
|
||||
# 此时 full_precision==True, half_precision==False
|
||||
|
||||
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
||||
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
sovits_path = args.sovits_path
|
||||
gpt_path = args.gpt_path
|
||||
|
||||
default_refer_path = args.default_refer_path
|
||||
default_refer_text = args.default_refer_text
|
||||
default_refer_language = args.default_refer_language
|
||||
has_preset = False
|
||||
|
||||
device = args.device
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
|
||||
if sovits_path == "":
|
||||
sovits_path = g_config.pretrained_sovits_path
|
||||
print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||
if gpt_path == "":
|
||||
gpt_path = g_config.pretrained_gpt_path
|
||||
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
|
||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "":
|
||||
default_refer_path, default_refer_text, default_refer_language = "", "", ""
|
||||
print("[INFO] 未指定默认参考音频")
|
||||
has_preset = False
|
||||
else:
|
||||
print(f"[INFO] 默认参考音频路径: {default_refer_path}")
|
||||
print(f"[INFO] 默认参考音频文本: {default_refer_text}")
|
||||
print(f"[INFO] 默认参考音频语种: {default_refer_language}")
|
||||
has_preset = True
|
||||
|
||||
is_half = g_config.is_half
|
||||
if args.full_precision:
|
||||
is_half = False
|
||||
if args.half_precision:
|
||||
is_half = True
|
||||
if args.full_precision and args.half_precision:
|
||||
is_half = g_config.is_half # 炒饭fallback
|
||||
|
||||
print(f"[INFO] 半精: {is_half}")
|
||||
|
||||
cnhubert_base_path = args.hubert_path
|
||||
bert_path = args.bert_path
|
||||
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half:
|
||||
bert_model = bert_model.half().to(device)
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
||||
return phone_level_feature.T
|
||||
|
||||
|
||||
n_semantic = 1024
|
||||
dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
hps = dict_s2["config"]
|
||||
print(hps)
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
|
||||
config = dict_s1["config"]
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half:
|
||||
ssl_model = ssl_model.half().to(device)
|
||||
else:
|
||||
ssl_model = ssl_model.to(device)
|
||||
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model)
|
||||
if is_half:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
hz = 50
|
||||
max_sec = config['data']['max_sec']
|
||||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(device)
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
|
||||
hps.data.win_length, center=False)
|
||||
return spec
|
||||
|
||||
|
||||
dict_language = {
|
||||
"中文": "zh",
|
||||
"英文": "en",
|
||||
"日文": "ja",
|
||||
"ZH": "zh",
|
||||
"EN": "en",
|
||||
"JA": "ja",
|
||||
"zh": "zh",
|
||||
"en": "en",
|
||||
"ja": "ja"
|
||||
}
|
||||
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
prompt_language, text = prompt_language, text.strip("\n")
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if (is_half == True):
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k=torch.cat([wav16k,zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
texts = text.split("\n")
|
||||
audio_opt = []
|
||||
|
||||
for text in texts:
|
||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||
phones2 = cleaned_text_to_sequence(phones2)
|
||||
if (prompt_language == "zh"):
|
||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:
|
||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
||||
device)
|
||||
if (text_language == "zh"):
|
||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:
|
||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# pred_semantic = t2s_model.model.infer(
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=config['inference']['top_k'],
|
||||
early_stop_num=hz * max_sec)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if (is_half == True):
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refer).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
# yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
return hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
def get_tts_wavs(ref_wav_path, prompt_text, prompt_language, textss, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if (is_half == True):
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k=torch.cat([wav16k,zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
audios_opt=[]
|
||||
for text0 in textss:
|
||||
texts = text0.strip("\n").split("\n")
|
||||
audio_opt = []
|
||||
for text in texts:
|
||||
text=text.strip("。")+"。"
|
||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||
phones2 = cleaned_text_to_sequence(phones2)
|
||||
if (prompt_language == "zh"):
|
||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:
|
||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
||||
device)
|
||||
if (text_language == "zh"):
|
||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:
|
||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# pred_semantic = t2s_model.model.infer(
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=config['inference']['top_k'],
|
||||
early_stop_num=hz * max_sec)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if (is_half == True):
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refer).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
audios_opt.append([text0,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16)])
|
||||
return audios_opt
|
||||
|
||||
|
||||
# get_tts_wav(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", "我觉得还是该给喜欢的女孩子一场认真的告白。", "中文")
|
||||
# with open(r"D:\BaiduNetdiskDownload\gsv\烟嗓-todo1.txt","r",encoding="utf8")as f:
|
||||
# with open(r"D:\BaiduNetdiskDownload\gsv\年下-todo1.txt","r",encoding="utf8")as f:
|
||||
# with open(r"D:\BaiduNetdiskDownload\gsv\萧逸3b.txt","r",encoding="utf8")as f:
|
||||
with open(r"D:\BaiduNetdiskDownload\gsv\萧逸4.txt","r",encoding="utf8")as f:
|
||||
textss=f.read().split("\n")
|
||||
for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", textss, "中文")):
|
||||
|
||||
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\足够的能力,去制定好自己的生活规划。低沉烟嗓.MP3_1940480_2095360.wav", "足够的能力,去制定好自己的生活规划。", "中文", textss, "中文")):
|
||||
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\不会呀!你前几天才吃过你还说好吃来着。年下少年音.MP3_537600_711040.wav", "不会呀!你前几天才吃过你还说好吃来着。", "中文", textss, "中文")):
|
||||
print(idx,text)
|
||||
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\烟嗓第一批\%04d-%s.wav"%(idx,text),audio,32000)
|
||||
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\年下\%04d-%s.wav"%(idx,text),audio,32000)
|
||||
sf.write(r"D:\BaiduNetdiskDownload\gsv\output\萧逸第4批\%04d-%s.wav"%(idx,text),audio,32000)
|
||||
|
||||
|
||||
# def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
# if command == "/restart":
|
||||
# os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
||||
# elif command == "/exit":
|
||||
# os.kill(os.getpid(), signal.SIGTERM)
|
||||
# exit(0)
|
||||
#
|
||||
# if (
|
||||
# refer_wav_path == "" or refer_wav_path is None
|
||||
# or prompt_text == "" or prompt_text is None
|
||||
# or prompt_language == "" or prompt_language is None
|
||||
# ):
|
||||
# refer_wav_path, prompt_text, prompt_language = (
|
||||
# default_refer_path,
|
||||
# default_refer_text,
|
||||
# default_refer_language,
|
||||
# )
|
||||
# if not has_preset:
|
||||
# raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# gen = get_tts_wav(
|
||||
# refer_wav_path, prompt_text, prompt_language, text, text_language
|
||||
# )
|
||||
# sampling_rate, audio_data = next(gen)
|
||||
#
|
||||
# wav = BytesIO()
|
||||
# sf.write(wav, audio_data, sampling_rate, format="wav")
|
||||
# wav.seek(0)
|
||||
#
|
||||
# torch.cuda.empty_cache()
|
||||
# return StreamingResponse(wav, media_type="audio/wav")
|
||||
|
||||
|
||||
# app = FastAPI()
|
||||
#
|
||||
#
|
||||
# @app.post("/")
|
||||
# async def tts_endpoint(request: Request):
|
||||
# json_post_raw = await request.json()
|
||||
# return handle(
|
||||
# json_post_raw.get("command"),
|
||||
# json_post_raw.get("refer_wav_path"),
|
||||
# json_post_raw.get("prompt_text"),
|
||||
# json_post_raw.get("prompt_language"),
|
||||
# json_post_raw.get("text"),
|
||||
# json_post_raw.get("text_language"),
|
||||
# )
|
||||
#
|
||||
#
|
||||
# @app.get("/")
|
||||
# async def tts_endpoint(
|
||||
# command: str = None,
|
||||
# refer_wav_path: str = None,
|
||||
# prompt_text: str = None,
|
||||
# prompt_language: str = None,
|
||||
# text: str = None,
|
||||
# text_language: str = None,
|
||||
# ):
|
||||
# return handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# uvicorn.run(app, host=host, port=port, workers=1)
|
||||
@@ -578,3 +578,19 @@
|
||||
- 内容: 优化精度自动检测逻辑, 给 WebUI 前端界面模块增加折叠功能.
|
||||
- 类型: 新功能
|
||||
- 提交: XXXXRT666, RVC-Boss
|
||||
- 2025.06.06 [PR#2427](https://github.com/RVC-Boss/GPT-SoVITS/pull/2427)
|
||||
- 内容: X一X型多音字判断修复
|
||||
- 类型: 修复
|
||||
- 提交: wzy3650
|
||||
- 2025.06.05 [PR#2439](https://github.com/RVC-Boss/GPT-SoVITS/pull/2439)
|
||||
- 内容: 配置修复;sovits模型读取修复
|
||||
- 类型: 修复
|
||||
- 提交: wzy3650
|
||||
- 2025.06.09 [Commit#8056efe4](https://github.com/RVC-Boss/GPT-SoVITS/commit/8056efe4ab7bbc3610c72ae356a6f37518441f7d)
|
||||
- 内容: 修复ge.sum数值可能爆炸导致推理无声的问题
|
||||
- 类型: 修复
|
||||
- 提交: RVC-Boss
|
||||
- 2025.06.10 [Commit#2c0436b9](https://github.com/RVC-Boss/GPT-SoVITS/commit/2c0436b9ce397424ae03476c836fb64c6e5ebcc6)
|
||||
- 内容: 修复实验名结尾出现空格在win中路径不正确的问题
|
||||
- 类型: 修复
|
||||
- 提交: RVC-Boss
|
||||
|
||||
6
webui.py
6
webui.py
@@ -507,6 +507,7 @@ def open1Ba(
|
||||
):
|
||||
global p_train_SoVITS
|
||||
if p_train_SoVITS == None:
|
||||
exp_name=exp_name.rstrip(" ")
|
||||
config_file = (
|
||||
"GPT_SoVITS/configs/s2.json"
|
||||
if version not in {"v2Pro", "v2ProPlus"}
|
||||
@@ -603,6 +604,7 @@ def open1Bb(
|
||||
):
|
||||
global p_train_GPT
|
||||
if p_train_GPT == None:
|
||||
exp_name=exp_name.rstrip(" ")
|
||||
with open(
|
||||
"GPT_SoVITS/configs/s1longer.yaml" if version == "v1" else "GPT_SoVITS/configs/s1longer-v2.yaml"
|
||||
) as f:
|
||||
@@ -785,6 +787,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1a == []:
|
||||
opt_dir = "%s/%s" % (exp_root, exp_name)
|
||||
config = {
|
||||
@@ -874,6 +877,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1b == []:
|
||||
config = {
|
||||
"inp_text": inp_text,
|
||||
@@ -962,6 +966,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
|
||||
inp_text = my_utils.clean_path(inp_text)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1c == []:
|
||||
opt_dir = "%s/%s" % (exp_root, exp_name)
|
||||
config_file = (
|
||||
@@ -1059,6 +1064,7 @@ def open1abc(
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1abc == []:
|
||||
opt_dir = "%s/%s" % (exp_root, exp_name)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user