新增VITS批量推理 GPT_SoVITS/TTS_infer_pack/TTS.py
fix some bugs GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py fix some bugs GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py fix some bugs GPT_SoVITS/inference_webui.py fix some bugs GPT_SoVITS/module/models.py
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@@ -985,6 +986,55 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
|
||||
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
|
||||
# codes.dtype
|
||||
# )
|
||||
y_lengths = (y_lengths * 2).long().to(codes.device)
|
||||
text_lengths = text_lengths.long().to(text.device)
|
||||
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
z_masked = (z * y_mask)[:, :, :]
|
||||
|
||||
# 串行。把padding部分去掉再decode
|
||||
o_list:List[torch.Tensor] = []
|
||||
for i in range(z_masked.shape[0]):
|
||||
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
|
||||
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
|
||||
o_list.append(o)
|
||||
|
||||
# 并行(会有问题)。先decode,再把padding的部分去掉
|
||||
# o = self.dec(z_masked, g=ge)
|
||||
# upsample_rate = int(math.prod(self.upsample_rates))
|
||||
# o_lengths = y_lengths*upsample_rate
|
||||
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
|
||||
|
||||
return o_list
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
|
||||
Reference in New Issue
Block a user