更新算法模版
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:02:24
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-24 10:31:30
|
||||
LastEditTime: 2022-10-26 07:38:17
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -14,8 +14,13 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import json
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from functools import wraps
|
||||
from time import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from matplotlib.font_manager import FontProperties # 导入字体模块
|
||||
|
||||
@@ -61,17 +66,17 @@ def smooth(data, weight=0.9):
|
||||
last = smoothed_val
|
||||
return smoothed
|
||||
|
||||
def plot_rewards(rewards,cfg,path=None,tag='train'):
|
||||
def plot_rewards(rewards,title="learning curve",fpath=None,save_fig=True,show_fig=False):
|
||||
sns.set()
|
||||
plt.figure() # 创建一个图形实例,方便同时多画几个图
|
||||
plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
|
||||
plt.title(f"{title}")
|
||||
plt.xlabel('epsiodes')
|
||||
plt.plot(rewards, label='rewards')
|
||||
plt.plot(smooth(rewards), label='smoothed')
|
||||
plt.legend()
|
||||
if cfg['save_fig']:
|
||||
plt.savefig(f"{path}/{tag}ing_curve.png")
|
||||
if cfg['show_fig']:
|
||||
if save_fig:
|
||||
plt.savefig(f"{fpath}/learning_curve.png")
|
||||
if show_fig:
|
||||
plt.show()
|
||||
|
||||
def plot_losses(losses, algo="DQN", save=True, path='./'):
|
||||
@@ -85,48 +90,86 @@ def plot_losses(losses, algo="DQN", save=True, path='./'):
|
||||
plt.savefig(path+"losses_curve")
|
||||
plt.show()
|
||||
|
||||
def save_results(res_dic, tag='train', path = None):
|
||||
''' 保存奖励
|
||||
def save_results(res_dic,fpath = None):
|
||||
''' save results
|
||||
'''
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
Path(fpath).mkdir(parents=True, exist_ok=True)
|
||||
df = pd.DataFrame(res_dic)
|
||||
df.to_csv(f"{path}/{tag}ing_results.csv",index=None)
|
||||
print('Results saved!')
|
||||
|
||||
|
||||
def make_dir(*paths):
|
||||
''' 创建文件夹
|
||||
df.to_csv(f"{fpath}/res.csv",index=None)
|
||||
def merge_class_attrs(ob1, ob2):
|
||||
ob1.__dict__.update(ob2.__dict__)
|
||||
return ob1
|
||||
def get_logger(fpath):
|
||||
Path(fpath).mkdir(parents=True, exist_ok=True)
|
||||
logger = logging.getLogger(name='r') # set root logger if not set name
|
||||
logger.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s: - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
# output to file by using FileHandler
|
||||
fh = logging.FileHandler(fpath+"log.txt")
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(formatter)
|
||||
# output to screen by using StreamHandler
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(formatter)
|
||||
# add Handler
|
||||
logger.addHandler(ch)
|
||||
logger.addHandler(fh)
|
||||
return logger
|
||||
def save_cfgs(cfgs, fpath):
|
||||
''' save config
|
||||
'''
|
||||
for path in paths:
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
Path(fpath).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(f"{fpath}/config.yaml", 'w') as f:
|
||||
for cfg_type in cfgs:
|
||||
yaml.dump({cfg_type: cfgs[cfg_type].__dict__}, f, default_flow_style=False)
|
||||
def load_cfgs(cfgs, fpath):
|
||||
with open(fpath) as f:
|
||||
load_cfg = yaml.load(f,Loader=yaml.FullLoader)
|
||||
for cfg_type in cfgs:
|
||||
for k, v in load_cfg[cfg_type].items():
|
||||
setattr(cfgs[cfg_type], k, v)
|
||||
# def del_empty_dir(*paths):
|
||||
# ''' 删除目录下所有空文件夹
|
||||
# '''
|
||||
# for path in paths:
|
||||
# dirs = os.listdir(path)
|
||||
# for dir in dirs:
|
||||
# if not os.listdir(os.path.join(path, dir)):
|
||||
# os.removedirs(os.path.join(path, dir))
|
||||
|
||||
|
||||
def del_empty_dir(*paths):
|
||||
''' 删除目录下所有空文件夹
|
||||
'''
|
||||
for path in paths:
|
||||
dirs = os.listdir(path)
|
||||
for dir in dirs:
|
||||
if not os.listdir(os.path.join(path, dir)):
|
||||
os.removedirs(os.path.join(path, dir))
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
# class NpEncoder(json.JSONEncoder):
|
||||
# def default(self, obj):
|
||||
# if isinstance(obj, np.integer):
|
||||
# return int(obj)
|
||||
# if isinstance(obj, np.floating):
|
||||
# return float(obj)
|
||||
# if isinstance(obj, np.ndarray):
|
||||
# return obj.tolist()
|
||||
# return json.JSONEncoder.default(self, obj)
|
||||
|
||||
def save_args(args,path=None):
|
||||
# save parameters
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
with open(f"{path}/params.json", 'w') as fp:
|
||||
json.dump(args, fp,cls=NpEncoder)
|
||||
print("Parameters saved!")
|
||||
# def save_args(args,path=None):
|
||||
# # save parameters
|
||||
# Path(path).mkdir(parents=True, exist_ok=True)
|
||||
# with open(f"{path}/params.json", 'w') as fp:
|
||||
# json.dump(args, fp,cls=NpEncoder)
|
||||
# print("Parameters saved!")
|
||||
|
||||
|
||||
def timing(func):
|
||||
''' a decorator to print the running time of a function
|
||||
'''
|
||||
@wraps(func)
|
||||
def wrap(*args, **kw):
|
||||
ts = time()
|
||||
result = func(*args, **kw)
|
||||
te = time()
|
||||
print(f"func: {func.__name__}, took: {te-ts:2.4f} seconds")
|
||||
return result
|
||||
return wrap
|
||||
def all_seed(env,seed = 1):
|
||||
''' omnipotent seed for RL, attention the position of seed function, you'd better put it just following the env create function
|
||||
Args:
|
||||
@@ -136,7 +179,7 @@ def all_seed(env,seed = 1):
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
print(f"seed = {seed}")
|
||||
# print(f"seed = {seed}")
|
||||
env.seed(seed) # env config
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user