更新算法模版

This commit is contained in:
johnjim0816
2022-11-06 12:15:36 +08:00
parent 466a17707f
commit dc78698262
256 changed files with 17282 additions and 10229 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:02:24
LastEditor: John
LastEditTime: 2022-08-24 10:31:30
LastEditTime: 2022-10-26 07:38:17
Discription:
Environment:
'''
@@ -14,8 +14,13 @@ import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import json
import yaml
import pandas as pd
from functools import wraps
from time import time
import logging
from pathlib import Path
from matplotlib.font_manager import FontProperties # 导入字体模块
@@ -61,17 +66,17 @@ def smooth(data, weight=0.9):
last = smoothed_val
return smoothed
def plot_rewards(rewards,cfg,path=None,tag='train'):
def plot_rewards(rewards,title="learning curve",fpath=None,save_fig=True,show_fig=False):
sns.set()
plt.figure() # 创建一个图形实例,方便同时多画几个图
plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
plt.title(f"{title}")
plt.xlabel('epsiodes')
plt.plot(rewards, label='rewards')
plt.plot(smooth(rewards), label='smoothed')
plt.legend()
if cfg['save_fig']:
plt.savefig(f"{path}/{tag}ing_curve.png")
if cfg['show_fig']:
if save_fig:
plt.savefig(f"{fpath}/learning_curve.png")
if show_fig:
plt.show()
def plot_losses(losses, algo="DQN", save=True, path='./'):
@@ -85,48 +90,86 @@ def plot_losses(losses, algo="DQN", save=True, path='./'):
plt.savefig(path+"losses_curve")
plt.show()
def save_results(res_dic, tag='train', path = None):
''' 保存奖励
def save_results(res_dic,fpath = None):
''' save results
'''
Path(path).mkdir(parents=True, exist_ok=True)
Path(fpath).mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(res_dic)
df.to_csv(f"{path}/{tag}ing_results.csv",index=None)
print('Results saved')
def make_dir(*paths):
''' 创建文件夹
df.to_csv(f"{fpath}/res.csv",index=None)
def merge_class_attrs(ob1, ob2):
ob1.__dict__.update(ob2.__dict__)
return ob1
def get_logger(fpath):
Path(fpath).mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(name='r') # set root logger if not set name
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s: - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
# output to file by using FileHandler
fh = logging.FileHandler(fpath+"log.txt")
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
# output to screen by using StreamHandler
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
# add Handler
logger.addHandler(ch)
logger.addHandler(fh)
return logger
def save_cfgs(cfgs, fpath):
''' save config
'''
for path in paths:
Path(path).mkdir(parents=True, exist_ok=True)
Path(fpath).mkdir(parents=True, exist_ok=True)
with open(f"{fpath}/config.yaml", 'w') as f:
for cfg_type in cfgs:
yaml.dump({cfg_type: cfgs[cfg_type].__dict__}, f, default_flow_style=False)
def load_cfgs(cfgs, fpath):
with open(fpath) as f:
load_cfg = yaml.load(f,Loader=yaml.FullLoader)
for cfg_type in cfgs:
for k, v in load_cfg[cfg_type].items():
setattr(cfgs[cfg_type], k, v)
# def del_empty_dir(*paths):
# ''' 删除目录下所有空文件夹
# '''
# for path in paths:
# dirs = os.listdir(path)
# for dir in dirs:
# if not os.listdir(os.path.join(path, dir)):
# os.removedirs(os.path.join(path, dir))
def del_empty_dir(*paths):
''' 删除目录下所有空文件夹
'''
for path in paths:
dirs = os.listdir(path)
for dir in dirs:
if not os.listdir(os.path.join(path, dir)):
os.removedirs(os.path.join(path, dir))
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
# class NpEncoder(json.JSONEncoder):
# def default(self, obj):
# if isinstance(obj, np.integer):
# return int(obj)
# if isinstance(obj, np.floating):
# return float(obj)
# if isinstance(obj, np.ndarray):
# return obj.tolist()
# return json.JSONEncoder.default(self, obj)
def save_args(args,path=None):
# save parameters
Path(path).mkdir(parents=True, exist_ok=True)
with open(f"{path}/params.json", 'w') as fp:
json.dump(args, fp,cls=NpEncoder)
print("Parameters saved!")
# def save_args(args,path=None):
# # save parameters
# Path(path).mkdir(parents=True, exist_ok=True)
# with open(f"{path}/params.json", 'w') as fp:
# json.dump(args, fp,cls=NpEncoder)
# print("Parameters saved!")
def timing(func):
''' a decorator to print the running time of a function
'''
@wraps(func)
def wrap(*args, **kw):
ts = time()
result = func(*args, **kw)
te = time()
print(f"func: {func.__name__}, took: {te-ts:2.4f} seconds")
return result
return wrap
def all_seed(env,seed = 1):
''' omnipotent seed for RL, attention the position of seed function, you'd better put it just following the env create function
Args:
@@ -136,7 +179,7 @@ def all_seed(env,seed = 1):
import torch
import numpy as np
import random
print(f"seed = {seed}")
# print(f"seed = {seed}")
env.seed(seed) # env config
np.random.seed(seed)
random.seed(seed)