Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,6 +1,3 @@
from . import cnhubert, whisper_enc
content_module_map = {
'cnhubert': cnhubert,
'whisper': whisper_enc
}
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}

View File

@@ -1,14 +1,11 @@
import time
import librosa
import torch
import torch.nn.functional as F
import soundfile as sf
import os
from transformers import logging as tf_logging
tf_logging.set_verbosity_error()
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
@@ -23,21 +20,19 @@ cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self, base_path:str=None):
def __init__(self, base_path: str = None):
super().__init__()
if base_path is None:
base_path = cnhubert_base_path
if os.path.exists(base_path):...
else:raise FileNotFoundError(base_path)
if os.path.exists(base_path):
...
else:
raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_path, local_files_only=True
)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats

View File

@@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
return feature