update_infer

This commit is contained in:
Watchtower-Liu
2024-02-16 16:53:57 +08:00
parent 41041715a4
commit 1803729360
6 changed files with 88 additions and 56 deletions

View File

@@ -5,8 +5,8 @@ from torch.nn.functional import (
_none_or_dtype,
_in_projection_packed,
)
# import torch
from torch.nn import functional as F
import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
@@ -448,9 +448,11 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)