添加导出成 TorchScript 的脚本用于支持python以外的语言 (#1640)
* Fix onnx_export to support v2 * delete some useless code & add some args type for export torch-script * Add export_torch_script.py * (export_torch_script.py) 整合 vits 和 t2s 成一个 model 导出 * 恢复 `t2s_model.py` 把改动移到 `export_torch_script.py`
This commit is contained in:
@@ -4,8 +4,8 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from module import commons
|
||||
from module.modules import LayerNorm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
@@ -59,6 +59,7 @@ class Encoder(nn.Module):
|
||||
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
||||
# self.gin_channels = 256
|
||||
self.cond_layer_idx = self.n_layers
|
||||
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
|
||||
if "gin_channels" in kwargs:
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
if self.gin_channels != 0:
|
||||
@@ -98,22 +99,36 @@ class Encoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
# def forward(self, x, x_mask, g=None):
|
||||
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
# x = x * x_mask
|
||||
# for i in range(self.n_layers):
|
||||
# if i == self.cond_layer_idx and g is not None:
|
||||
# g = self.spk_emb_linear(g.transpose(1, 2))
|
||||
# g = g.transpose(1, 2)
|
||||
# x = x + g
|
||||
# x = x * x_mask
|
||||
# y = self.attn_layers[i](x, x, attn_mask)
|
||||
# y = self.drop(y)
|
||||
# x = self.norm_layers_1[i](x + y)
|
||||
|
||||
# y = self.ffn_layers[i](x, x_mask)
|
||||
# y = self.drop(y)
|
||||
# x = self.norm_layers_2[i](x + y)
|
||||
# x = x * x_mask
|
||||
# return x
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
if i == self.cond_layer_idx and g is not None:
|
||||
g = self.spk_emb_linear(g.transpose(1, 2))
|
||||
g = g.transpose(1, 2)
|
||||
x = x + g
|
||||
x = x * x_mask
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
|
||||
y = attn_layers(x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
x = norm_layers_1(x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = ffn_layers(x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = norm_layers_2(x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
@@ -172,17 +187,18 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
# x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
x, _ = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, _ = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||
@@ -304,7 +320,7 @@ class FFN(nn.Module):
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=0.0,
|
||||
activation=None,
|
||||
activation="",
|
||||
causal=False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -316,10 +332,11 @@ class FFN(nn.Module):
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
# 从上下文看这里一定是 False
|
||||
# if causal:
|
||||
# self.padding = self._causal_padding
|
||||
# else:
|
||||
# self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
@@ -334,6 +351,9 @@ class FFN(nn.Module):
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def padding(self, x):
|
||||
return self._same_padding(x)
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
@@ -352,3 +372,35 @@ class FFN(nn.Module):
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
|
||||
class MRTE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
content_enc_channels=192,
|
||||
hidden_size=512,
|
||||
out_channels=192,
|
||||
kernel_size=5,
|
||||
n_heads=4,
|
||||
ge_layer=2,
|
||||
):
|
||||
super(MRTE, self).__init__()
|
||||
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
|
||||
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
||||
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
||||
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
|
||||
|
||||
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
|
||||
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
|
||||
|
||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
||||
@@ -13,10 +13,10 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
# def convert_pad_shape(pad_shape):
|
||||
# l = pad_shape[::-1]
|
||||
# pad_shape = [item for sublist in l for item in sublist]
|
||||
# return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@@ -11,7 +12,6 @@ from module import attentions_onnx as attentions
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
@@ -218,7 +218,7 @@ class TextEncoder(nn.Module):
|
||||
symbols = symbols_v2.symbols
|
||||
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||
|
||||
self.mrte = MRTE()
|
||||
self.mrte = attentions.MRTE()
|
||||
|
||||
self.encoder2 = attentions.Encoder(
|
||||
hidden_channels,
|
||||
@@ -249,25 +249,6 @@ class TextEncoder(nn.Module):
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask
|
||||
|
||||
def extract_latent(self, x):
|
||||
x = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
|
||||
y = self.vq_proj(quantized) * y_mask
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
|
||||
y = self.mrte(y, y_mask, refer, refer_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask, quantized
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -448,7 +429,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
def forward(self, x, g:Optional[torch.Tensor]=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@@ -870,15 +851,15 @@ class SynthesizerTrn(nn.Module):
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
# self.enc_q = PosteriorEncoder(
|
||||
# spec_channels,
|
||||
# inter_channels,
|
||||
# hidden_channels,
|
||||
# 5,
|
||||
# 1,
|
||||
# 16,
|
||||
# gin_channels=gin_channels,
|
||||
# )
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user