修复了一些bug,优化了一些代码
This commit is contained in:
@@ -229,10 +229,15 @@ class Text2SemanticDecoder(nn.Module):
|
||||
ignore_index=self.EOS,
|
||||
)
|
||||
|
||||
if not flash_attn_enabled:
|
||||
self.enable_flash_attn(flash_attn_enabled)
|
||||
|
||||
def enable_flash_attn(self, enable:bool=True):
|
||||
|
||||
if not enable:
|
||||
print("Not Using Flash Attention")
|
||||
self.infer_panel = self.infer_panel_batch_only
|
||||
else:
|
||||
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
|
||||
print("Using Flash Attention")
|
||||
blocks = []
|
||||
|
||||
@@ -497,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel(
|
||||
def infer_panel_batch_infer_with_flash_attn(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
x_lens,
|
||||
@@ -508,8 +513,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
|
||||
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
# AR Decoder
|
||||
@@ -546,30 +553,28 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
|
||||
y_mask = make_pad_mask(y_lens)
|
||||
x_mask = make_pad_mask(x_lens)
|
||||
|
||||
|
||||
# (bsz, x_len + y_len)
|
||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||
_xy_padding_mask = (
|
||||
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
|
||||
)
|
||||
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_mask = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
|
||||
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
|
||||
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
||||
xy_attn_mask = new_attn_mask
|
||||
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
@@ -641,7 +646,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||
|
||||
|
||||
if (None in idx_list):
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
|
||||
Reference in New Issue
Block a user