Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -11,7 +11,6 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
@@ -28,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
||||
|
||||
from module.commons import sequence_mask
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
@@ -130,26 +130,24 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||
self,#d is channel,n is T
|
||||
def forward( # x, prompt_x, x_lens, t, style,cond
|
||||
self, # d is channel,n is T
|
||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond0: float["b n d"], # masked cond audio # noqa: F722
|
||||
x_lens,
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
dt_base_bootstrap,
|
||||
dt_base_bootstrap,
|
||||
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||
use_grad_ckpt=False, # bool
|
||||
###no-use
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
|
||||
):
|
||||
|
||||
x=x0.transpose(2,1)
|
||||
cond=cond0.transpose(2,1)
|
||||
text=text0.transpose(2,1)
|
||||
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
|
||||
x = x0.transpose(2, 1)
|
||||
cond = cond0.transpose(2, 1)
|
||||
text = text0.transpose(2, 1)
|
||||
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
||||
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@@ -158,8 +156,8 @@ class DiT(nn.Module):
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
t+=dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
|
||||
t += dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
@@ -179,4 +177,4 @@ class DiT(nn.Module):
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
@@ -391,6 +391,7 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
# from torch.nn.attention import SDPBackend
|
||||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
class AttnProcessor:
|
||||
@@ -545,6 +546,7 @@ class JointAttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user