hot update

This commit is contained in:
johnjim0816
2022-08-24 11:33:06 +08:00
parent ad65dd17cd
commit 62a7364c72
40 changed files with 2129 additions and 179 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:02:24
LastEditor: John
LastEditTime: 2022-08-22 17:41:28
LastEditTime: 2022-08-24 10:31:30
Discription:
Environment:
'''
@@ -64,14 +64,14 @@ def smooth(data, weight=0.9):
def plot_rewards(rewards,cfg,path=None,tag='train'):
sns.set()
plt.figure() # 创建一个图形实例,方便同时多画几个图
plt.title(f"{tag}ing curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}")
plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
plt.xlabel('epsiodes')
plt.plot(rewards, label='rewards')
plt.plot(smooth(rewards), label='smoothed')
plt.legend()
if cfg.save_fig:
if cfg['save_fig']:
plt.savefig(f"{path}/{tag}ing_curve.png")
if cfg.show_fig:
if cfg['show_fig']:
plt.show()
def plot_losses(losses, algo="DQN", save=True, path='./'):
@@ -110,12 +110,21 @@ def del_empty_dir(*paths):
if not os.listdir(os.path.join(path, dir)):
os.removedirs(os.path.join(path, dir))
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def save_args(args,path=None):
# 保存参数
args_dict = vars(args)
# save parameters
Path(path).mkdir(parents=True, exist_ok=True)
with open(f"{path}/params.json", 'w') as fp:
json.dump(args_dict, fp)
json.dump(args, fp,cls=NpEncoder)
print("Parameters saved!")
def all_seed(env,seed = 1):