condition cache (#2377)
This commit is contained in:
@@ -143,6 +143,9 @@ class DiT(nn.Module):
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
infer=False, # bool
|
||||
text_cache=None, # torch tensor as text_embed
|
||||
dt_cache=None, # torch tensor as dt
|
||||
):
|
||||
x = x0.transpose(2, 1)
|
||||
cond = cond0.transpose(2, 1)
|
||||
@@ -155,9 +158,17 @@ class DiT(nn.Module):
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
if infer and dt_cache is not None:
|
||||
dt = dt_cache
|
||||
else:
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
t += dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||
|
||||
if infer and text_cache is not None:
|
||||
text_embed = text_cache
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
@@ -177,4 +188,7 @@ class DiT(nn.Module):
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
return output
|
||||
if infer:
|
||||
return output, text_embed, dt
|
||||
else:
|
||||
return output
|
||||
Reference in New Issue
Block a user