Support for mel_band_roformer (#2078)

* support for mel_band_roformer

* Remove unnecessary audio channel judgments

* remove context manager and fix path

* Update webui.py

* Update README.md
This commit is contained in:
Sucial
2025-02-23 20:28:53 +08:00
committed by GitHub
parent fbb9f21e53
commit e061e9d38e
10 changed files with 941 additions and 176 deletions

View File

@@ -1,18 +1,8 @@
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val):
return val is not None
@@ -20,21 +10,6 @@ def exists(val):
def default(v, d):
return v if exists(v) else d
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
self,
@@ -50,43 +25,16 @@ class Attend(nn.Module):
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
def forward(self, q, k, v):
"""
@@ -97,7 +45,7 @@ class Attend(nn.Module):
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)