修复resume epoch数识别错,每次resume都要都训一轮的问题
修复resume epoch数识别错,每次resume都要都训一轮的问题
This commit is contained in:
@@ -161,6 +161,7 @@ def run(rank, n_gpus, hps):
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
except: # 如果首次不能加载,加载pretrain
|
||||
# traceback.print_exc()
|
||||
@@ -170,7 +171,7 @@ def run(rank, n_gpus, hps):
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print(
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
@@ -198,6 +199,7 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
net_d=optim_d=scheduler_d=None
|
||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||
print("start training from epoch %s"%epoch)
|
||||
if rank == 0:
|
||||
train_and_evaluate(
|
||||
rank,
|
||||
@@ -226,6 +228,7 @@ def run(rank, n_gpus, hps):
|
||||
None,
|
||||
)
|
||||
scheduler_g.step()
|
||||
print("training done")
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
|
||||
Reference in New Issue
Block a user