Simplify BSR (#1356)
This commit is contained in:
@@ -7,8 +7,9 @@ import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
|
||||
from beartype.typing import Tuple, Optional, List, Callable
|
||||
from beartype import beartype
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
@@ -125,7 +126,7 @@ class LinearAttention(Module):
|
||||
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -219,7 +220,7 @@ class Transformer(Module):
|
||||
# bandsplit module
|
||||
|
||||
class BandSplit(Module):
|
||||
@beartype
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
@@ -274,7 +275,7 @@ def MLP(
|
||||
|
||||
|
||||
class MaskEstimator(Module):
|
||||
@beartype
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
@@ -325,7 +326,7 @@ DEFAULT_FREQS_PER_BANDS = (
|
||||
|
||||
class BSRoformer(Module):
|
||||
|
||||
@beartype
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
|
||||
Reference in New Issue
Block a user