update Gradient Checkpointing to reduce VRAM usage (#2040)
* update Gradient Checkpointing to reduce VRAM usage * fix inference
This commit is contained in:
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
@@ -121,6 +122,14 @@ class DiT(nn.Module):
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def ckpt_wrapper(self, module):
|
||||
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
||||
def ckpt_forward(*inputs):
|
||||
outputs = module(*inputs)
|
||||
return outputs
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||
self,#d is channel,n is T
|
||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||
@@ -129,11 +138,12 @@ class DiT(nn.Module):
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
dt_base_bootstrap,
|
||||
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||
|
||||
use_grad_ckpt, # bool
|
||||
###no-use
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
|
||||
):
|
||||
|
||||
x=x0.transpose(2,1)
|
||||
@@ -158,7 +168,10 @@ class DiT(nn.Module):
|
||||
residual = x
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
if use_grad_ckpt:
|
||||
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||
|
||||
Reference in New Issue
Block a user