Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,14 +1,11 @@
import time
import librosa
import torch
import torch.nn.functional as F
import soundfile as sf
import os
from transformers import logging as tf_logging
tf_logging.set_verbosity_error()
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
@@ -23,21 +20,19 @@ cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self, base_path:str=None):
def __init__(self, base_path: str = None):
super().__init__()
if base_path is None:
base_path = cnhubert_base_path
if os.path.exists(base_path):...
else:raise FileNotFoundError(base_path)
if os.path.exists(base_path):
...
else:
raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_path, local_files_only=True
)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats