more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -26,26 +26,28 @@ class LayerNorm(nn.Module):
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float=1e-5,
elementwise_affine: bool=True,
device=None,
dtype=None, ) -> None:
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape, ) # type: ignore[assignment]
self.normalized_shape = tuple(
normalized_shape) # type: ignore[arg-type]
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@@ -57,36 +59,43 @@ class LayerNorm(nn.Module):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps, ), embedding, )
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight,
self.bias, self.eps)
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__))
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float=1e-5,
device=None,
dtype=None, ) -> None:
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
@@ -121,11 +130,13 @@ class TransformerEncoder(nn.Module):
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,
return_layer_states: bool=False,cache=None ) -> Tensor:
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
@@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0])
if self.norm is not None:
@@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
output = src
for mod in self.layers:
output = mod(output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None:
output = self.norm(output)
@@ -168,43 +184,47 @@ class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int=2048,
dropout: float=0.1,
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
batch_first: bool=False,
norm_first: bool=False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module=nn.Linear,
linear2_self_attention_cls: nn.Module=nn.Linear,
linear1_feedforward_cls: nn.Module=nn.Linear,
linear2_feedforward_cls: nn.Module=nn.Linear,
layer_norm_cls: nn.Module=LayerNorm,
layer_norm_eps: float=1e-5,
adaptive_layer_norm=False, ) -> None:
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead)
# import os
# os._exit(2333333)
self.self_attn = MultiheadAttention(
d_model,#512 16
d_model, # 512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs, )
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
**factory_kwargs)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
**factory_kwargs)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@@ -249,10 +267,12 @@ class TransformerEncoderLayer(nn.Module):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask):
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
@@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,cache=cache )
src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
stage_embedding, )
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
@@ -295,12 +319,14 @@ class TransformerEncoderLayer(nn.Module):
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask)
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os
# os._exit(23333)
x = self.self_attn(
@@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0]
need_weights=False,
cache=cache,
)[0]
return self.dropout1(x)
# feed forward block
@@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module):
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])