support gpt-sovits v4
support gpt-sovits v4
This commit is contained in:
@@ -414,7 +414,7 @@ class Generator(torch.nn.Module):
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
gin_channels=0,is_bias=False,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
@@ -442,7 +442,7 @@ class Generator(torch.nn.Module):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
@@ -1173,7 +1173,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(
|
||||
fea, mel_lengths, ge
|
||||
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
||||
@@ -1196,9 +1196,9 @@ class SynthesizerTrnV3(nn.Module):
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
if speed == 1:
|
||||
sizee = int(codes.size(2) * 2.5 * 1.5)
|
||||
sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4))
|
||||
else:
|
||||
sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1
|
||||
sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4) / speed) + 1
|
||||
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
@@ -1207,7 +1207,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea, ge
|
||||
|
||||
Reference in New Issue
Block a user