Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1,17 +1,14 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
@@ -73,6 +70,7 @@ class MultiheadAttention(Module):
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
@@ -104,9 +102,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@@ -117,31 +113,32 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
)
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
@@ -150,7 +147,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@@ -164,7 +164,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
@@ -261,28 +264,26 @@ class MultiheadAttention(Module):
|
||||
if key_padding_mask is not None:
|
||||
_kpm_dtype = key_padding_mask.dtype
|
||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
key_padding_mask
|
||||
key_padding_mask,
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
)
|
||||
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
)
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif not self.batch_first:
|
||||
@@ -300,9 +301,7 @@ class MultiheadAttention(Module):
|
||||
elif attn_mask is not None:
|
||||
why_not_fast_path = "attn_mask was not None"
|
||||
elif query.is_nested and key_padding_mask is not None:
|
||||
why_not_fast_path = (
|
||||
"key_padding_mask is not supported with NestedTensor input"
|
||||
)
|
||||
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
||||
elif self.num_heads % 2 == 1:
|
||||
why_not_fast_path = "num_heads is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
@@ -322,20 +321,10 @@ class MultiheadAttention(Module):
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all(
|
||||
[
|
||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
||||
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
@@ -350,11 +339,7 @@ class MultiheadAttention(Module):
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1
|
||||
if key_padding_mask is not None
|
||||
else 0
|
||||
if attn_mask is not None
|
||||
else None,
|
||||
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
|
||||
Reference in New Issue
Block a user