hot update
This commit is contained in:
@@ -5,11 +5,12 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-10 15:27:16
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-09-15 14:52:37
|
||||
LastEditTime: 2022-08-22 17:23:21
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
import random
|
||||
from collections import deque
|
||||
class ReplayBuffer:
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity # 经验回放的容量
|
||||
@@ -34,3 +35,40 @@ class ReplayBuffer:
|
||||
'''
|
||||
return len(self.buffer)
|
||||
|
||||
class ReplayBufferQue:
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.capacity = capacity
|
||||
self.buffer = deque(maxlen=self.capacity)
|
||||
def push(self,trainsitions):
|
||||
'''_summary_
|
||||
Args:
|
||||
trainsitions (tuple): _description_
|
||||
'''
|
||||
self.buffer.append(trainsitions)
|
||||
def sample(self, batch_size: int, sequential: bool = False):
|
||||
if batch_size > len(self.buffer):
|
||||
batch_size = len(self.buffer)
|
||||
if sequential: # sequential sampling
|
||||
rand = random.randint(0, len(self.buffer) - batch_size)
|
||||
batch = [self.buffer[i] for i in range(rand, rand + batch_size)]
|
||||
return zip(*batch)
|
||||
else:
|
||||
batch = random.sample(self.buffer, batch_size)
|
||||
return zip(*batch)
|
||||
def clear(self):
|
||||
self.buffer.clear()
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
class PGReplay(ReplayBufferQue):
|
||||
'''replay buffer for policy gradient based methods, each time these methods will sample all transitions
|
||||
Args:
|
||||
ReplayBufferQue (_type_): _description_
|
||||
'''
|
||||
def __init__(self):
|
||||
self.buffer = deque()
|
||||
def sample(self):
|
||||
''' sample all the transitions
|
||||
'''
|
||||
batch = list(self.buffer)
|
||||
return zip(*batch)
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:02:24
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-15 18:11:27
|
||||
LastEditTime: 2022-08-22 17:41:28
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -15,6 +15,7 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
from matplotlib.font_manager import FontProperties # 导入字体模块
|
||||
|
||||
@@ -84,12 +85,12 @@ def plot_losses(losses, algo="DQN", save=True, path='./'):
|
||||
plt.savefig(path+"losses_curve")
|
||||
plt.show()
|
||||
|
||||
def save_results(dic, tag='train', path = None):
|
||||
def save_results(res_dic, tag='train', path = None):
|
||||
''' 保存奖励
|
||||
'''
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
for key,value in dic.items():
|
||||
np.save(path+'{}_{}.npy'.format(tag,key),value)
|
||||
df = pd.DataFrame(res_dic)
|
||||
df.to_csv(f"{path}/{tag}ing_results.csv",index=None)
|
||||
print('Results saved!')
|
||||
|
||||
|
||||
@@ -115,4 +116,26 @@ def save_args(args,path=None):
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
with open(f"{path}/params.json", 'w') as fp:
|
||||
json.dump(args_dict, fp)
|
||||
print("参数已保存!")
|
||||
print("Parameters saved!")
|
||||
|
||||
def all_seed(env,seed = 1):
|
||||
''' omnipotent seed for RL, attention the position of seed function, you'd better put it just following the env create function
|
||||
Args:
|
||||
env (_type_):
|
||||
seed (int, optional): _description_. Defaults to 1.
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
print(f"seed = {seed}")
|
||||
env.seed(seed) # env config
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed) # config for CPU
|
||||
torch.cuda.manual_seed(seed) # config for GPU
|
||||
os.environ['PYTHONHASHSEED'] = str(seed) # config for python scripts
|
||||
# config for cudnn
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
|
||||
Reference in New Issue
Block a user