Sync with main branch (#1134)
* api接口,修复文本切分符号设定中,中文分号错写为英文分号的问题 (#1001) * 一些小问题修复 (#1021) * fix import error. It may happen when calling api.py * Update README.md * Update gpt-sovits_kaggle.ipynb * Update gpt-sovits_kaggle.ipynb * fix path error delete useless line wraps * 删除重复的 COPY 指令 (#1073) * [优化] 1Aa-文本获取 (#1102) * Filter unsupported languages * add feedback * simplify modification * fix detail * Update english.py (#1106) copy but not ref the phones list becoz it will be extend later, if not do so,it will affect the self.cmu dict values. * Update models.py * modify freeze_quantizer mode, avoid quantizer's codebook updating (#953) --------- Co-authored-by: FengQingYunDan <pingdengjia0liu@163.com> Co-authored-by: Kenn Zhang <breakstring@hotmail.com> Co-authored-by: 蓝梦实 <36986837+SapphireLab@users.noreply.github.com> Co-authored-by: lyris <lyris@users.noreply.github.com> Co-authored-by: hcwu1993 <15855138469@163.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
from text import symbols
|
||||
from torch.cuda.amp import autocast
|
||||
import contextlib
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
@@ -891,9 +892,10 @@ class SynthesizerTrn(nn.Module):
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
if freeze_quantizer:
|
||||
self.ssl_proj.requires_grad_(False)
|
||||
self.quantizer.requires_grad_(False)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
# if freeze_quantizer:
|
||||
# self.ssl_proj.requires_grad_(False)
|
||||
# self.quantizer.requires_grad_(False)
|
||||
#self.quantizer.eval()
|
||||
# self.enc_p.text_embedding.requires_grad_(False)
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
@@ -906,6 +908,11 @@ class SynthesizerTrn(nn.Module):
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
|
||||
with autocast(enabled=False):
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
with maybe_no_grad:
|
||||
if self.freeze_quantizer:
|
||||
self.ssl_proj.eval()
|
||||
self.quantizer.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
|
||||
Reference in New Issue
Block a user