208 lines
6.6 KiB
Python
208 lines
6.6 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
@Author: John
|
|
@Email: johnjim0816@gmail.com
|
|
@Date: 2020-06-10 15:27:16
|
|
@LastEditor: John
|
|
LastEditTime: 2022-08-28 23:44:06
|
|
@Discription:
|
|
@Environment: python 3.7.7
|
|
'''
|
|
import random
|
|
import numpy as np
|
|
from collections import deque
|
|
class ReplayBuffer:
|
|
def __init__(self, capacity):
|
|
self.capacity = capacity # 经验回放的容量
|
|
self.buffer = [] # 缓冲区
|
|
self.position = 0
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition)
|
|
'''
|
|
if len(self.buffer) < self.capacity:
|
|
self.buffer.append(None)
|
|
self.buffer[self.position] = (state, action, reward, next_state, done)
|
|
self.position = (self.position + 1) % self.capacity
|
|
|
|
def sample(self, batch_size):
|
|
batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移
|
|
state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等
|
|
return state, action, reward, next_state, done
|
|
|
|
def __len__(self):
|
|
''' 返回当前存储的量
|
|
'''
|
|
return len(self.buffer)
|
|
|
|
class ReplayBufferQue:
|
|
def __init__(self, capacity: int) -> None:
|
|
self.capacity = capacity
|
|
self.buffer = deque(maxlen=self.capacity)
|
|
def push(self,transitions):
|
|
'''_summary_
|
|
Args:
|
|
trainsitions (tuple): _description_
|
|
'''
|
|
self.buffer.append(transitions)
|
|
def sample(self, batch_size: int, sequential: bool = False):
|
|
if batch_size > len(self.buffer):
|
|
batch_size = len(self.buffer)
|
|
if sequential: # sequential sampling
|
|
rand = random.randint(0, len(self.buffer) - batch_size)
|
|
batch = [self.buffer[i] for i in range(rand, rand + batch_size)]
|
|
return zip(*batch)
|
|
else:
|
|
batch = random.sample(self.buffer, batch_size)
|
|
return zip(*batch)
|
|
def clear(self):
|
|
self.buffer.clear()
|
|
def __len__(self):
|
|
return len(self.buffer)
|
|
|
|
class PGReplay(ReplayBufferQue):
|
|
'''replay buffer for policy gradient based methods, each time these methods will sample all transitions
|
|
Args:
|
|
ReplayBufferQue (_type_): _description_
|
|
'''
|
|
def __init__(self):
|
|
self.buffer = deque()
|
|
def sample(self):
|
|
''' sample all the transitions
|
|
'''
|
|
batch = list(self.buffer)
|
|
return zip(*batch)
|
|
|
|
class SumTree:
|
|
'''SumTree for the per(Prioritized Experience Replay) DQN.
|
|
This SumTree code is a modified version and the original code is from:
|
|
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py
|
|
'''
|
|
def __init__(self, capacity: int):
|
|
self.capacity = capacity
|
|
self.data_pointer = 0
|
|
self.n_entries = 0
|
|
self.tree = np.zeros(2 * capacity - 1)
|
|
self.data = np.zeros(capacity, dtype = object)
|
|
|
|
def update(self, tree_idx, p):
|
|
'''Update the sampling weight
|
|
'''
|
|
change = p - self.tree[tree_idx]
|
|
self.tree[tree_idx] = p
|
|
|
|
while tree_idx != 0:
|
|
tree_idx = (tree_idx - 1) // 2
|
|
self.tree[tree_idx] += change
|
|
|
|
def add(self, p, data):
|
|
'''Adding new data to the sumTree
|
|
'''
|
|
tree_idx = self.data_pointer + self.capacity - 1
|
|
self.data[self.data_pointer] = data
|
|
# print ("tree_idx=", tree_idx)
|
|
# print ("nonzero = ", np.count_nonzero(self.tree))
|
|
self.update(tree_idx, p)
|
|
|
|
self.data_pointer += 1
|
|
if self.data_pointer >= self.capacity:
|
|
self.data_pointer = 0
|
|
|
|
if self.n_entries < self.capacity:
|
|
self.n_entries += 1
|
|
|
|
def get_leaf(self, v):
|
|
'''Sampling the data
|
|
'''
|
|
parent_idx = 0
|
|
while True:
|
|
cl_idx = 2 * parent_idx + 1
|
|
cr_idx = cl_idx + 1
|
|
if cl_idx >= len(self.tree):
|
|
leaf_idx = parent_idx
|
|
break
|
|
else:
|
|
if v <= self.tree[cl_idx] :
|
|
parent_idx = cl_idx
|
|
else:
|
|
v -= self.tree[cl_idx]
|
|
parent_idx = cr_idx
|
|
|
|
data_idx = leaf_idx - self.capacity + 1
|
|
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
|
|
|
|
def total(self):
|
|
return int(self.tree[0])
|
|
|
|
class ReplayTree:
|
|
'''ReplayTree for the per(Prioritized Experience Replay) DQN.
|
|
'''
|
|
def __init__(self, capacity):
|
|
self.capacity = capacity # the capacity for memory replay
|
|
self.tree = SumTree(capacity)
|
|
self.abs_err_upper = 1.
|
|
|
|
## hyper parameter for calculating the importance sampling weight
|
|
self.beta_increment_per_sampling = 0.001
|
|
self.alpha = 0.6
|
|
self.beta = 0.4
|
|
self.epsilon = 0.01
|
|
self.abs_err_upper = 1.
|
|
|
|
def __len__(self):
|
|
''' return the num of storage
|
|
'''
|
|
return self.tree.total()
|
|
|
|
def push(self, error, sample):
|
|
'''Push the sample into the replay according to the importance sampling weight
|
|
'''
|
|
p = (np.abs(error) + self.epsilon) ** self.alpha
|
|
self.tree.add(p, sample)
|
|
|
|
|
|
def sample(self, batch_size):
|
|
'''This is for sampling a batch data and the original code is from:
|
|
https://github.com/rlcode/per/blob/master/prioritized_memory.py
|
|
'''
|
|
pri_segment = self.tree.total() / batch_size
|
|
|
|
priorities = []
|
|
batch = []
|
|
idxs = []
|
|
|
|
is_weights = []
|
|
|
|
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
|
|
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total()
|
|
|
|
for i in range(batch_size):
|
|
a = pri_segment * i
|
|
b = pri_segment * (i+1)
|
|
|
|
s = random.uniform(a, b)
|
|
idx, p, data = self.tree.get_leaf(s)
|
|
|
|
priorities.append(p)
|
|
batch.append(data)
|
|
idxs.append(idx)
|
|
prob = p / self.tree.total()
|
|
|
|
sampling_probabilities = np.array(priorities) / self.tree.total()
|
|
is_weights = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
|
|
is_weights /= is_weights.max()
|
|
|
|
return zip(*batch), idxs, is_weights
|
|
|
|
def batch_update(self, tree_idx, abs_errors):
|
|
'''Update the importance sampling weight
|
|
'''
|
|
abs_errors += self.epsilon
|
|
|
|
clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
|
|
ps = np.power(clipped_errors, self.alpha)
|
|
|
|
for ti, p in zip(tree_idx, ps):
|
|
self.tree.update(ti, p)
|