Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -1,17 +1,15 @@
import os
import glob
import sys
import argparse
import logging
import glob
import json
import logging
import os
import subprocess
import sys
import traceback
import librosa
import numpy as np
from scipy.io.wavfile import read
import torch
import logging
logging.getLogger("numba").setLevel(logging.ERROR)
logging.getLogger("matplotlib").setLevel(logging.ERROR)
@@ -27,11 +25,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
if optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
@@ -50,9 +44,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
)
except:
traceback.print_exc()
print(
"error, %s is not in the checkpoint" % k
) # shape不对也会比如text_embedding当cleaner修改时
print("error, %s is not in the checkpoint" % k) # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
@@ -60,25 +52,28 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
model.load_state_dict(new_state_dict)
print("load ")
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
"Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path,
iteration,
)
)
return model, optimizer, learning_rate, iteration
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
from time import time as ttime
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path))
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
@@ -132,7 +127,6 @@ def plot_spectrogram_to_numpy(spectrogram):
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
@@ -158,11 +152,13 @@ def plot_alignment_to_numpy(alignment, info=None):
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
alignment.transpose(),
aspect="auto",
origin="lower",
interpolation="none",
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
@@ -199,9 +195,7 @@ def get_hparams(init=True, stage=1):
default="./configs/s2.json",
help="JSON file for configuration",
)
parser.add_argument(
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
)
parser.add_argument("-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir")
parser.add_argument(
"-rs",
"--resume_step",
@@ -250,11 +244,7 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
"""
import re
ckpts_files = [
f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
@@ -263,8 +253,7 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
@@ -296,7 +285,7 @@ def check_git_hash(model_dir):
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
source_dir,
)
)
return
@@ -309,7 +298,8 @@ def check_git_hash(model_dir):
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
saved_hash[:8],
cur_hash[:8],
)
)
else:
@@ -366,6 +356,6 @@ class HParams:
if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac",
)
)