update Gradient Checkpointing to reduce VRAM usage (#2040)
* update Gradient Checkpointing to reduce VRAM usage * fix inference
This commit is contained in:
@@ -1089,15 +1089,15 @@ class CFM(torch.nn.Module):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||
if inference_cfg_rate>1e-5:
|
||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
x = x + d * v_pred
|
||||
t = t + d
|
||||
x[:, :, :prompt_len] = 0
|
||||
return x
|
||||
def forward(self, x1, x_lens, prompt_lens, mu):
|
||||
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
||||
b, _, t = x1.shape
|
||||
|
||||
# random timestep
|
||||
@@ -1117,16 +1117,16 @@ class CFM(torch.nn.Module):
|
||||
d_input = d.clone()
|
||||
d_input[d_input < 1e-2] = 0
|
||||
# with torch.no_grad():
|
||||
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu).transpose(2, 1).detach()
|
||||
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
||||
x_mid = xt + d[:, None, None] * v_pred_1
|
||||
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu).transpose(2, 1).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
vt = (v_pred_1 + v_pred_2) / 2
|
||||
vt = vt.detach()
|
||||
dt = 2*d
|
||||
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu).transpose(2,1)
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
|
||||
loss = 0
|
||||
|
||||
# print(45555555,estimator_out.shape,u.shape,x_lens,prompt_lens)#45555555 torch.Size([7, 465, 100]) torch.Size([7, 100, 465]) tensor([461, 461, 451, 451, 442, 442, 442], device='cuda:0') tensor([ 96, 93, 185, 59, 244, 262, 294], device='cuda:0')
|
||||
@@ -1220,7 +1220,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
|
||||
with autocast(enabled=False):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
@@ -1245,7 +1245,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
minn=min(mel.shape[-1],fea.shape[-1])
|
||||
mel=mel[:,:,:minn]
|
||||
fea=fea[:,:,:minn]
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user