hot update

This commit is contained in:
johnjim0816
2022-08-22 17:50:11 +08:00
parent 0a54840828
commit ad65dd17cd
54 changed files with 1639 additions and 503 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:02:24
LastEditor: John
LastEditTime: 2022-08-15 18:11:27
LastEditTime: 2022-08-22 17:41:28
Discription:
Environment:
'''
@@ -15,6 +15,7 @@ from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import json
import pandas as pd
from matplotlib.font_manager import FontProperties # 导入字体模块
@@ -84,12 +85,12 @@ def plot_losses(losses, algo="DQN", save=True, path='./'):
plt.savefig(path+"losses_curve")
plt.show()
def save_results(dic, tag='train', path = None):
def save_results(res_dic, tag='train', path = None):
''' 保存奖励
'''
Path(path).mkdir(parents=True, exist_ok=True)
for key,value in dic.items():
np.save(path+'{}_{}.npy'.format(tag,key),value)
df = pd.DataFrame(res_dic)
df.to_csv(f"{path}/{tag}ing_results.csv",index=None)
print('Results saved')
@@ -115,4 +116,26 @@ def save_args(args,path=None):
Path(path).mkdir(parents=True, exist_ok=True)
with open(f"{path}/params.json", 'w') as fp:
json.dump(args_dict, fp)
print("参数已保存!")
print("Parameters saved!")
def all_seed(env,seed = 1):
''' omnipotent seed for RL, attention the position of seed function, you'd better put it just following the env create function
Args:
env (_type_):
seed (int, optional): _description_. Defaults to 1.
'''
import torch
import numpy as np
import random
print(f"seed = {seed}")
env.seed(seed) # env config
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed) # config for CPU
torch.cuda.manual_seed(seed) # config for GPU
os.environ['PYTHONHASHSEED'] = str(seed) # config for python scripts
# config for cudnn
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False