Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1,31 +1,28 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
||||
import os
|
||||
import pdb
|
||||
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
import argparse
|
||||
import logging
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import torch, platform
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
import torch
|
||||
from AR.data.data_module import Text2SemanticDataModule
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
from AR.utils import get_newest_ckpt
|
||||
|
||||
from collections import OrderedDict
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
|
||||
from AR.utils import get_newest_ckpt
|
||||
from process_ckpt import my_save
|
||||
|
||||
|
||||
@@ -37,7 +34,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
if_save_every_weights,
|
||||
half_weights_save_dir,
|
||||
exp_name,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.if_save_latest = if_save_latest
|
||||
@@ -50,10 +47,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||
if self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if (
|
||||
self._every_n_epochs >= 1
|
||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||
):
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
if (
|
||||
self.if_save_latest == True
|
||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
@@ -75,7 +69,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||
# torch.save(
|
||||
# print(os.environ)
|
||||
if(os.environ.get("LOCAL_RANK","0")=="0"):
|
||||
if os.environ.get("LOCAL_RANK", "0") == "0":
|
||||
my_save(
|
||||
to_save_od,
|
||||
"%s/%s-e%s.ckpt"
|
||||
@@ -112,7 +106,7 @@ def main(args):
|
||||
dirpath=ckpt_dir,
|
||||
)
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
os.environ["MASTER_ADDR"]="localhost"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
@@ -123,9 +117,9 @@ def main(args):
|
||||
devices=-1 if torch.cuda.is_available() else 1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy = DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
) if torch.cuda.is_available() else "auto",
|
||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
if torch.cuda.is_available()
|
||||
else "auto",
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
@@ -133,9 +127,7 @@ def main(args):
|
||||
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
|
||||
)
|
||||
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
||||
config, output_dir
|
||||
)
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
|
||||
|
||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||
config,
|
||||
|
||||
Reference in New Issue
Block a user