74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
@Author: John
|
|
@Email: johnjim0816@gmail.com
|
|
@Date: 2020-06-10 15:27:16
|
|
@LastEditor: John
|
|
LastEditTime: 2022-08-28 23:44:06
|
|
@Discription:
|
|
@Environment: python 3.7.7
|
|
'''
|
|
import random
|
|
from collections import deque
|
|
class ReplayBuffer:
|
|
def __init__(self, capacity):
|
|
self.capacity = capacity # 经验回放的容量
|
|
self.buffer = [] # 缓冲区
|
|
self.position = 0
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition)
|
|
'''
|
|
if len(self.buffer) < self.capacity:
|
|
self.buffer.append(None)
|
|
self.buffer[self.position] = (state, action, reward, next_state, done)
|
|
self.position = (self.position + 1) % self.capacity
|
|
|
|
def sample(self, batch_size):
|
|
batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移
|
|
state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等
|
|
return state, action, reward, next_state, done
|
|
|
|
def __len__(self):
|
|
''' 返回当前存储的量
|
|
'''
|
|
return len(self.buffer)
|
|
|
|
class ReplayBufferQue:
|
|
def __init__(self, capacity: int) -> None:
|
|
self.capacity = capacity
|
|
self.buffer = deque(maxlen=self.capacity)
|
|
def push(self,transitions):
|
|
'''_summary_
|
|
Args:
|
|
trainsitions (tuple): _description_
|
|
'''
|
|
self.buffer.append(transitions)
|
|
def sample(self, batch_size: int, sequential: bool = False):
|
|
if batch_size > len(self.buffer):
|
|
batch_size = len(self.buffer)
|
|
if sequential: # sequential sampling
|
|
rand = random.randint(0, len(self.buffer) - batch_size)
|
|
batch = [self.buffer[i] for i in range(rand, rand + batch_size)]
|
|
return zip(*batch)
|
|
else:
|
|
batch = random.sample(self.buffer, batch_size)
|
|
return zip(*batch)
|
|
def clear(self):
|
|
self.buffer.clear()
|
|
def __len__(self):
|
|
return len(self.buffer)
|
|
|
|
class PGReplay(ReplayBufferQue):
|
|
'''replay buffer for policy gradient based methods, each time these methods will sample all transitions
|
|
Args:
|
|
ReplayBufferQue (_type_): _description_
|
|
'''
|
|
def __init__(self):
|
|
self.buffer = deque()
|
|
def sample(self):
|
|
''' sample all the transitions
|
|
'''
|
|
batch = list(self.buffer)
|
|
return zip(*batch) |