更新PPO,增加PER DQN

This commit is contained in:
johnjim0816
2022-11-14 21:35:28 +08:00
parent dc78698262
commit b8aec4c188
34 changed files with 1993 additions and 476 deletions

View File

@@ -36,11 +36,11 @@ class Launcher:
ep_reward = 0
ep_step = 0
return agent,ep_reward,ep_step
def test_one_episode(self,env, agent, cfg):
def test_one_episode(self, env, agent, cfg):
ep_reward = 0
ep_step = 0
return agent,ep_reward,ep_step
def evaluate(self,env, agent, cfg):
def evaluate(self, env, agent, cfg):
sum_eval_reward = 0
for _ in range(cfg.eval_eps):
_,eval_ep_reward,_ = self.test_one_episode(env, agent, cfg)

View File

@@ -10,6 +10,7 @@ LastEditTime: 2022-08-28 23:44:06
@Environment: python 3.7.7
'''
import random
import numpy as np
from collections import deque
class ReplayBuffer:
def __init__(self, capacity):
@@ -71,4 +72,136 @@ class PGReplay(ReplayBufferQue):
''' sample all the transitions
'''
batch = list(self.buffer)
return zip(*batch)
return zip(*batch)
class SumTree:
'''SumTree for the per(Prioritized Experience Replay) DQN.
This SumTree code is a modified version and the original code is from:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py
'''
def __init__(self, capacity: int):
self.capacity = capacity
self.data_pointer = 0
self.n_entries = 0
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype = object)
def update(self, tree_idx, p):
'''Update the sampling weight
'''
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def add(self, p, data):
'''Adding new data to the sumTree
'''
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data
# print ("tree_idx=", tree_idx)
# print ("nonzero = ", np.count_nonzero(self.tree))
self.update(tree_idx, p)
self.data_pointer += 1
if self.data_pointer >= self.capacity:
self.data_pointer = 0
if self.n_entries < self.capacity:
self.n_entries += 1
def get_leaf(self, v):
'''Sampling the data
'''
parent_idx = 0
while True:
cl_idx = 2 * parent_idx + 1
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree):
leaf_idx = parent_idx
break
else:
if v <= self.tree[cl_idx] :
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
def total(self):
return int(self.tree[0])
class ReplayTree:
'''ReplayTree for the per(Prioritized Experience Replay) DQN.
'''
def __init__(self, capacity):
self.capacity = capacity # the capacity for memory replay
self.tree = SumTree(capacity)
self.abs_err_upper = 1.
## hyper parameter for calculating the importance sampling weight
self.beta_increment_per_sampling = 0.001
self.alpha = 0.6
self.beta = 0.4
self.epsilon = 0.01
self.abs_err_upper = 1.
def __len__(self):
''' return the num of storage
'''
return self.tree.total()
def push(self, error, sample):
'''Push the sample into the replay according to the importance sampling weight
'''
p = (np.abs(error) + self.epsilon) ** self.alpha
self.tree.add(p, sample)
def sample(self, batch_size):
'''This is for sampling a batch data and the original code is from:
https://github.com/rlcode/per/blob/master/prioritized_memory.py
'''
pri_segment = self.tree.total() / batch_size
priorities = []
batch = []
idxs = []
is_weights = []
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total()
for i in range(batch_size):
a = pri_segment * i
b = pri_segment * (i+1)
s = random.uniform(a, b)
idx, p, data = self.tree.get_leaf(s)
priorities.append(p)
batch.append(data)
idxs.append(idx)
prob = p / self.tree.total()
sampling_probabilities = np.array(priorities) / self.tree.total()
is_weights = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
is_weights /= is_weights.max()
return zip(*batch), idxs, is_weights
def batch_update(self, tree_idx, abs_errors):
'''Update the importance sampling weight
'''
abs_errors += self.epsilon
clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
ps = np.power(clipped_errors, self.alpha)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:02:24
LastEditor: John
LastEditTime: 2022-10-26 07:38:17
LastEditTime: 2022-11-14 10:27:43
Discription:
Environment:
'''
@@ -179,6 +179,8 @@ def all_seed(env,seed = 1):
import torch
import numpy as np
import random
if seed == 0:
return
# print(f"seed = {seed}")
env.seed(seed) # env config
np.random.seed(seed)