update Gradient Checkpointing to reduce VRAM usage (#2040)

* update Gradient Checkpointing to reduce VRAM usage

* fix inference
This commit is contained in:
KakaruHayate
2025-02-12 23:00:34 +08:00
committed by GitHub
parent 86acb7a89d
commit c2b3298bed
5 changed files with 33 additions and 16 deletions

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from x_transformers.x_transformers import RotaryEmbedding
@@ -121,6 +122,14 @@ class DiT(nn.Module):
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def forward(#x, prompt_x, x_lens, t, style,cond
self,#d is channel,n is T
x0: float["b n d"], # nosied input audio # noqa: F722
@@ -129,11 +138,12 @@ class DiT(nn.Module):
time: float["b"] | float[""], # time step # noqa: F821 F722
dt_base_bootstrap,
text0, # : int["b nt"] # noqa: F722#####condition feature
use_grad_ckpt, # bool
###no-use
drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text
# mask: bool["b n"] | None = None, # noqa: F722
):
x=x0.transpose(2,1)
@@ -158,7 +168,10 @@ class DiT(nn.Module):
residual = x
for block in self.transformer_blocks:
x = block(x, t, mask=mask, rope=rope)
if use_grad_ckpt:
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))