@@ -64,6 +64,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
||||
)
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
tmp_path="%s.pth"%(ttime())
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
|
||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
logger.info(
|
||||
@@ -75,7 +83,8 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(
|
||||
# torch.save(
|
||||
my_save(
|
||||
{
|
||||
"model": state_dict,
|
||||
"iteration": iteration,
|
||||
|
||||
Reference in New Issue
Block a user