Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,17 +1,14 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
@@ -73,6 +70,7 @@ class MultiheadAttention(Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
@@ -104,9 +102,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@@ -117,31 +113,32 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty((embed_dim, embed_dim), **factory_kwargs),
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty((embed_dim, self.kdim), **factory_kwargs),
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty((embed_dim, self.vdim), **factory_kwargs),
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
self._reset_parameters()
@@ -150,7 +147,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@@ -164,7 +164,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:
@@ -261,28 +264,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
key_padding_mask,
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
raise AssertionError("only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
)
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
)
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
@@ -300,9 +301,7 @@ class MultiheadAttention(Module):
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
@@ -322,20 +321,10 @@ class MultiheadAttention(Module):
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@@ -350,11 +339,7 @@ class MultiheadAttention(Module):
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested