update Gradient Checkpointing to reduce VRAM usage (#2040)
* update Gradient Checkpointing to reduce VRAM usage * fix inference
This commit is contained in:
@@ -304,7 +304,7 @@ def train_and_evaluate(
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths)
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||
loss_gen_all=cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
|
||||
Reference in New Issue
Block a user