Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
from . import cnhubert, whisper_enc
|
||||
|
||||
content_module_map = {
|
||||
'cnhubert': cnhubert,
|
||||
'whisper': whisper_enc
|
||||
}
|
||||
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import time
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import soundfile as sf
|
||||
import os
|
||||
from transformers import logging as tf_logging
|
||||
|
||||
tf_logging.set_verbosity_error()
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
from transformers import (
|
||||
@@ -23,21 +20,19 @@ cnhubert_base_path = None
|
||||
|
||||
|
||||
class CNHubert(nn.Module):
|
||||
def __init__(self, base_path:str=None):
|
||||
def __init__(self, base_path: str = None):
|
||||
super().__init__()
|
||||
if base_path is None:
|
||||
base_path = cnhubert_base_path
|
||||
if os.path.exists(base_path):...
|
||||
else:raise FileNotFoundError(base_path)
|
||||
if os.path.exists(base_path):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(base_path)
|
||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_path, local_files_only=True
|
||||
)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
||||
|
||||
def forward(self, x):
|
||||
input_values = self.feature_extractor(
|
||||
x, return_tensors="pt", sampling_rate=16000
|
||||
).input_values.to(x.device)
|
||||
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
feats = self.model(input_values)["last_hidden_state"]
|
||||
return feats
|
||||
|
||||
|
||||
@@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
||||
feature_len = mel.shape[-1] // 2
|
||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||
with torch.no_grad():
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
||||
:1, :feature_len, :
|
||||
].transpose(1, 2)
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||
return feature
|
||||
|
||||
Reference in New Issue
Block a user