Add files via upload

This commit is contained in:
RVC-Boss
2024-01-16 17:38:48 +08:00
committed by GitHub
parent 143d32f621
commit 41ca6028d6
65 changed files with 139856 additions and 0 deletions

View File

View File

@@ -0,0 +1,397 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward=multi_head_attention_forward_patched
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- dropout is 0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None, ) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (self.kdim == embed_dim and
self.vdim == embed_dim)
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs))
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
else:
if not self._qkv_same_embed_dim:
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = self.in_proj_linear.bias
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
self.add_zero_attn = add_zero_attn
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor]=None,
need_weights: bool=True,
attn_mask: Optional[Tensor]=None,
average_attn_weights: bool=True,cache=None
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (self.in_proj_bias is not None and
query.dtype != self.in_proj_bias.dtype):
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (self.in_proj_weight is not None and
query.dtype != self.in_proj_weight.dtype):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input")
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (query, key, value, self.in_proj_weight,
self.in_proj_bias, self.out_proj.weight,
self.out_proj.bias, )
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args]):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU")
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else 0
if attn_mask is not None else None, )
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}")
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,cache=cache )
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,cache=cache )
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights

View File

@@ -0,0 +1,78 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
import math
import torch
from torch import nn
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float=0.0, ):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index:index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float=0.0,
scale: bool=False,
alpha: bool=False, ):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.embedding_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
return self.dropout(output)

View File

@@ -0,0 +1,85 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
import math
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Adam
class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
"""
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
"""
def __init__(self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0):
self.init_lr = init_lr
self.peak_lr = peak_lr
self.end_lr = end_lr
self.optimizer = optimizer
self._warmup_rate = (peak_lr - init_lr) / warmup_steps
self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
self._current_step = current_step
self.lr = init_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self._last_lr = [self.lr]
def set_lr(self, lr):
self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups:
# g['lr'] = lr
g['lr'] = self.end_lr###锁定用线性
def step(self):
if self._current_step < self.warmup_steps:
lr = self.init_lr + self._warmup_rate * self._current_step
elif self._current_step > self.total_steps:
lr = self.end_lr
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
self.set_lr(lr)
self.lr = lr
self._current_step += 1
return self.lr
if __name__ == '__main__':
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0)
lrs = []
for i in range(25000):
s.step()
lrs.append(s.lr)
print(s.lr)
plt.plot(lrs)
plt.plot(range(0, 25000), lrs)
plt.show()

View File

@@ -0,0 +1,622 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from collections import defaultdict
from typing import List
from typing import Tuple
import torch
from torch import Tensor
from torch.optim import Optimizer
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension. This is intended for speed with GPUs,
as it reduces the number of kernels launched in the optimizer.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager
def batched_params(self, param_group, group_params_names):
"""
This function returns (technically, yields) a list of
of tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this batch of parameters
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
This function is decorated as a context manager so that it can
write parameters back to their "real" locations.
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
with self.batched_params(group["params"]) as batches:
for p, state, p_names in batches:
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.param_groups.
group_params_names: name for each parameter in group,
which is List[str].
"""
batches = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
key = (str(p.dtype), *p.shape)
batches[key].append(p)
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [
batches_names[batches_names_keys[idx]] for idx in sorted_idx
]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
# turn batches into a list, in deterministic order.
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
# one for each batch in `batches`.
tuples = []
for batch, batch_names in zip(batches, batches_names):
p = batch[0]
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
tuples.append((p_stacked, state, batch_names))
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
in log space, subject to upper and lower limits (as if we had factored each parameter as
param = underlying_param * log_scale.exp())
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled
by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1.
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
"""
def __init__(
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=3.0,
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True, ):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter")
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
betas=betas,
scalar_lr_scale=scalar_lr_scale,
eps=eps,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period, )
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
batch = True
for group, group_params_names in zip(self.param_groups,
self.parameters_names):
with self.batched_params(group["params"],
group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
if (len(batches[0][1]) ==
0): # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
for p, state, _ in batches:
# Perform optimization step.
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
self._step_one_batch(group, p, state, clipping_scale)
return loss
def _init_state(self, group: dict, p: Tensor, state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
size_update_period = group["size_update_period"]
state["step"] = 0
kwargs = {"device": p.device, "dtype": p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel()
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period,
*param_rms.shape, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
def _get_clipping_scale(self,
group: dict,
tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
Args:
group: the parameter group, an item in self.param_groups
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
"""
assert len(tuples) >= 1
clipping_scale = group["clipping_scale"]
(first_p, first_state, _) = tuples[0]
step = first_state["step"]
if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other
# parameters' state won't have been initialized yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients")
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"])**2).sum()
tot_norm = tot_sumsq.sqrt()
if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(
clipping_update_period, device=p.device)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
quartiles = []
for n in range(0, 5):
index = min(
clipping_update_period - 1,
(clipping_update_period // 4) * n, )
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 /
clipping_update_period
if "num_clipped" in first_state else 0.0)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
try:
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor):
"""
Show information of parameter wihch dominanting tot_sumsq.
Args:
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
tot_sumsq: sumsq of all parameters. Though it's could be calculated
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummpy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(batch_param_names,
batch_sumsq_orig,
batch_rms_orig, batch_grad):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0), )
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(),
key=lambda item: item[1][0],
reverse=True, )
}
dominant_param_name = next(iter(sorted_by_proportion))
(dominant_proportion, dominant_sumsq, dominant_rms,
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict,
clipping_scale: float):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
beta1 = group["betas"][0]
grad = p.grad
if clipping_scale != 1.0:
grad = grad * clipping_scale
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2)
.mean(dim=list(range(1, p.ndim)), keepdim=True)
.sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state)
if numel == 1:
# For parameters with 1 element we just use regular Adam.
# Updates delta.
self._step_scalar(group, p, state)
else:
self._step(group, p, state)
state["step"] = step + 1
def _size_update(self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
gradient descent on underlying param and on scale, this function does the update
on `scale`.
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
"""
param_rms = state["param_rms"]
beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["scalar_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
eps = group["eps"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state[
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr**size_step
# we don't bother with bias_correction1; this will help prevent divergence
# at the start of training.
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (-size_lr * (bias_correction2**0.5) *
scale_grads.sum(dim=0) / denom)
is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
# when it gets too large, stop it from getting any larger.
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1 - beta1))
def _step(self, group: dict, p: Tensor, state: dict):
"""
This function does the core update of self.step(), in the case where the members of
the batch have more than 1 element.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
This function modifies p.
"""
grad = p.grad
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
param_min_rms = group["param_min_rms"]
step = state["step"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"]
if "zero_step" in state else 0)
bias_correction2 = 1 - beta2**(this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt()
denom += eps
grad = grad / denom
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"]
delta.add_(grad * alpha)
p.add_(delta)
def _step_scalar(self, group: dict, p: Tensor, state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2**(state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta)

View File

@@ -0,0 +1,388 @@
from torch.nn.functional import *
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
# import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,cache=None
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
Default: `True`
Note: `needs_weight` defaults to `True`, but should be set to `False`
For best performance when attention weights are not nedeeded.
*Setting needs_weights to `True`
leads to a significant performance degradation.*
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
is_causal: If specified, applies a causal mask as attention mask, and ignores
attn_mask for computing scaled dot product attention.
Default: ``False``.
.. warning::
is_causal is provides a hint that the attn_mask is the
causal mask.Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
when ``need_weights=True.``. Default: True
Shape:
Inputs:
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a FloatTensor is provided, it will be directly added to the value.
If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
tens_ops,
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
is_causal=is_causal,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attn_weights,cache=cache
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
# batch dimension so that the output doesn't carry this temporary batch dimension.
if not is_batched:
# unsqueeze if the input is unbatched
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(0)
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)
if is_causal and key_padding_mask is None and not need_weights:
# when we have a kpm or need weights, we need attn_mask
# Otherwise, we use the is_causal hint go as is_causal
# indicator to SDPA.
attn_mask = None
else:
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
if(cache!=None):
if(cache["first_infer"]==1):
cache["k"][cache["stage"]]=k
# print(0,cache["k"].shape)
cache["v"][cache["stage"]]=v
else:###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape)
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
# print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0]
k=cache["k"][cache["stage"]]
v=cache["v"][cache["stage"]]
# if attn_mask is not None:
# attn_mask=attn_mask[-1:,]
# print(attn_mask.shape,attn_mask)
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
# print(2333,cache)
# prep attention mask
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=q.dtype,
check_other=False,
)
if attn_mask is not None:
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.size(1)
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask
# adjust dropout probability
if not training:
dropout_p = 0.0
#
# (deep breath) calculate attention and out projection
#
if need_weights:
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# optionally average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights
else:
# attn_mask can be either (L,S) or (N*num_heads, L, S)
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
# in order to match the input for SDPA of (N, num_heads, L, S)
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
return attn_output, None

View File

@@ -0,0 +1,319 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import random
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import torch.nn as nn
from torch import Tensor
class DoubleSwishFunction(torch.autograd.Function):
"""
double_swish(x) = x * torch.sigmoid(x-1)
This is a definition, originally motivated by its close numerical
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
Memory-efficient derivative computation:
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
Now, s'(x) = s(x) * (1-s(x)).
double_swish'(x) = x * s'(x) + s(x).
= x * s(x) * (1-s(x)) + s(x).
= double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
x_dtype = x.dtype
if x.dtype == torch.float16:
x = x.to(torch.float32)
s = torch.sigmoid(x - 1.0)
y = x * s
if requires_grad:
deriv = y * (1 - s) + s
# notes on derivative of x * sigmoid(x - 1):
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
) + torch.rand_like(deriv)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
return y
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
(d, ) = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.043637
ceil = 1.2
d = d * ((ceil - floor) / 255.0) + floor
return y_grad * d
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
xgt0 = x > 0
if sign_factor is None:
ctx.save_for_backward(xgt0, scale_factor)
else:
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
if len(ctx.saved_tensors) == 3:
xgt0, scale_factor, sign_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1)
sign_factor = sign_factor.unsqueeze(-1)
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
else:
xgt0, scale_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, )
def _compute_scale_factor(
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
if min_abs == 0.0:
below_threshold = 0.0
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = (
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor)
return below_threshold - above_threshold
def _compute_sign_factor(
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
if min_positive == 0.0:
factor1 = 0.0
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) *
(gain_factor / min_positive)).clamp_(
min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) *
(gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
return sign_factor
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive.
Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), above which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives for
either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.02].
sign_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_positive and max_positive
are violated.
scale_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_abs and max_abs
are violated.
min_abs: the minimum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
min_prob: determines the minimum probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers
from doing it at the same time. Early in training we may use
higher probabilities than this; it will decay to this value.
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: float=0.05,
max_positive: float=0.95,
max_factor: float=0.04,
sign_gain_factor: float=0.01,
scale_gain_factor: float=0.02,
min_abs: float=0.2,
max_abs: float=100.0,
min_prob: float=0.1, ):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.min_positive = min_positive
self.max_positive = max_positive
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
self.min_prob = min_prob
self.sign_gain_factor = sign_gain_factor
self.scale_gain_factor = scale_gain_factor
# count measures how many times the forward() function has been called.
# We occasionally sync this to a tensor called `count`, that exists to
# make sure it is synced to disk when we load and save the model.
self.cpu_count = 0
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
torch.jit.is_tracing()):
return _no_op(x)
count = self.cpu_count
self.cpu_count += 1
if random.random() < 0.01:
# Occasionally sync self.cpu_count with self.count.
# count affects the decay of 'prob'. don't do this on every iter,
# because syncing with the GPU is slow.
self.cpu_count = max(self.cpu_count, self.count.item())
self.count.fill_(self.cpu_count)
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
if random.random() < prob:
sign_gain_factor = 0.5
if self.min_positive != 0.0 or self.max_positive != 1.0:
sign_factor = _compute_sign_factor(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, )
else:
sign_factor = None
scale_factor = _compute_scale_factor(
x.detach(),
self.channel_dim,
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, )
return ActivationBalancerFunction.apply(
x,
scale_factor,
sign_factor,
self.channel_dim, )
else:
return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
min_prob=0.25) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
return nn.Sequential(
balancer,
DoubleSwish(), )

View File

@@ -0,0 +1,347 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
import copy
import numbers
from functools import partial
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from AR.modules.activation import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish
from torch import nn
from torch import Tensor
from torch.nn import functional as F
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float=1e-5,
elementwise_affine: bool=True,
device=None,
dtype=None, ) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape, ) # type: ignore[assignment]
self.normalized_shape = tuple(
normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps, ), embedding, )
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight,
self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__))
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float=1e-5,
device=None,
dtype=None, ) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,
return_layer_states: bool=False,cache=None ) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
output = src
for mod in self.layers:
output = mod(output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int=2048,
dropout: float=0.1,
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
batch_first: bool=False,
norm_first: bool=False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module=nn.Linear,
linear2_self_attention_cls: nn.Module=nn.Linear,
linear1_feedforward_cls: nn.Module=nn.Linear,
linear2_feedforward_cls: nn.Module=nn.Linear,
layer_norm_cls: nn.Module=LayerNorm,
layer_norm_eps: float=1e-5,
adaptive_layer_norm=False, ) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead)
# import os
# os._exit(2333333)
self.self_attn = MultiheadAttention(
d_model,#512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs, )
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
**factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
**factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
# # We can't test self.activation in forward() in TorchScript,
# # so stash some information about it instead.
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
# self.activation_relu_or_gelu = 1
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
# self.activation_relu_or_gelu = 2
# else:
# self.activation_relu_or_gelu = 0
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,cache=cache )
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
stage_embedding, )
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
return (x, stage_embedding)
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask)
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os
# os._exit(23333)
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])