update Gradient Checkpointing to reduce VRAM usage (#2040)

* update Gradient Checkpointing to reduce VRAM usage

* fix inference
This commit is contained in:
KakaruHayate
2025-02-12 23:00:34 +08:00
committed by GitHub
parent 86acb7a89d
commit c2b3298bed
5 changed files with 33 additions and 16 deletions

View File

@@ -304,7 +304,7 @@ def train_and_evaluate(
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths)
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()