update Gradient Checkpointing to reduce VRAM usage (#2040)
* update Gradient Checkpointing to reduce VRAM usage * fix inference
This commit is contained in:
@@ -18,7 +18,8 @@
|
||||
"warmup_epochs": 0,
|
||||
"c_mel": 45,
|
||||
"c_kl": 1.0,
|
||||
"text_low_lr_rate": 0.4
|
||||
"text_low_lr_rate": 0.4,
|
||||
"grad_ckpt": false
|
||||
},
|
||||
"data": {
|
||||
"max_wav_value": 32768.0,
|
||||
|
||||
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
@@ -121,6 +122,14 @@ class DiT(nn.Module):
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def ckpt_wrapper(self, module):
|
||||
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
||||
def ckpt_forward(*inputs):
|
||||
outputs = module(*inputs)
|
||||
return outputs
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||
self,#d is channel,n is T
|
||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||
@@ -129,11 +138,12 @@ class DiT(nn.Module):
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
dt_base_bootstrap,
|
||||
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||
|
||||
use_grad_ckpt, # bool
|
||||
###no-use
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
|
||||
):
|
||||
|
||||
x=x0.transpose(2,1)
|
||||
@@ -158,7 +168,10 @@ class DiT(nn.Module):
|
||||
residual = x
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
if use_grad_ckpt:
|
||||
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||
|
||||
@@ -1089,15 +1089,15 @@ class CFM(torch.nn.Module):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||
if inference_cfg_rate>1e-5:
|
||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
x = x + d * v_pred
|
||||
t = t + d
|
||||
x[:, :, :prompt_len] = 0
|
||||
return x
|
||||
def forward(self, x1, x_lens, prompt_lens, mu):
|
||||
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
||||
b, _, t = x1.shape
|
||||
|
||||
# random timestep
|
||||
@@ -1117,16 +1117,16 @@ class CFM(torch.nn.Module):
|
||||
d_input = d.clone()
|
||||
d_input[d_input < 1e-2] = 0
|
||||
# with torch.no_grad():
|
||||
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu).transpose(2, 1).detach()
|
||||
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
||||
x_mid = xt + d[:, None, None] * v_pred_1
|
||||
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu).transpose(2, 1).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
vt = (v_pred_1 + v_pred_2) / 2
|
||||
vt = vt.detach()
|
||||
dt = 2*d
|
||||
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu).transpose(2,1)
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
|
||||
loss = 0
|
||||
|
||||
# print(45555555,estimator_out.shape,u.shape,x_lens,prompt_lens)#45555555 torch.Size([7, 465, 100]) torch.Size([7, 100, 465]) tensor([461, 461, 451, 451, 442, 442, 442], device='cuda:0') tensor([ 96, 93, 185, 59, 244, 262, 294], device='cuda:0')
|
||||
@@ -1220,7 +1220,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
|
||||
with autocast(enabled=False):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
@@ -1245,7 +1245,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
minn=min(mel.shape[-1],fea.shape[-1])
|
||||
mel=mel[:,:,:minn]
|
||||
fea=fea[:,:,:minn]
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -304,7 +304,7 @@ def train_and_evaluate(
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths)
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||
loss_gen_all=cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
|
||||
Reference in New Issue
Block a user