more code refactor
This commit is contained in:
@@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
F.multi_head_attention_forward=multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information
|
||||
@@ -76,66 +78,71 @@ class MultiheadAttention(Module):
|
||||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
linear1_cls=Linear,
|
||||
linear2_cls=Linear,
|
||||
device=None,
|
||||
dtype=None, ) -> None:
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
linear1_cls=Linear,
|
||||
linear2_cls=Linear,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = (self.kdim == embed_dim and
|
||||
self.vdim == embed_dim)
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs))
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs))
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs))
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
else:
|
||||
@@ -143,7 +150,8 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
@@ -156,7 +164,8 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
@@ -190,14 +199,15 @@ class MultiheadAttention(Module):
|
||||
super(MultiheadAttention, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor]=None,
|
||||
need_weights: bool=True,
|
||||
attn_mask: Optional[Tensor]=None,
|
||||
average_attn_weights: bool=True,cache=None
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -251,23 +261,26 @@ class MultiheadAttention(Module):
|
||||
if key_padding_mask is not None:
|
||||
_kpm_dtype = key_padding_mask.dtype
|
||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
key_padding_mask):
|
||||
key_padding_mask
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif (self.in_proj_bias is not None and
|
||||
query.dtype != self.in_proj_bias.dtype):
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (self.in_proj_weight is not None and
|
||||
query.dtype != self.in_proj_weight.dtype):
|
||||
elif (
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
elif self.training:
|
||||
@@ -288,29 +301,41 @@ class MultiheadAttention(Module):
|
||||
why_not_fast_path = "attn_mask was not None"
|
||||
elif query.is_nested and key_padding_mask is not None:
|
||||
why_not_fast_path = (
|
||||
"key_padding_mask is not supported with NestedTensor input")
|
||||
"key_padding_mask is not supported with NestedTensor input"
|
||||
)
|
||||
elif self.num_heads % 2 == 1:
|
||||
why_not_fast_path = "num_heads is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_fast_path = "autocast is enabled"
|
||||
|
||||
if not why_not_fast_path:
|
||||
tensor_args = (query, key, value, self.in_proj_weight,
|
||||
self.in_proj_bias, self.out_proj.weight,
|
||||
self.out_proj.bias, )
|
||||
tensor_args = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
)
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args]):
|
||||
why_not_fast_path = (
|
||||
"some Tensor argument is neither CUDA nor CPU")
|
||||
elif not all(
|
||||
[
|
||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]):
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad")
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
@@ -322,17 +347,21 @@ class MultiheadAttention(Module):
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
key_padding_mask
|
||||
if key_padding_mask is not None else attn_mask,
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1 if key_padding_mask is not None else 0
|
||||
if attn_mask is not None else None, )
|
||||
1
|
||||
if key_padding_mask is not None
|
||||
else 0
|
||||
if attn_mask is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
assert not any_nested, (
|
||||
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
||||
+ f"The fast path was not hit because {why_not_fast_path}")
|
||||
+ f"The fast path was not hit because {why_not_fast_path}"
|
||||
)
|
||||
|
||||
if self.batch_first and is_batched:
|
||||
# make sure that the transpose op does not affect the "is" property
|
||||
@@ -343,9 +372,7 @@ class MultiheadAttention(Module):
|
||||
query, key = [x.transpose(1, 0) for x in (query, key)]
|
||||
value = key
|
||||
else:
|
||||
query, key, value = [
|
||||
x.transpose(1, 0) for x in (query, key, value)
|
||||
]
|
||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
@@ -370,7 +397,9 @@ class MultiheadAttention(Module):
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
average_attn_weights=average_attn_weights,cache=cache )
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
@@ -390,7 +419,9 @@ class MultiheadAttention(Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights,cache=cache )
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
else:
|
||||
|
||||
@@ -7,10 +7,11 @@ from torch import nn
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
dropout: float=0.0, ):
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
@@ -24,7 +25,7 @@ class TokenEmbedding(nn.Module):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def embedding(self, index: int) -> torch.Tensor:
|
||||
return self.word_embeddings.weight[index:index + 1]
|
||||
return self.word_embeddings.weight[index : index + 1]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.word_embeddings(x)
|
||||
@@ -34,11 +35,12 @@ class TokenEmbedding(nn.Module):
|
||||
|
||||
class SinePositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
dropout: float=0.0,
|
||||
scale: bool=False,
|
||||
alpha: bool=False, ):
|
||||
self,
|
||||
embedding_dim: int,
|
||||
dropout: float = 0.0,
|
||||
scale: bool = False,
|
||||
alpha: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
@@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module):
|
||||
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(
|
||||
0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
|
||||
-(math.log(10000.0) / self.embedding_dim))
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
@@ -74,5 +77,5 @@ class SinePositionalEmbedding(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.extend_pe(x)
|
||||
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
||||
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
|
||||
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(output)
|
||||
|
||||
@@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
init_lr,
|
||||
peak_lr,
|
||||
end_lr,
|
||||
warmup_steps=10000,
|
||||
total_steps=400000,
|
||||
current_step=0):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
init_lr,
|
||||
peak_lr,
|
||||
end_lr,
|
||||
warmup_steps=10000,
|
||||
total_steps=400000,
|
||||
current_step=0,
|
||||
):
|
||||
self.init_lr = init_lr
|
||||
self.peak_lr = peak_lr
|
||||
self.end_lr = end_lr
|
||||
@@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
self._last_lr = [self.lr]
|
||||
|
||||
def set_lr(self, lr):
|
||||
self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
|
||||
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
|
||||
for g in self.optimizer.param_groups:
|
||||
# g['lr'] = lr
|
||||
g['lr'] = self.end_lr###锁定用线性
|
||||
g["lr"] = self.end_lr ###锁定用线性
|
||||
|
||||
def step(self):
|
||||
if self._current_step < self.warmup_steps:
|
||||
@@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
|
||||
else:
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (
|
||||
self.total_steps - self.warmup_steps)
|
||||
self.total_steps - self.warmup_steps
|
||||
)
|
||||
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
||||
raise RuntimeError(
|
||||
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
||||
@@ -55,25 +58,19 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||
|
||||
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
|
||||
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
|
||||
self.set_lr(lr)
|
||||
self.lr = lr
|
||||
self._current_step += 1
|
||||
return self.lr
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
m = nn.Linear(10, 10)
|
||||
opt = Adam(m.parameters(), lr=1e-4)
|
||||
s = WarmupCosineLRSchedule(
|
||||
opt,
|
||||
1e-6,
|
||||
2e-4,
|
||||
1e-6,
|
||||
warmup_steps=2000,
|
||||
total_steps=20000,
|
||||
current_step=0)
|
||||
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
|
||||
)
|
||||
lrs = []
|
||||
for i in range(25000):
|
||||
s.step()
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
|
||||
from torch.nn.functional import (
|
||||
_mha_shape_check,
|
||||
_canonical_mask,
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
|
||||
# import torch
|
||||
# Tensor = torch.Tensor
|
||||
# from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
@@ -29,7 +36,8 @@ def multi_head_attention_forward_patched(
|
||||
static_k: Optional[Tensor] = None,
|
||||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,cache=None
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -105,7 +113,17 @@ def multi_head_attention_forward_patched(
|
||||
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
||||
"""
|
||||
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
||||
tens_ops = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
bias_k,
|
||||
bias_v,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
)
|
||||
if has_torch_function(tens_ops):
|
||||
return handle_torch_function(
|
||||
multi_head_attention_forward,
|
||||
@@ -134,10 +152,13 @@ def multi_head_attention_forward_patched(
|
||||
v_proj_weight=v_proj_weight,
|
||||
static_k=static_k,
|
||||
static_v=static_v,
|
||||
average_attn_weights=average_attn_weights,cache=cache
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||
is_batched = _mha_shape_check(
|
||||
query, key, value, key_padding_mask, attn_mask, num_heads
|
||||
)
|
||||
|
||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||
# is batched, run the computation and before returning squeeze the
|
||||
@@ -159,7 +180,7 @@ def multi_head_attention_forward_patched(
|
||||
mask_name="key_padding_mask",
|
||||
other_type=_none_or_dtype(attn_mask),
|
||||
other_name="attn_mask",
|
||||
target_type=query.dtype
|
||||
target_type=query.dtype,
|
||||
)
|
||||
|
||||
if is_causal and attn_mask is None:
|
||||
@@ -184,59 +205,84 @@ def multi_head_attention_forward_patched(
|
||||
check_other=False,
|
||||
)
|
||||
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# We have the attn_mask, and use that to merge kpm into it.
|
||||
# Turn off use of is_causal hint, as the merged mask is no
|
||||
# longer causal.
|
||||
is_causal = False
|
||||
|
||||
assert embed_dim == embed_dim_to_check, \
|
||||
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
assert (
|
||||
embed_dim == embed_dim_to_check
|
||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
if isinstance(embed_dim, torch.Tensor):
|
||||
# embed_dim can be a tensor when JIT tracing
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||
else:
|
||||
head_dim = embed_dim // num_heads
|
||||
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
if use_separate_proj_weight:
|
||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||
assert key.shape[:2] == value.shape[:2], \
|
||||
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
assert (
|
||||
key.shape[:2] == value.shape[:2]
|
||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
else:
|
||||
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
assert (
|
||||
key.shape == value.shape
|
||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
|
||||
#
|
||||
# compute in-projection
|
||||
#
|
||||
if not use_separate_proj_weight:
|
||||
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
assert (
|
||||
in_proj_weight is not None
|
||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||
else:
|
||||
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
assert (
|
||||
q_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert (
|
||||
k_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert (
|
||||
v_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
if in_proj_bias is None:
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
||||
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
|
||||
if(cache!=None):
|
||||
if(cache["first_infer"]==1):
|
||||
cache["k"][cache["stage"]]=k
|
||||
q, k, v = _in_projection(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_proj_weight,
|
||||
k_proj_weight,
|
||||
v_proj_weight,
|
||||
b_q,
|
||||
b_k,
|
||||
b_v,
|
||||
)
|
||||
if cache != None:
|
||||
if cache["first_infer"] == 1:
|
||||
cache["k"][cache["stage"]] = k
|
||||
# print(0,cache["k"].shape)
|
||||
cache["v"][cache["stage"]]=v
|
||||
else:###12个layer每个都要留自己的cache_kv
|
||||
cache["v"][cache["stage"]] = v
|
||||
else: ###12个layer每个都要留自己的cache_kv
|
||||
# print(1,cache["k"].shape)
|
||||
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
|
||||
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
|
||||
cache["k"][cache["stage"]] = torch.cat(
|
||||
[cache["k"][cache["stage"]], k], 0
|
||||
) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
|
||||
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
|
||||
# print(2, cache["k"].shape)
|
||||
src_len = cache["k"][cache["stage"]].shape[0]
|
||||
k=cache["k"][cache["stage"]]
|
||||
v=cache["v"][cache["stage"]]
|
||||
k = cache["k"][cache["stage"]]
|
||||
v = cache["v"][cache["stage"]]
|
||||
# if attn_mask is not None:
|
||||
# attn_mask=attn_mask[-1:,]
|
||||
# print(attn_mask.shape,attn_mask)
|
||||
# print(attn_mask.shape,attn_mask)
|
||||
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||
# print(2333,cache)
|
||||
# prep attention mask
|
||||
@@ -255,14 +301,20 @@ def multi_head_attention_forward_patched(
|
||||
if attn_mask.dim() == 2:
|
||||
correct_2d_size = (tgt_len, src_len)
|
||||
if attn_mask.shape != correct_2d_size:
|
||||
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
||||
raise RuntimeError(
|
||||
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
|
||||
)
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
elif attn_mask.dim() == 3:
|
||||
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
||||
if attn_mask.shape != correct_3d_size:
|
||||
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
||||
raise RuntimeError(
|
||||
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||
raise RuntimeError(
|
||||
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
||||
)
|
||||
|
||||
# add bias along batch dimension (currently second)
|
||||
if bias_k is not None and bias_v is not None:
|
||||
@@ -286,26 +338,34 @@ def multi_head_attention_forward_patched(
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert static_k.size(0) == bsz * num_heads, \
|
||||
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
assert static_k.size(2) == head_dim, \
|
||||
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
assert (
|
||||
static_k.size(0) == bsz * num_heads
|
||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
assert (
|
||||
static_k.size(2) == head_dim
|
||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert static_v.size(0) == bsz * num_heads, \
|
||||
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
assert static_v.size(2) == head_dim, \
|
||||
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
assert (
|
||||
static_v.size(0) == bsz * num_heads
|
||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
assert (
|
||||
static_v.size(2) == head_dim
|
||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
v = static_v
|
||||
|
||||
# add zero attention along batch dimension (now first)
|
||||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||
k = torch.cat(
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
||||
)
|
||||
v = torch.cat(
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
||||
)
|
||||
if attn_mask is not None:
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
@@ -316,10 +376,15 @@ def multi_head_attention_forward_patched(
|
||||
|
||||
# merge key padding and attention masks
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (bsz, src_len), \
|
||||
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
|
||||
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||
assert key_padding_mask.shape == (
|
||||
bsz,
|
||||
src_len,
|
||||
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||
key_padding_mask = (
|
||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
||||
.expand(-1, num_heads, -1, -1)
|
||||
.reshape(bsz * num_heads, 1, src_len)
|
||||
)
|
||||
if attn_mask is None:
|
||||
attn_mask = key_padding_mask
|
||||
else:
|
||||
@@ -337,10 +402,14 @@ def multi_head_attention_forward_patched(
|
||||
B, Nt, E = q.shape
|
||||
q_scaled = q / math.sqrt(E)
|
||||
|
||||
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
||||
assert not (
|
||||
is_causal and attn_mask is None
|
||||
), "FIXME: is_causal not implemented for need_weights"
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = torch.baddbmm(
|
||||
attn_mask, q_scaled, k.transpose(-2, -1)
|
||||
)
|
||||
else:
|
||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
@@ -349,7 +418,9 @@ def multi_head_attention_forward_patched(
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
@@ -377,8 +448,12 @@ def multi_head_attention_forward_patched(
|
||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
)
|
||||
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
@@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
|
||||
) + torch.rand_like(deriv)
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
deriv
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
(d, ) = ctx.saved_tensors
|
||||
(d,) = ctx.saved_tensors
|
||||
# the same constants as used in forward pass.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
@@ -96,11 +97,12 @@ class DoubleSwish(torch.nn.Module):
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
scale_factor: Tensor,
|
||||
sign_factor: Optional[Tensor],
|
||||
channel_dim: int, ) -> Tensor:
|
||||
ctx,
|
||||
x: Tensor,
|
||||
scale_factor: Tensor,
|
||||
sign_factor: Optional[Tensor],
|
||||
channel_dim: int,
|
||||
) -> Tensor:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
@@ -125,16 +127,22 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
scale_factor = scale_factor.unsqueeze(-1)
|
||||
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
||||
neg_delta_grad = x_grad.abs() * factor
|
||||
return (x_grad - neg_delta_grad, None, None, None, )
|
||||
return (
|
||||
x_grad - neg_delta_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _compute_scale_factor(
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_abs: float,
|
||||
max_abs: float,
|
||||
gain_factor: float,
|
||||
max_factor: float, ) -> Tensor:
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_abs: float,
|
||||
max_abs: float,
|
||||
gain_factor: float,
|
||||
max_factor: float,
|
||||
) -> Tensor:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
@@ -145,23 +153,25 @@ def _compute_scale_factor(
|
||||
else:
|
||||
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
||||
# x_abs)_mean , min_abs.
|
||||
below_threshold = (
|
||||
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
||||
min=0, max=max_factor)
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
||||
min=0, max=max_factor)
|
||||
min=0, max=max_factor
|
||||
)
|
||||
|
||||
return below_threshold - above_threshold
|
||||
|
||||
|
||||
def _compute_sign_factor(
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float,
|
||||
max_positive: float,
|
||||
gain_factor: float,
|
||||
max_factor: float, ) -> Tensor:
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float,
|
||||
max_positive: float,
|
||||
gain_factor: float,
|
||||
max_factor: float,
|
||||
) -> Tensor:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
@@ -171,18 +181,18 @@ def _compute_sign_factor(
|
||||
else:
|
||||
# 0 if proportion_positive >= min_positive, else can be
|
||||
# as large as max_factor.
|
||||
factor1 = ((min_positive - proportion_positive) *
|
||||
(gain_factor / min_positive)).clamp_(
|
||||
min=0, max=max_factor)
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
||||
).clamp_(min=0, max=max_factor)
|
||||
|
||||
if max_positive == 1.0:
|
||||
factor2 = 0.0
|
||||
else:
|
||||
# 0 if self.proportion_positive <= max_positive, else can be
|
||||
# as large as -max_factor.
|
||||
factor2 = ((proportion_positive - max_positive) *
|
||||
(gain_factor / (1.0 - max_positive))).clamp_(
|
||||
min=0, max=max_factor)
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
||||
).clamp_(min=0, max=max_factor)
|
||||
sign_factor = factor1 - factor2
|
||||
# require min_positive != 0 or max_positive != 1:
|
||||
assert not isinstance(sign_factor, float)
|
||||
@@ -230,17 +240,18 @@ class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
min_positive: float=0.05,
|
||||
max_positive: float=0.95,
|
||||
max_factor: float=0.04,
|
||||
sign_gain_factor: float=0.01,
|
||||
scale_gain_factor: float=0.02,
|
||||
min_abs: float=0.2,
|
||||
max_abs: float=100.0,
|
||||
min_prob: float=0.1, ):
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.04,
|
||||
sign_gain_factor: float = 0.01,
|
||||
scale_gain_factor: float = 0.02,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
min_prob: float = 0.1,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
@@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||
torch.jit.is_tracing()):
|
||||
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
|
||||
return _no_op(x)
|
||||
|
||||
count = self.cpu_count
|
||||
@@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||
# a floor at min_prob (==0.1, by default)
|
||||
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
|
||||
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
||||
|
||||
if random.random() < prob:
|
||||
sign_gain_factor = 0.5
|
||||
@@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
gain_factor=self.sign_gain_factor / prob,
|
||||
max_factor=self.max_factor, )
|
||||
max_factor=self.max_factor,
|
||||
)
|
||||
else:
|
||||
sign_factor = None
|
||||
|
||||
@@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
|
||||
min_abs=self.min_abs,
|
||||
max_abs=self.max_abs,
|
||||
gain_factor=self.scale_gain_factor / prob,
|
||||
max_factor=self.max_factor, )
|
||||
max_factor=self.max_factor,
|
||||
)
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
scale_factor,
|
||||
sign_factor,
|
||||
self.channel_dim, )
|
||||
self.channel_dim,
|
||||
)
|
||||
else:
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
|
||||
min_prob=0.25) -> nn.Sequential:
|
||||
def BalancedDoubleSwish(
|
||||
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
||||
) -> nn.Sequential:
|
||||
"""
|
||||
ActivationBalancer -> DoubleSwish
|
||||
"""
|
||||
balancer = ActivationBalancer(
|
||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
||||
)
|
||||
return nn.Sequential(
|
||||
balancer,
|
||||
DoubleSwish(), )
|
||||
DoubleSwish(),
|
||||
)
|
||||
|
||||
@@ -26,26 +26,28 @@ class LayerNorm(nn.Module):
|
||||
elementwise_affine: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: _shape_t,
|
||||
eps: float=1e-5,
|
||||
elementwise_affine: bool=True,
|
||||
device=None,
|
||||
dtype=None, ) -> None:
|
||||
self,
|
||||
normalized_shape: _shape_t,
|
||||
eps: float = 1e-5,
|
||||
elementwise_affine: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super(LayerNorm, self).__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape, ) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(
|
||||
normalized_shape) # type: ignore[arg-type]
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@@ -57,36 +59,43 @@ class LayerNorm(nn.Module):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
|
||||
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||
if isinstance(input, tuple):
|
||||
input, embedding = input
|
||||
return (F.layer_norm(
|
||||
input,
|
||||
self.normalized_shape,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps, ), embedding, )
|
||||
return (
|
||||
F.layer_norm(
|
||||
input,
|
||||
self.normalized_shape,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
),
|
||||
embedding,
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight,
|
||||
self.bias, self.eps)
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__))
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
eps: float=1e-5,
|
||||
device=None,
|
||||
dtype=None, ) -> None:
|
||||
self,
|
||||
d_model: int,
|
||||
eps: float = 1e-5,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super(IdentityNorm, self).__init__()
|
||||
|
||||
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
|
||||
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||
if isinstance(input, tuple):
|
||||
return input
|
||||
|
||||
@@ -121,11 +130,13 @@ class TransformerEncoder(nn.Module):
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
mask: Optional[Tensor]=None,
|
||||
src_key_padding_mask: Optional[Tensor]=None,
|
||||
return_layer_states: bool=False,cache=None ) -> Tensor:
|
||||
self,
|
||||
src: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
return_layer_states: bool = False,
|
||||
cache=None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
@@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
|
||||
output = mod(
|
||||
output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask, cache=cache)
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
cache=cache,
|
||||
)
|
||||
layer_states.append(output[0])
|
||||
|
||||
if self.norm is not None:
|
||||
@@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
output = src
|
||||
for mod in self.layers:
|
||||
output = mod(output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask, cache=cache)
|
||||
output = mod(
|
||||
output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
@@ -168,43 +184,47 @@ class TransformerEncoderLayer(nn.Module):
|
||||
__constants__ = ["batch_first", "norm_first"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int=2048,
|
||||
dropout: float=0.1,
|
||||
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
|
||||
batch_first: bool=False,
|
||||
norm_first: bool=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
linear1_self_attention_cls: nn.Module=nn.Linear,
|
||||
linear2_self_attention_cls: nn.Module=nn.Linear,
|
||||
linear1_feedforward_cls: nn.Module=nn.Linear,
|
||||
linear2_feedforward_cls: nn.Module=nn.Linear,
|
||||
layer_norm_cls: nn.Module=LayerNorm,
|
||||
layer_norm_eps: float=1e-5,
|
||||
adaptive_layer_norm=False, ) -> None:
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
||||
batch_first: bool = False,
|
||||
norm_first: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
linear1_self_attention_cls: nn.Module = nn.Linear,
|
||||
linear2_self_attention_cls: nn.Module = nn.Linear,
|
||||
linear1_feedforward_cls: nn.Module = nn.Linear,
|
||||
linear2_feedforward_cls: nn.Module = nn.Linear,
|
||||
layer_norm_cls: nn.Module = LayerNorm,
|
||||
layer_norm_eps: float = 1e-5,
|
||||
adaptive_layer_norm=False,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
# print(233333333333,d_model,nhead)
|
||||
# import os
|
||||
# os._exit(2333333)
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model,#512 16
|
||||
d_model, # 512 16
|
||||
nhead,
|
||||
dropout=dropout,
|
||||
batch_first=batch_first,
|
||||
linear1_cls=linear1_self_attention_cls,
|
||||
linear2_cls=linear2_self_attention_cls,
|
||||
**factory_kwargs, )
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
|
||||
**factory_kwargs)
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
|
||||
**factory_kwargs)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
@@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
if layer_norm_cls == IdentityNorm:
|
||||
norm2 = BalancedBasicNorm(
|
||||
d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
else:
|
||||
norm2 = layer_norm_cls(
|
||||
d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
|
||||
if adaptive_layer_norm:
|
||||
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
||||
@@ -249,10 +267,12 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.activation = F.relu
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
src_mask: Optional[Tensor]=None,
|
||||
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
|
||||
self,
|
||||
src: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
cache=None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
@@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
if src_key_padding_mask is not None:
|
||||
_skpm_dtype = src_key_padding_mask.dtype
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
src_key_padding_mask):
|
||||
src_key_padding_mask
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
@@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
|
||||
x = x + self._sa_block(
|
||||
self.norm1(x, stage_embedding),
|
||||
src_mask,
|
||||
src_key_padding_mask,cache=cache )
|
||||
src_key_padding_mask,
|
||||
cache=cache,
|
||||
)
|
||||
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
||||
else:
|
||||
x = self.norm1(
|
||||
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
|
||||
stage_embedding, )
|
||||
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
|
||||
stage_embedding,
|
||||
)
|
||||
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
||||
|
||||
if is_src_tuple:
|
||||
@@ -295,12 +319,14 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
# self-attention block
|
||||
def _sa_block(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],
|
||||
cache=None,
|
||||
) -> Tensor:
|
||||
# print(x.shape,attn_mask.shape,key_padding_mask)
|
||||
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None
|
||||
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
|
||||
# import os
|
||||
# os._exit(23333)
|
||||
x = self.self_attn(
|
||||
@@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=False,cache=cache )[0]
|
||||
need_weights=False,
|
||||
cache=cache,
|
||||
)[0]
|
||||
return self.dropout1(x)
|
||||
|
||||
# feed forward block
|
||||
@@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module):
|
||||
self.d_model = d_model
|
||||
self.eps = self.norm.eps
|
||||
|
||||
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
|
||||
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
||||
if isinstance(input, tuple):
|
||||
input, embedding = input
|
||||
weight, bias = torch.split(
|
||||
self.project_layer(embedding),
|
||||
split_size_or_sections=self.d_model,
|
||||
dim=-1, )
|
||||
dim=-1,
|
||||
)
|
||||
return (weight * self.norm(input) + bias, embedding)
|
||||
|
||||
weight, bias = torch.split(
|
||||
self.project_layer(embedding),
|
||||
split_size_or_sections=self.d_model,
|
||||
dim=-1, )
|
||||
dim=-1,
|
||||
)
|
||||
return weight * self.norm(input) + bias
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
Reference in New Issue
Block a user