more code refactor
This commit is contained in:
@@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
F.multi_head_attention_forward=multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information
|
||||
@@ -76,66 +78,71 @@ class MultiheadAttention(Module):
|
||||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
linear1_cls=Linear,
|
||||
linear2_cls=Linear,
|
||||
device=None,
|
||||
dtype=None, ) -> None:
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
linear1_cls=Linear,
|
||||
linear2_cls=Linear,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = (self.kdim == embed_dim and
|
||||
self.vdim == embed_dim)
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs))
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs))
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs))
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
else:
|
||||
@@ -143,7 +150,8 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
@@ -156,7 +164,8 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
@@ -190,14 +199,15 @@ class MultiheadAttention(Module):
|
||||
super(MultiheadAttention, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor]=None,
|
||||
need_weights: bool=True,
|
||||
attn_mask: Optional[Tensor]=None,
|
||||
average_attn_weights: bool=True,cache=None
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -251,23 +261,26 @@ class MultiheadAttention(Module):
|
||||
if key_padding_mask is not None:
|
||||
_kpm_dtype = key_padding_mask.dtype
|
||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
key_padding_mask):
|
||||
key_padding_mask
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif (self.in_proj_bias is not None and
|
||||
query.dtype != self.in_proj_bias.dtype):
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (self.in_proj_weight is not None and
|
||||
query.dtype != self.in_proj_weight.dtype):
|
||||
elif (
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
elif self.training:
|
||||
@@ -288,29 +301,41 @@ class MultiheadAttention(Module):
|
||||
why_not_fast_path = "attn_mask was not None"
|
||||
elif query.is_nested and key_padding_mask is not None:
|
||||
why_not_fast_path = (
|
||||
"key_padding_mask is not supported with NestedTensor input")
|
||||
"key_padding_mask is not supported with NestedTensor input"
|
||||
)
|
||||
elif self.num_heads % 2 == 1:
|
||||
why_not_fast_path = "num_heads is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_fast_path = "autocast is enabled"
|
||||
|
||||
if not why_not_fast_path:
|
||||
tensor_args = (query, key, value, self.in_proj_weight,
|
||||
self.in_proj_bias, self.out_proj.weight,
|
||||
self.out_proj.bias, )
|
||||
tensor_args = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
)
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args]):
|
||||
why_not_fast_path = (
|
||||
"some Tensor argument is neither CUDA nor CPU")
|
||||
elif not all(
|
||||
[
|
||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]):
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad")
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
@@ -322,17 +347,21 @@ class MultiheadAttention(Module):
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
key_padding_mask
|
||||
if key_padding_mask is not None else attn_mask,
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1 if key_padding_mask is not None else 0
|
||||
if attn_mask is not None else None, )
|
||||
1
|
||||
if key_padding_mask is not None
|
||||
else 0
|
||||
if attn_mask is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
assert not any_nested, (
|
||||
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
||||
+ f"The fast path was not hit because {why_not_fast_path}")
|
||||
+ f"The fast path was not hit because {why_not_fast_path}"
|
||||
)
|
||||
|
||||
if self.batch_first and is_batched:
|
||||
# make sure that the transpose op does not affect the "is" property
|
||||
@@ -343,9 +372,7 @@ class MultiheadAttention(Module):
|
||||
query, key = [x.transpose(1, 0) for x in (query, key)]
|
||||
value = key
|
||||
else:
|
||||
query, key, value = [
|
||||
x.transpose(1, 0) for x in (query, key, value)
|
||||
]
|
||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
@@ -370,7 +397,9 @@ class MultiheadAttention(Module):
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
average_attn_weights=average_attn_weights,cache=cache )
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
@@ -390,7 +419,9 @@ class MultiheadAttention(Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights,cache=cache )
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user