Code refactor + remove unused imports
This commit is contained in:
@@ -2,56 +2,84 @@
|
||||
import os
|
||||
import pdb
|
||||
|
||||
if("_CUDA_VISIBLE_DEVICES"in os.environ):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch,platform
|
||||
import torch, platform
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from AR.data.data_module import Text2SemanticDataModule
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
logging.getLogger('numba').setLevel(logging.WARNING)
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
from AR.utils import get_newest_ckpt
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class my_model_ckpt(ModelCheckpoint):
|
||||
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
if_save_latest,
|
||||
if_save_every_weights,
|
||||
half_weights_save_dir,
|
||||
exp_name,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.if_save_latest=if_save_latest
|
||||
self.if_save_every_weights=if_save_every_weights
|
||||
self.half_weights_save_dir=half_weights_save_dir
|
||||
self.exp_name=exp_name
|
||||
self.config=config
|
||||
self.if_save_latest = if_save_latest
|
||||
self.if_save_every_weights = if_save_every_weights
|
||||
self.half_weights_save_dir = half_weights_save_dir
|
||||
self.exp_name = exp_name
|
||||
self.config = config
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||
if not self._should_skip_saving_checkpoint(
|
||||
trainer
|
||||
) and self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean=list(os.listdir(self.dirpath))
|
||||
if (
|
||||
self._every_n_epochs >= 1
|
||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||
):
|
||||
if (
|
||||
self.if_save_latest == True
|
||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean = list(os.listdir(self.dirpath))
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
if (self.if_save_latest == True):
|
||||
if self.if_save_latest == True:
|
||||
for name in to_clean:
|
||||
try:
|
||||
os.remove("%s/%s"%(self.dirpath,name))
|
||||
except:pass
|
||||
if(self.if_save_every_weights==True):
|
||||
to_save_od=OrderedDict()
|
||||
to_save_od["weight"]=OrderedDict()
|
||||
dictt=trainer.strategy._lightning_module.state_dict()
|
||||
for key in dictt:to_save_od["weight"][key]=dictt[key].half()
|
||||
to_save_od["config"]=self.config
|
||||
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
|
||||
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
|
||||
os.remove("%s/%s" % (self.dirpath, name))
|
||||
except:
|
||||
pass
|
||||
if self.if_save_every_weights == True:
|
||||
to_save_od = OrderedDict()
|
||||
to_save_od["weight"] = OrderedDict()
|
||||
dictt = trainer.strategy._lightning_module.state_dict()
|
||||
for key in dictt:
|
||||
to_save_od["weight"][key] = dictt[key].half()
|
||||
to_save_od["config"] = self.config
|
||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||
torch.save(
|
||||
to_save_od,
|
||||
"%s/%s-e%s.ckpt"
|
||||
% (
|
||||
self.half_weights_save_dir,
|
||||
self.exp_name,
|
||||
trainer.current_epoch + 1,
|
||||
),
|
||||
)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
|
||||
@@ -61,41 +89,45 @@ def main(args):
|
||||
output_dir = Path(config["output_dir"])
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ckpt_dir = output_dir / 'ckpt'
|
||||
ckpt_dir = output_dir / "ckpt"
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
seed_everything(config["train"]["seed"], workers=True)
|
||||
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
||||
config=config,
|
||||
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
|
||||
if_save_latest=config["train"]["if_save_latest"],
|
||||
if_save_every_weights=config["train"]["if_save_every_weights"],
|
||||
half_weights_save_dir=config["train"]["half_weights_save_dir"],
|
||||
exp_name=config["train"]["exp_name"],
|
||||
save_top_k=-1,
|
||||
monitor='top_3_acc',
|
||||
mode='max',
|
||||
monitor="top_3_acc",
|
||||
mode="max",
|
||||
save_on_train_epoch_end=True,
|
||||
every_n_epochs=config["train"]["save_every_n_epoch"],
|
||||
dirpath=ckpt_dir,
|
||||
)
|
||||
logger = TensorBoardLogger(
|
||||
name=output_dir.stem,
|
||||
save_dir=output_dir
|
||||
)
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
accelerator='gpu',
|
||||
accelerator="gpu",
|
||||
# val_check_interval=9999999999999999999999,###不要验证
|
||||
# check_val_every_n_epoch=None,
|
||||
limit_val_batches=0,
|
||||
devices=-1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
|
||||
strategy=DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
),
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,num_sanity_val_steps=0,
|
||||
callbacks=[ckpt_callback])
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
callbacks=[ckpt_callback],
|
||||
)
|
||||
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
||||
config, output_dir)
|
||||
config, output_dir
|
||||
)
|
||||
|
||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||
config,
|
||||
@@ -116,14 +148,15 @@ def main(args):
|
||||
|
||||
|
||||
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config_file',
|
||||
"-c",
|
||||
"--config_file",
|
||||
type=str,
|
||||
default='configs/s1longer.yaml',
|
||||
help='path of config file')
|
||||
default="configs/s1longer.yaml",
|
||||
help="path of config file",
|
||||
)
|
||||
# args for dataset
|
||||
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
|
||||
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
|
||||
|
||||
Reference in New Issue
Block a user