Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_mha_shape_check,
|
||||
_canonical_mask,
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
@@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
||||
# set up shape vars
|
||||
_, _, embed_dim = query.shape
|
||||
attn_mask = _canonical_mask(
|
||||
@@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
|
||||
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user