Simplify BSR (#1356)

This commit is contained in:
KamioRinn
2024-07-30 10:32:37 +08:00
committed by GitHub
parent 8abc0342d7
commit 7670bc77c3
3 changed files with 13 additions and 13 deletions

View File

@@ -7,8 +7,9 @@ import torch.nn.functional as F
from bs_roformer.attend import Attend
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
from typing import Tuple, Optional, List, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding
@@ -125,7 +126,7 @@ class LinearAttention(Module):
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
"""
@beartype
# @beartype
def __init__(
self,
*,
@@ -219,7 +220,7 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
@beartype
# @beartype
def __init__(
self,
dim,
@@ -274,7 +275,7 @@ def MLP(
class MaskEstimator(Module):
@beartype
# @beartype
def __init__(
self,
dim,
@@ -325,7 +326,7 @@ DEFAULT_FREQS_PER_BANDS = (
class BSRoformer(Module):
@beartype
# @beartype
def __init__(
self,
dim,