#!/usr/bin/env python # coding=utf-8 ''' @Author: John @Email: johnjim0816@gmail.com @Date: 2020-06-10 15:27:16 @LastEditor: John LastEditTime: 2022-08-28 23:44:06 @Discription: @Environment: python 3.7.7 ''' import random import numpy as np from collections import deque class ReplayBuffer: def __init__(self, capacity): self.capacity = capacity # 经验回放的容量 self.buffer = [] # 缓冲区 self.position = 0 def push(self, state, action, reward, next_state, done): ''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition) ''' if len(self.buffer) < self.capacity: self.buffer.append(None) self.buffer[self.position] = (state, action, reward, next_state, done) self.position = (self.position + 1) % self.capacity def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移 state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等 return state, action, reward, next_state, done def __len__(self): ''' 返回当前存储的量 ''' return len(self.buffer) class ReplayBufferQue: def __init__(self, capacity: int) -> None: self.capacity = capacity self.buffer = deque(maxlen=self.capacity) def push(self,transitions): '''_summary_ Args: trainsitions (tuple): _description_ ''' self.buffer.append(transitions) def sample(self, batch_size: int, sequential: bool = False): if batch_size > len(self.buffer): batch_size = len(self.buffer) if sequential: # sequential sampling rand = random.randint(0, len(self.buffer) - batch_size) batch = [self.buffer[i] for i in range(rand, rand + batch_size)] return zip(*batch) else: batch = random.sample(self.buffer, batch_size) return zip(*batch) def clear(self): self.buffer.clear() def __len__(self): return len(self.buffer) class PGReplay(ReplayBufferQue): '''replay buffer for policy gradient based methods, each time these methods will sample all transitions Args: ReplayBufferQue (_type_): _description_ ''' def __init__(self): self.buffer = deque() def sample(self): ''' sample all the transitions ''' batch = list(self.buffer) return zip(*batch) class SumTree: '''SumTree for the per(Prioritized Experience Replay) DQN. This SumTree code is a modified version and the original code is from: https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py ''' def __init__(self, capacity: int): self.capacity = capacity self.data_pointer = 0 self.n_entries = 0 self.tree = np.zeros(2 * capacity - 1) self.data = np.zeros(capacity, dtype = object) def update(self, tree_idx, p): '''Update the sampling weight ''' change = p - self.tree[tree_idx] self.tree[tree_idx] = p while tree_idx != 0: tree_idx = (tree_idx - 1) // 2 self.tree[tree_idx] += change def add(self, p, data): '''Adding new data to the sumTree ''' tree_idx = self.data_pointer + self.capacity - 1 self.data[self.data_pointer] = data # print ("tree_idx=", tree_idx) # print ("nonzero = ", np.count_nonzero(self.tree)) self.update(tree_idx, p) self.data_pointer += 1 if self.data_pointer >= self.capacity: self.data_pointer = 0 if self.n_entries < self.capacity: self.n_entries += 1 def get_leaf(self, v): '''Sampling the data ''' parent_idx = 0 while True: cl_idx = 2 * parent_idx + 1 cr_idx = cl_idx + 1 if cl_idx >= len(self.tree): leaf_idx = parent_idx break else: if v <= self.tree[cl_idx] : parent_idx = cl_idx else: v -= self.tree[cl_idx] parent_idx = cr_idx data_idx = leaf_idx - self.capacity + 1 return leaf_idx, self.tree[leaf_idx], self.data[data_idx] def total(self): return int(self.tree[0]) class ReplayTree: '''ReplayTree for the per(Prioritized Experience Replay) DQN. ''' def __init__(self, capacity): self.capacity = capacity # the capacity for memory replay self.tree = SumTree(capacity) self.abs_err_upper = 1. ## hyper parameter for calculating the importance sampling weight self.beta_increment_per_sampling = 0.001 self.alpha = 0.6 self.beta = 0.4 self.epsilon = 0.01 self.abs_err_upper = 1. def __len__(self): ''' return the num of storage ''' return self.tree.total() def push(self, error, sample): '''Push the sample into the replay according to the importance sampling weight ''' p = (np.abs(error) + self.epsilon) ** self.alpha self.tree.add(p, sample) def sample(self, batch_size): '''This is for sampling a batch data and the original code is from: https://github.com/rlcode/per/blob/master/prioritized_memory.py ''' pri_segment = self.tree.total() / batch_size priorities = [] batch = [] idxs = [] is_weights = [] self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total() for i in range(batch_size): a = pri_segment * i b = pri_segment * (i+1) s = random.uniform(a, b) idx, p, data = self.tree.get_leaf(s) priorities.append(p) batch.append(data) idxs.append(idx) prob = p / self.tree.total() sampling_probabilities = np.array(priorities) / self.tree.total() is_weights = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) is_weights /= is_weights.max() return zip(*batch), idxs, is_weights def batch_update(self, tree_idx, abs_errors): '''Update the importance sampling weight ''' abs_errors += self.epsilon clipped_errors = np.minimum(abs_errors, self.abs_err_upper) ps = np.power(clipped_errors, self.alpha) for ti, p in zip(tree_idx, ps): self.tree.update(ti, p)