more code refactor
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,189 +1,189 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
||||
return kl
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += (
|
||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(
|
||||
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = (
|
||||
math.log(float(max_timescale) / float(min_timescale)) /
|
||||
(num_timescales - 1))
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
||||
num_timescales - 1
|
||||
)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2,3) * mask
|
||||
return path
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
return total_norm
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
def squeeze(x, x_mask=None, n_sqz=2):
|
||||
b, c, t = x.size()
|
||||
b, c, t = x.size()
|
||||
|
||||
t = (t // n_sqz) * n_sqz
|
||||
x = x[:, :, :t]
|
||||
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
|
||||
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
|
||||
t = (t // n_sqz) * n_sqz
|
||||
x = x[:, :, :t]
|
||||
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
|
||||
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
|
||||
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
|
||||
return x_sqz * x_mask, x_mask
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
|
||||
return x_sqz * x_mask, x_mask
|
||||
|
||||
|
||||
def unsqueeze(x, x_mask=None, n_sqz=2):
|
||||
b, c, t = x.size()
|
||||
b, c, t = x.size()
|
||||
|
||||
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
|
||||
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
|
||||
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
|
||||
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
|
||||
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
|
||||
return x_unsqz * x_mask, x_mask
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
|
||||
return x_unsqz * x_mask, x_mask
|
||||
|
||||
@@ -76,10 +76,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
|
||||
print("kmeans start ... ")
|
||||
for _ in tqdm(range(num_iters)):
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
||||
means, "c d -> () c d"
|
||||
)
|
||||
dists = -(diffs ** 2).sum(dim=-1)
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
||||
dists = -(diffs**2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
@@ -110,6 +108,7 @@ class EuclideanCodebook(nn.Module):
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -122,7 +121,9 @@ class EuclideanCodebook(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
||||
uniform_init if not kmeans_init else torch.zeros
|
||||
)
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
@@ -147,7 +148,7 @@ class EuclideanCodebook(nn.Module):
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
#broadcast_tensors(self.buffers())
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
@@ -165,7 +166,7 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
#broadcast_tensors(self.buffers())
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def preprocess(self, x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
@@ -246,6 +247,7 @@ class VectorQuantization(nn.Module):
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -256,22 +258,31 @@ class VectorQuantization(nn.Module):
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
commitment_weight: float = 1.,
|
||||
commitment_weight: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
||||
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
||||
self.project_in = (
|
||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
||||
decay=decay, epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code)
|
||||
self._codebook = EuclideanCodebook(
|
||||
dim=_codebook_dim,
|
||||
codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init,
|
||||
kmeans_iters=kmeans_iters,
|
||||
decay=decay,
|
||||
epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
||||
)
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
@property
|
||||
@@ -316,13 +327,16 @@ class ResidualVectorQuantization(nn.Module):
|
||||
"""Residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
|
||||
def forward(
|
||||
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
||||
):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
@@ -345,7 +359,9 @@ class ResidualVectorQuantization(nn.Module):
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses, out_quantized
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor:
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
@@ -358,10 +374,10 @@ class ResidualVectorQuantization(nn.Module):
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor:
|
||||
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[st + i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
||||
return quantized_out
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time,logging
|
||||
import time, logging
|
||||
import os
|
||||
import random,traceback
|
||||
import random, traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
@@ -16,41 +16,44 @@ import torch
|
||||
import requests
|
||||
from scipy.io import wavfile
|
||||
from io import BytesIO
|
||||
|
||||
# from config import exp_dir
|
||||
from my_utils import load_audio
|
||||
|
||||
|
||||
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
2) normalizes text and converts them to sequences of integers
|
||||
3) computes spectrograms from audio files.
|
||||
1) loads audio, speaker_id, text pairs
|
||||
2) normalizes text and converts them to sequences of integers
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, val=False):
|
||||
exp_dir=hparams.exp_dir
|
||||
self.path2="%s/2-name2text.txt"%exp_dir
|
||||
self.path4="%s/4-cnhubert"%exp_dir
|
||||
self.path5="%s/5-wav32k"%exp_dir
|
||||
exp_dir = hparams.exp_dir
|
||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
||||
self.path5 = "%s/5-wav32k" % exp_dir
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path4)
|
||||
assert os.path.exists(self.path5)
|
||||
names4=set([name[:-3]for name in list(os.listdir(self.path4))])#去除.pt后缀
|
||||
names5=set(os.listdir(self.path5))
|
||||
self.phoneme_data={}
|
||||
with open(self.path2,"r",encoding="utf8")as f:
|
||||
lines=f.read().strip("\n").split("\n")
|
||||
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
||||
names5 = set(os.listdir(self.path5))
|
||||
self.phoneme_data = {}
|
||||
with open(self.path2, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines:
|
||||
tmp=line.split("\t")
|
||||
if(len(tmp)!=4):continue
|
||||
self.phoneme_data[tmp[0]]=[tmp[1]]
|
||||
tmp = line.split("\t")
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
self.audiopaths_sid_text=list(set(self.phoneme_data)&names4&names5)
|
||||
tmp=self.audiopaths_sid_text
|
||||
leng=len(tmp)
|
||||
min_num=100
|
||||
if(leng<min_num):
|
||||
self.audiopaths_sid_text=[]
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
@@ -69,20 +72,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
|
||||
audiopaths_sid_text_new = []
|
||||
lengths = []
|
||||
skipped_phone = 0
|
||||
skipped_dur = 0
|
||||
skipped_phone = 0
|
||||
skipped_dur = 0
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
skipped_phone += 1
|
||||
skipped_phone += 1
|
||||
continue
|
||||
size=os.path.getsize("%s/%s"%(self.path5,audiopath))
|
||||
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
||||
duration = size / self.sampling_rate / 2
|
||||
if (54 > duration > 0.6 or self.val):
|
||||
if 54 > duration > 0.6 or self.val:
|
||||
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
||||
lengths.append(size // (2 * self.hop_length))
|
||||
else:
|
||||
@@ -90,7 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
continue
|
||||
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
||||
print("total left: ", len(audiopaths_sid_text_new))
|
||||
assert len(audiopaths_sid_text_new)>1#至少能凑够batch size,这里todo
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
|
||||
@@ -98,30 +101,41 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
audiopath, phoneme_ids = audiopath_sid_text
|
||||
text = torch.FloatTensor(phoneme_ids)
|
||||
try:
|
||||
spec, wav = self.get_audio("%s/%s"%(self.path5,audiopath))
|
||||
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt"%(self.path4,audiopath),map_location="cpu")
|
||||
if(ssl.shape[-1]!=spec.shape[-1]):
|
||||
typee=ssl.dtype
|
||||
ssl=F.pad(ssl.float(),(0,1),mode="replicate").to(typee)
|
||||
ssl.requires_grad=False
|
||||
ssl = torch.load(
|
||||
"%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
|
||||
)
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
except:
|
||||
traceback.print_exc()
|
||||
spec = torch.zeros(1025, 100)
|
||||
wav = torch.zeros(1, 100*self.hop_length)
|
||||
ssl=torch.zeros(1,768,100)
|
||||
text=text[-1:]
|
||||
wav = torch.zeros(1, 100 * self.hop_length)
|
||||
ssl = torch.zeros(1, 768, 100)
|
||||
text = text[-1:]
|
||||
print("load audio or ssl error!!!!!!", audiopath)
|
||||
# print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
|
||||
return (ssl, spec, wav, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio_array = load_audio(
|
||||
filename, self.sampling_rate
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,self.sampling_rate, self.hop_length, self.win_length,center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
return spec, audio_norm
|
||||
|
||||
@@ -131,39 +145,51 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
# with torch.no_grad():
|
||||
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
||||
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
|
||||
def random_slice(self, ssl, wav, mel):
|
||||
assert abs(ssl.shape[-1]- wav.shape[-1]//self.hop_length) < 3, ("first", ssl.shape, wav.shape)
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
||||
"first",
|
||||
ssl.shape,
|
||||
wav.shape,
|
||||
)
|
||||
|
||||
len_mel = mel.shape[1]
|
||||
if self.val:
|
||||
reference_mel = mel[:, :len_mel//3]
|
||||
reference_mel = mel[:, : len_mel // 3]
|
||||
return reference_mel, ssl, wav, mel
|
||||
dir = random.randint(0, 1)
|
||||
sep_point = random.randint(int(len_mel//3), int(len_mel//3*2))
|
||||
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
||||
|
||||
if dir == 0:
|
||||
reference_mel = mel[:, :sep_point]
|
||||
ssl = ssl[:, :, sep_point:]
|
||||
wav2 = wav[:, sep_point*self.hop_length:]
|
||||
wav2 = wav[:, sep_point * self.hop_length :]
|
||||
mel = mel[:, sep_point:]
|
||||
else:
|
||||
reference_mel = mel[:, sep_point:]
|
||||
ssl = ssl[:, :, :sep_point]
|
||||
wav2 = wav[:, :sep_point*self.hop_length]
|
||||
wav2 = wav[:, : sep_point * self.hop_length]
|
||||
mel = mel[:, :sep_point]
|
||||
|
||||
assert abs(ssl.shape[-1]- wav2.shape[-1]//self.hop_length) < 3, (ssl.shape, wav.shape,wav2.shape, mel.shape, sep_point,self.hop_length, sep_point*self.hop_length, dir)
|
||||
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
||||
ssl.shape,
|
||||
wav.shape,
|
||||
wav2.shape,
|
||||
mel.shape,
|
||||
sep_point,
|
||||
self.hop_length,
|
||||
sep_point * self.hop_length,
|
||||
dir,
|
||||
)
|
||||
return reference_mel, ssl, wav2, mel
|
||||
|
||||
|
||||
class TextAudioSpeakerCollate():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
class TextAudioSpeakerCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@@ -176,8 +202,8 @@ class TextAudioSpeakerCollate():
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
|
||||
)
|
||||
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
||||
@@ -194,7 +220,7 @@ class TextAudioSpeakerCollate():
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
||||
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
||||
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
|
||||
spec_padded.zero_()
|
||||
wav_padded.zero_()
|
||||
@@ -205,23 +231,31 @@ class TextAudioSpeakerCollate():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
return (
|
||||
ssl_padded,
|
||||
ssl_lengths,
|
||||
spec_padded,
|
||||
spec_lengths,
|
||||
wav_padded,
|
||||
wav_lengths,
|
||||
text_padded,
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
@@ -234,7 +268,15 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
boundaries,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
):
|
||||
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
||||
self.lengths = dataset.lengths
|
||||
# print(233333333333333,self.lengths,dir(dataset))
|
||||
@@ -254,7 +296,7 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
buckets[idx_bucket].append(i)
|
||||
|
||||
for i in range(len(buckets) - 1, 0, -1):
|
||||
# for i in range(len(buckets) - 1, -1, -1):
|
||||
# for i in range(len(buckets) - 1, -1, -1):
|
||||
if len(buckets[i]) == 0:
|
||||
buckets.pop(i)
|
||||
self.boundaries.pop(i + 1)
|
||||
@@ -263,7 +305,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
for i in range(len(buckets)):
|
||||
len_bucket = len(buckets[i])
|
||||
total_batch_size = self.num_replicas * self.batch_size
|
||||
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
|
||||
rem = (
|
||||
total_batch_size - (len_bucket % total_batch_size)
|
||||
) % total_batch_size
|
||||
num_samples_per_bucket.append(len_bucket + rem)
|
||||
return buckets, num_samples_per_bucket
|
||||
|
||||
@@ -289,14 +333,23 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
rem = num_samples_bucket - len_bucket
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
|
||||
ids_bucket = (
|
||||
ids_bucket
|
||||
+ ids_bucket * (rem // len_bucket)
|
||||
+ ids_bucket[: (rem % len_bucket)]
|
||||
)
|
||||
|
||||
# subsample
|
||||
ids_bucket = ids_bucket[self.rank::self.num_replicas]
|
||||
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
||||
|
||||
# batching
|
||||
for j in range(len(ids_bucket) // self.batch_size):
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
|
||||
batch = [
|
||||
bucket[idx]
|
||||
for idx in ids_bucket[
|
||||
j * self.batch_size : (j + 1) * self.batch_size
|
||||
]
|
||||
]
|
||||
batches.append(batch)
|
||||
|
||||
if self.shuffle:
|
||||
|
||||
@@ -5,64 +5,69 @@ from torch.nn import functional as F
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
r_loss = torch.mean((1-dr)**2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
dr = dr.float()
|
||||
dg = dg.float()
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1-dg)**2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
z_p, logs_q: [b, h, t_t]
|
||||
m_p, logs_p: [b, h, t_t]
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
m_p = m_p.float()
|
||||
logs_p = logs_p.float()
|
||||
z_mask = z_mask.float()
|
||||
"""
|
||||
z_p, logs_q: [b, h, t_t]
|
||||
m_p, logs_p: [b, h, t_t]
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
m_p = m_p.float()
|
||||
logs_p = logs_p.float()
|
||||
z_mask = z_mask.float()
|
||||
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
l = kl / torch.sum(z_mask)
|
||||
return l
|
||||
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
l = kl / torch.sum(z_mask)
|
||||
return l
|
||||
|
||||
def mle_loss(z, m, logs, logdet, mask):
|
||||
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term
|
||||
l = l - torch.sum(logdet) # log jacobian determinant
|
||||
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
|
||||
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
||||
return l
|
||||
l = torch.sum(logs) + 0.5 * torch.sum(
|
||||
torch.exp(-2 * logs) * ((z - m) ** 2)
|
||||
) # neg normal likelihood w/o the constant term
|
||||
l = l - torch.sum(logdet) # log jacobian determinant
|
||||
l = l / torch.sum(
|
||||
torch.ones_like(z) * mask
|
||||
) # averaging across batch, channel and time axes
|
||||
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
||||
return l
|
||||
|
||||
@@ -49,21 +49,37 @@ hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + '_' + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
@@ -71,37 +87,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + '_' + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + '_' + dtype_device
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=spec.dtype, device=spec.device
|
||||
)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
def mel_spectrogram_torch(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + '_' + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + '_' + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
|
||||
@@ -12,12 +12,21 @@ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
from text import symbols
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
n_flows=4,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
filter_channels = in_channels # it needs to be removed from future version.
|
||||
self.in_channels = in_channels
|
||||
@@ -31,21 +40,29 @@ class StochasticDurationPredictor(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
@@ -66,7 +83,10 @@ class StochasticDurationPredictor(nn.Module):
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
e_q = (
|
||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
@@ -74,8 +94,13 @@ class StochasticDurationPredictor(nn.Module):
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
@@ -84,12 +109,18 @@ class StochasticDurationPredictor(nn.Module):
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
z = (
|
||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
@@ -98,7 +129,9 @@ class StochasticDurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
def __init__(
|
||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@@ -108,9 +141,13 @@ class DurationPredictor(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@@ -135,15 +172,17 @@ class DurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
latent_channels=192):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
latent_channels=192,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -160,17 +199,14 @@ class TextEncoder(nn.Module):
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers//2,
|
||||
n_layers // 2,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
p_dropout,
|
||||
)
|
||||
|
||||
self.encoder_text = attentions.Encoder(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||
|
||||
self.mrte = MRTE()
|
||||
@@ -179,21 +215,25 @@ class TextEncoder(nn.Module):
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers//2,
|
||||
n_layers // 2,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
|
||||
p_dropout,
|
||||
)
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
|
||||
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
|
||||
if test == 1 :
|
||||
text_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(text_lengths, text.size(1)), 1
|
||||
).to(y.dtype)
|
||||
if test == 1:
|
||||
text[:, :] = 0
|
||||
text = self.text_embedding(text).transpose(1, 2)
|
||||
text = self.encoder_text(text * text_mask, text_mask)
|
||||
@@ -208,9 +248,9 @@ class TextEncoder(nn.Module):
|
||||
def extract_latent(self, x):
|
||||
x = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
||||
return codes.transpose(0,1)
|
||||
def decode_latent(self, codes, y_mask, refer,refer_mask, ge):
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
|
||||
y = self.vq_proj(quantized) * y_mask
|
||||
@@ -224,15 +264,18 @@ class TextEncoder(nn.Module):
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask, quantized
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -245,8 +288,16 @@ class ResidualCouplingBlock(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
||||
gin_channels=gin_channels, mean_only=True))
|
||||
modules.ResidualCouplingLayer(
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
mean_only=True,
|
||||
)
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
@@ -260,14 +311,16 @@ class ResidualCouplingBlock(nn.Module):
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@@ -278,13 +331,21 @@ class PosteriorEncoder(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.enc = modules.WN(
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if(g!=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
@@ -294,14 +355,16 @@ class PosteriorEncoder(nn.Module):
|
||||
|
||||
|
||||
class WNEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@@ -312,11 +375,20 @@ class WNEncoder(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.enc = modules.WN(
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.norm = modules.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
out = self.proj(x) * x_mask
|
||||
@@ -325,24 +397,45 @@ class WNEncoder(nn.Module):
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
||||
def __init__(
|
||||
self,
|
||||
initial_channel,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
|
||||
k, u, padding=(k - u) // 2)))
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
@@ -373,7 +466,7 @@ class Generator(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
@@ -386,13 +479,55 @@ class DiscriminatorP(torch.nn.Module):
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
])
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
32,
|
||||
128,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
128,
|
||||
512,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
512,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1024,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
1,
|
||||
padding=(get_padding(kernel_size, 1), 0),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -421,14 +556,16 @@ class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -451,7 +588,9 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
@@ -469,31 +608,40 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
'''
|
||||
"""
|
||||
inputs --- [N, Ty/r, n_mels*r] mels
|
||||
outputs --- [N, ref_enc_gru_size]
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, spec_channels, gin_channels=0):
|
||||
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||
K = len(ref_enc_filters)
|
||||
filters = [1] + ref_enc_filters
|
||||
convs = [weight_norm(nn.Conv2d(in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1))) for i in range(K)]
|
||||
convs = [
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
for i in range(K)
|
||||
]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
||||
|
||||
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
||||
self.gru = nn.GRU(input_size=ref_enc_filters[-1] * out_channels,
|
||||
hidden_size=256 // 2,
|
||||
batch_first=True)
|
||||
self.gru = nn.GRU(
|
||||
input_size=ref_enc_filters[-1] * out_channels,
|
||||
hidden_size=256 // 2,
|
||||
batch_first=True,
|
||||
)
|
||||
self.proj = nn.Linear(128, gin_channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
@@ -527,23 +675,31 @@ class Quantizer_module(torch.nn.Module):
|
||||
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
|
||||
|
||||
def forward(self, x):
|
||||
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T)
|
||||
d = (
|
||||
torch.sum(x**2, 1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, 1)
|
||||
- 2 * torch.matmul(x, self.embedding.weight.T)
|
||||
)
|
||||
min_indicies = torch.argmin(d, 1)
|
||||
z_q = self.embedding(min_indicies)
|
||||
return z_q, min_indicies
|
||||
|
||||
|
||||
class Quantizer(torch.nn.Module):
|
||||
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
|
||||
super(Quantizer, self).__init__()
|
||||
assert embed_dim % n_code_groups == 0
|
||||
self.quantizer_modules = nn.ModuleList([
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)
|
||||
])
|
||||
self.quantizer_modules = nn.ModuleList(
|
||||
[
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||
for _ in range(n_code_groups)
|
||||
]
|
||||
)
|
||||
self.n_code_groups = n_code_groups
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, xin):
|
||||
#B, C, T
|
||||
# B, C, T
|
||||
B, C, T = xin.shape
|
||||
xin = xin.transpose(1, 2)
|
||||
x = xin.reshape(-1, self.embed_dim)
|
||||
@@ -553,38 +709,41 @@ class Quantizer(torch.nn.Module):
|
||||
for _x, m in zip(x, self.quantizer_modules):
|
||||
_z_q, _min_indicies = m(_x)
|
||||
z_q.append(_z_q)
|
||||
min_indicies.append(_min_indicies) #B * T,
|
||||
min_indicies.append(_min_indicies) # B * T,
|
||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||
(z_q - xin.detach()) ** 2
|
||||
)
|
||||
z_q = xin + (z_q - xin).detach()
|
||||
z_q = z_q.transpose(1, 2)
|
||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||
return z_q, loss, codes.transpose(1, 2)
|
||||
|
||||
def embed(self, x):
|
||||
#idx: N, 4, T
|
||||
x=x.transpose(1, 2)
|
||||
# idx: N, 4, T
|
||||
x = x.transpose(1, 2)
|
||||
x = torch.split(x, 1, 2)
|
||||
ret = []
|
||||
for q, embed in zip(x, self.quantizer_modules):
|
||||
q = embed.embedding(q.squeeze(-1))
|
||||
ret.append(q)
|
||||
ret = torch.cat(ret, -1)
|
||||
return ret.transpose(1, 2) #N, C, T
|
||||
return ret.transpose(1, 2) # N, C, T
|
||||
|
||||
|
||||
class CodePredictor(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
n_q=8,
|
||||
dims=1024,
|
||||
ssl_dim=768
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
n_q=8,
|
||||
dims=1024,
|
||||
ssl_dim=768,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
@@ -594,19 +753,18 @@ class CodePredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
ssl_dim, style_vector_dim=hidden_channels
|
||||
)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q-1) * dims, 1)
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||
self.n_q = n_q
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x, x_mask, refer, codes, infer=False):
|
||||
x = x.detach()
|
||||
x = self.vq_proj(x * x_mask) * x_mask
|
||||
@@ -614,7 +772,9 @@ class CodePredictor(nn.Module):
|
||||
x = x + g
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
x = self.out_proj(x * x_mask) * x_mask
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||
2, 3
|
||||
)
|
||||
target = codes[1:].transpose(0, 1)
|
||||
if not infer:
|
||||
logits = logits.reshape(-1, self.dims)
|
||||
@@ -626,44 +786,44 @@ class CodePredictor(nn.Module):
|
||||
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
|
||||
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
|
||||
|
||||
print('Top-10 Accuracy:', top3_acc, "%")
|
||||
print("Top-10 Accuracy:", top3_acc, "%")
|
||||
|
||||
pred_codes = torch.argmax(logits, dim=-1)
|
||||
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
|
||||
print('Top-1 Accuracy:', acc, "%")
|
||||
print("Top-1 Accuracy:", acc, "%")
|
||||
|
||||
return pred_codes.transpose(0, 1)
|
||||
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
**kwargs):
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@@ -685,34 +845,50 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
gin_channels=gin_channels)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
spec_channels, style_vector_dim=gin_channels
|
||||
)
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
if freeze_quantizer:
|
||||
self.ssl_proj.requires_grad_(False)
|
||||
self.quantizer.requires_grad_(False)
|
||||
@@ -721,56 +897,85 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
|
||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
|
||||
with autocast(enabled=False):
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
o = self.dec(z_slice, g=ge)
|
||||
return o, commit_loss, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
|
||||
return (
|
||||
o,
|
||||
commit_loss,
|
||||
ids_slice,
|
||||
y_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
quantized,
|
||||
)
|
||||
|
||||
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
|
||||
ssl = self.ssl_proj(ssl)
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge, test=test
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o,y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes,text, refer, noise_scale=0.5):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5):
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2)*2]).to(codes.device)
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@@ -779,6 +984,6 @@ class SynthesizerTrn(nn.Module):
|
||||
return o
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,46 +5,74 @@ from torch import nn
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from module.attentions import MultiHeadAttention
|
||||
|
||||
|
||||
class MRTE(nn.Module):
|
||||
def __init__(self,
|
||||
content_enc_channels=192,
|
||||
hidden_size=512,
|
||||
out_channels=192,
|
||||
kernel_size=5,
|
||||
n_heads=4,
|
||||
ge_layer = 2
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
content_enc_channels=192,
|
||||
hidden_size=512,
|
||||
out_channels=192,
|
||||
kernel_size=5,
|
||||
n_heads=4,
|
||||
ge_layer=2,
|
||||
):
|
||||
super(MRTE, self).__init__()
|
||||
self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads)
|
||||
self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
|
||||
self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
|
||||
self.c_post = nn.Conv1d(hidden_size,out_channels, 1)
|
||||
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
|
||||
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
||||
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
||||
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
|
||||
|
||||
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
|
||||
if(ge==None):ge=0
|
||||
if ge == None:
|
||||
ge = 0
|
||||
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
|
||||
|
||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
if test != None:
|
||||
if test == 0:
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
elif test == 1:
|
||||
x = ssl_enc + ge
|
||||
elif test ==2:
|
||||
x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
|
||||
elif test == 2:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ge
|
||||
)
|
||||
else:
|
||||
raise ValueError("test should be 0,1,2")
|
||||
else:
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SpeakerEncoder(torch.nn.Module):
|
||||
def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256):
|
||||
def __init__(
|
||||
self,
|
||||
mel_n_channels=80,
|
||||
model_num_layers=2,
|
||||
model_hidden_size=256,
|
||||
model_embedding_size=256,
|
||||
):
|
||||
super(SpeakerEncoder, self).__init__()
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.lstm = nn.LSTM(
|
||||
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
||||
)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@@ -56,13 +84,15 @@ class SpeakerEncoder(torch.nn.Module):
|
||||
|
||||
|
||||
class MELEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@@ -81,80 +111,82 @@ class MELEncoder(nn.Module):
|
||||
x = self.enc(x)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
|
||||
super(WN, self).__init__()
|
||||
assert(kernel_size % 2 == 1)
|
||||
self.hidden_channels =hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
|
||||
super(WN, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate ** i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
|
||||
dilation=dilation, padding=padding)
|
||||
in_layer = weight_norm(in_layer)
|
||||
self.in_layers.append(in_layer)
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = nn.Conv1d(
|
||||
hidden_channels,
|
||||
2 * hidden_channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
)
|
||||
in_layer = weight_norm(in_layer)
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = weight_norm(res_skip_layer, name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
def forward(self, x):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
|
||||
acts = fused_add_tanh_sigmoid_multiply(
|
||||
x_in,
|
||||
n_channels_tensor)
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:,:self.hidden_channels,:]
|
||||
x = (x + res_acts)
|
||||
output = output + res_skip_acts[:,self.hidden_channels:,:]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||
x = x + res_acts
|
||||
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.in_layers:
|
||||
remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
remove_weight_norm(l)
|
||||
def remove_weight_norm(self):
|
||||
for l in self.in_layers:
|
||||
remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
t_act = torch.tanh(input[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(input[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
n_channels_int = n_channels[0]
|
||||
t_act = torch.tanh(input[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(input[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
content_enc = torch.randn(3,192,100)
|
||||
content_mask = torch.ones(3,1,100)
|
||||
ref_mel = torch.randn(3,128,30)
|
||||
ref_mask = torch.ones(3,1,30)
|
||||
if __name__ == "__main__":
|
||||
content_enc = torch.randn(3, 192, 100)
|
||||
content_mask = torch.ones(3, 1, 100)
|
||||
ref_mel = torch.randn(3, 128, 30)
|
||||
ref_mask = torch.ones(3, 1, 30)
|
||||
model = MRTE()
|
||||
out = model(content_enc,content_mask,ref_mel,ref_mask)
|
||||
print(out.shape)
|
||||
out = model(content_enc, content_mask, ref_mel, ref_mask)
|
||||
print(out.shape)
|
||||
|
||||
@@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int = 256,
|
||||
@@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n_q: tp.Optional[int] = None,
|
||||
layers: tp.Optional[list] = None,
|
||||
) -> QuantizedResult:
|
||||
"""Residual vector quantization on the given input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
@@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
"""
|
||||
n_q = n_q if n_q else self.n_q
|
||||
if layers and max(layers) >= n_q:
|
||||
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.')
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
|
||||
raise ValueError(
|
||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
||||
x, n_q=n_q, layers=layers
|
||||
)
|
||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
@@ -105,4 +116,4 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
st (int): Start to decode input codes from which layers. Default: 0.
|
||||
"""
|
||||
quantized = self.vq.decode(codes, st=st)
|
||||
return quantized
|
||||
return quantized
|
||||
|
||||
@@ -9,66 +9,63 @@ DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
|
||||
def piecewise_rational_quadratic_transform(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {
|
||||
'tails': tails,
|
||||
'tail_bound': tail_bound
|
||||
}
|
||||
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return torch.sum(
|
||||
inputs[..., None] >= bin_locations,
|
||||
dim=-1
|
||||
) - 1
|
||||
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails='linear',
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
def unconstrained_rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails="linear",
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == 'linear':
|
||||
if tails == "linear":
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
@@ -77,45 +74,57 @@ def unconstrained_rational_quadratic_spline(inputs,
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError('{} tails are not implemented.'.format(tails))
|
||||
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||
|
||||
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||
(
|
||||
outputs[inside_interval_mask],
|
||||
logabsdet[inside_interval_mask],
|
||||
) = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
|
||||
left=-tail_bound,
|
||||
right=tail_bound,
|
||||
bottom=-tail_bound,
|
||||
top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative
|
||||
min_derivative=min_derivative,
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
def rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0., right=1., bottom=0., top=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
|
||||
def rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0.0,
|
||||
right=1.0,
|
||||
bottom=0.0,
|
||||
top=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||
):
|
||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
raise ValueError('Input to a transform is not within its domain')
|
||||
raise ValueError("Input to a transform is not within its domain")
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin width too large for the number of bins')
|
||||
raise ValueError("Minimal bin width too large for the number of bins")
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin height too large for the number of bins')
|
||||
raise ValueError("Minimal bin height too large for the number of bins")
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
@@ -126,7 +135,7 @@ def rational_quadratic_spline(inputs,
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
@@ -150,15 +159,13 @@ def rational_quadratic_spline(inputs,
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (((inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta)
|
||||
+ input_heights * (input_delta - input_derivatives)))
|
||||
b = (input_heights * input_derivatives
|
||||
- (inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta))
|
||||
c = - input_delta * (inputs - input_cumheights)
|
||||
a = (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) + input_heights * (input_delta - input_derivatives)
|
||||
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
)
|
||||
c = -input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
@@ -167,11 +174,15 @@ def rational_quadratic_spline(inputs,
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2))
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
@@ -179,15 +190,20 @@ def rational_quadratic_spline(inputs,
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (input_delta * theta.pow(2)
|
||||
+ input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
numerator = input_heights * (
|
||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||
)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2))
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2)
|
||||
)
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
Reference in New Issue
Block a user