#!/usr/bin/env python # coding=utf-8 ''' @Author: John @Email: johnjim0816@gmail.com @Date: 2020-06-10 15:27:16 @LastEditor: John LastEditTime: 2022-08-28 23:44:06 @Discription: @Environment: python 3.7.7 ''' import random from collections import deque class ReplayBuffer: def __init__(self, capacity): self.capacity = capacity # 经验回放的容量 self.buffer = [] # 缓冲区 self.position = 0 def push(self, state, action, reward, next_state, done): ''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition) ''' if len(self.buffer) < self.capacity: self.buffer.append(None) self.buffer[self.position] = (state, action, reward, next_state, done) self.position = (self.position + 1) % self.capacity def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移 state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等 return state, action, reward, next_state, done def __len__(self): ''' 返回当前存储的量 ''' return len(self.buffer) class ReplayBufferQue: def __init__(self, capacity: int) -> None: self.capacity = capacity self.buffer = deque(maxlen=self.capacity) def push(self,transitions): '''_summary_ Args: trainsitions (tuple): _description_ ''' self.buffer.append(transitions) def sample(self, batch_size: int, sequential: bool = False): if batch_size > len(self.buffer): batch_size = len(self.buffer) if sequential: # sequential sampling rand = random.randint(0, len(self.buffer) - batch_size) batch = [self.buffer[i] for i in range(rand, rand + batch_size)] return zip(*batch) else: batch = random.sample(self.buffer, batch_size) return zip(*batch) def clear(self): self.buffer.clear() def __len__(self): return len(self.buffer) class PGReplay(ReplayBufferQue): '''replay buffer for policy gradient based methods, each time these methods will sample all transitions Args: ReplayBufferQue (_type_): _description_ ''' def __init__(self): self.buffer = deque() def sample(self): ''' sample all the transitions ''' batch = list(self.buffer) return zip(*batch)