新增VITS批量推理 GPT_SoVITS/TTS_infer_pack/TTS.py

fix some bugs   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	fix some bugs   GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
	fix some bugs   GPT_SoVITS/inference_webui.py
	fix some bugs   GPT_SoVITS/module/models.py
This commit is contained in:
chasonjiang
2024-03-10 21:37:28 +08:00
parent 174c4bbab3
commit 3535cfe3b0
5 changed files with 182 additions and 44 deletions

View File

@@ -1,5 +1,6 @@
import copy
import math
from typing import List
import torch
from torch import nn
from torch.nn import functional as F
@@ -985,6 +986,55 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
# codes.dtype
# )
y_lengths = (y_lengths * 2).long().to(codes.device)
text_lengths = text_lengths.long().to(text.device)
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
z_masked = (z * y_mask)[:, :, :]
# 串行。把padding部分去掉再decode
o_list:List[torch.Tensor] = []
for i in range(z_masked.shape[0]):
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
o_list.append(o)
# 并行会有问题。先decode再把padding的部分去掉
# o = self.dec(z_masked, g=ge)
# upsample_rate = int(math.prod(self.upsample_rates))
# o_lengths = y_lengths*upsample_rate
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
return o_list
def extract_latent(self, x):
ssl = self.ssl_proj(x)