hot update
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:02:24
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-22 17:41:28
|
||||
LastEditTime: 2022-08-24 10:31:30
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -64,14 +64,14 @@ def smooth(data, weight=0.9):
|
||||
def plot_rewards(rewards,cfg,path=None,tag='train'):
|
||||
sns.set()
|
||||
plt.figure() # 创建一个图形实例,方便同时多画几个图
|
||||
plt.title(f"{tag}ing curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}")
|
||||
plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
|
||||
plt.xlabel('epsiodes')
|
||||
plt.plot(rewards, label='rewards')
|
||||
plt.plot(smooth(rewards), label='smoothed')
|
||||
plt.legend()
|
||||
if cfg.save_fig:
|
||||
if cfg['save_fig']:
|
||||
plt.savefig(f"{path}/{tag}ing_curve.png")
|
||||
if cfg.show_fig:
|
||||
if cfg['show_fig']:
|
||||
plt.show()
|
||||
|
||||
def plot_losses(losses, algo="DQN", save=True, path='./'):
|
||||
@@ -110,12 +110,21 @@ def del_empty_dir(*paths):
|
||||
if not os.listdir(os.path.join(path, dir)):
|
||||
os.removedirs(os.path.join(path, dir))
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
def save_args(args,path=None):
|
||||
# 保存参数
|
||||
args_dict = vars(args)
|
||||
# save parameters
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
with open(f"{path}/params.json", 'w') as fp:
|
||||
json.dump(args_dict, fp)
|
||||
json.dump(args, fp,cls=NpEncoder)
|
||||
print("Parameters saved!")
|
||||
|
||||
def all_seed(env,seed = 1):
|
||||
|
||||
Reference in New Issue
Block a user