添加导出成 TorchScript 的脚本用于支持python以外的语言 (#1640)

* Fix onnx_export to support v2

* delete some useless code & add some args type for export torch-script

* Add export_torch_script.py

* (export_torch_script.py) 整合 vits 和 t2s 成一个 model 导出

* 恢复 `t2s_model.py` 把改动移到 `export_torch_script.py`
This commit is contained in:
zzz
2024-09-29 17:28:02 +08:00
committed by GitHub
parent 78c68d46cb
commit 5efb960898
4 changed files with 825 additions and 55 deletions

View File

@@ -4,8 +4,8 @@ from torch import nn
from torch.nn import functional as F
from module import commons
from module.modules import LayerNorm
from typing import Optional
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
@@ -59,6 +59,7 @@ class Encoder(nn.Module):
# self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256
self.cond_layer_idx = self.n_layers
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0:
@@ -98,22 +99,36 @@ class Encoder(nn.Module):
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None):
# def forward(self, x, x_mask, g=None):
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
# x = x * x_mask
# for i in range(self.n_layers):
# if i == self.cond_layer_idx and g is not None:
# g = self.spk_emb_linear(g.transpose(1, 2))
# g = g.transpose(1, 2)
# x = x + g
# x = x * x_mask
# y = self.attn_layers[i](x, x, attn_mask)
# y = self.drop(y)
# x = self.norm_layers_1[i](x + y)
# y = self.ffn_layers[i](x, x_mask)
# y = self.drop(y)
# x = self.norm_layers_2[i](x + y)
# x = x * x_mask
# return x
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
if i == self.cond_layer_idx and g is not None:
g = self.spk_emb_linear(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
y = attn_layers(x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
x = norm_layers_1(x + y)
y = self.ffn_layers[i](x, x_mask)
y = ffn_layers(x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = norm_layers_2(x + y)
x = x * x_mask
return x
@@ -172,17 +187,18 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
# x, self.attn = self.attention(q, k, v, mask=attn_mask)
x, _ = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
@@ -304,7 +320,7 @@ class FFN(nn.Module):
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
activation="",
causal=False,
):
super().__init__()
@@ -316,10 +332,11 @@ class FFN(nn.Module):
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
# 从上下文看这里一定是 False
# if causal:
# self.padding = self._causal_padding
# else:
# self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
@@ -334,6 +351,9 @@ class FFN(nn.Module):
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def padding(self, x):
return self._same_padding(x)
def _causal_padding(self, x):
if self.kernel_size == 1:
@@ -352,3 +372,35 @@ class FFN(nn.Module):
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
class MRTE(nn.Module):
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask)
return x

View File

@@ -13,10 +13,10 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
# def convert_pad_shape(pad_shape):
# l = pad_shape[::-1]
# pad_shape = [item for sublist in l for item in sublist]
# return pad_shape
def intersperse(lst, item):

View File

@@ -1,5 +1,6 @@
import copy
import math
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
@@ -11,7 +12,6 @@ from module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
@@ -218,7 +218,7 @@ class TextEncoder(nn.Module):
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
self.mrte = attentions.MRTE()
self.encoder2 = attentions.Encoder(
hidden_channels,
@@ -249,25 +249,6 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module):
def __init__(
@@ -448,7 +429,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
def forward(self, x, g:Optional[torch.Tensor]=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
@@ -870,15 +851,15 @@ class SynthesizerTrn(nn.Module):
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
# self.enc_q = PosteriorEncoder(
# spec_channels,
# inter_channels,
# hidden_channels,
# 5,
# 1,
# 16,
# gin_channels=gin_channels,
# )
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)