Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -11,7 +11,6 @@ from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from x_transformers.x_transformers import RotaryEmbedding
@@ -28,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
from module.commons import sequence_mask
class TextEmbedding(nn.Module):
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
@@ -130,26 +130,24 @@ class DiT(nn.Module):
return ckpt_forward
def forward(#x, prompt_x, x_lens, t, style,cond
self,#d is channel,n is T
def forward( # x, prompt_x, x_lens, t, style,cond
self, # d is channel,n is T
x0: float["b n d"], # nosied input audio # noqa: F722
cond0: float["b n d"], # masked cond audio # noqa: F722
x_lens,
time: float["b"] | float[""], # time step # noqa: F821 F722
dt_base_bootstrap,
dt_base_bootstrap,
text0, # : int["b nt"] # noqa: F722#####condition feature
use_grad_ckpt=False, # bool
###no-use
drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text
# mask: bool["b n"] | None = None, # noqa: F722
):
x=x0.transpose(2,1)
cond=cond0.transpose(2,1)
text=text0.transpose(2,1)
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
x = x0.transpose(2, 1)
cond = cond0.transpose(2, 1)
text = text0.transpose(2, 1)
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@@ -158,8 +156,8 @@ class DiT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
dt = self.d_embed(dt_base_bootstrap)
t+=dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
t += dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
@@ -179,4 +177,4 @@ class DiT(nn.Module):
x = self.norm_out(x, t)
output = self.proj_out(x)
return output
return output

View File

@@ -391,6 +391,7 @@ class Attention(nn.Module):
# Attention processor
# from torch.nn.attention import SDPBackend
# torch.backends.cuda.enable_flash_sdp(True)
class AttnProcessor:
@@ -545,6 +546,7 @@ class JointAttnProcessor:
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()