添加导出成 TorchScript 的脚本用于支持python以外的语言 (#1640)

* Fix onnx_export to support v2

* delete some useless code & add some args type for export torch-script

* Add export_torch_script.py

* (export_torch_script.py) 整合 vits 和 t2s 成一个 model 导出

* 恢复 `t2s_model.py` 把改动移到 `export_torch_script.py`
This commit is contained in:
zzz
2024-09-29 17:28:02 +08:00
committed by GitHub
parent 78c68d46cb
commit 5efb960898
4 changed files with 825 additions and 55 deletions

View File

@@ -1,5 +1,6 @@
import copy
import math
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
@@ -11,7 +12,6 @@ from module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
@@ -218,7 +218,7 @@ class TextEncoder(nn.Module):
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
self.mrte = attentions.MRTE()
self.encoder2 = attentions.Encoder(
hidden_channels,
@@ -249,25 +249,6 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module):
def __init__(
@@ -448,7 +429,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
def forward(self, x, g:Optional[torch.Tensor]=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
@@ -870,15 +851,15 @@ class SynthesizerTrn(nn.Module):
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
# self.enc_q = PosteriorEncoder(
# spec_channels,
# inter_channels,
# hidden_channels,
# 5,
# 1,
# 16,
# gin_channels=gin_channels,
# )
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)