Files
easy-rl/codes/DoubleDQN/memory.py
JohnJim0816 bf0f2990cf update
2021-03-23 16:10:11 +08:00

41 lines
1.1 KiB
Python

#!/usr/bin/env python
# coding=utf-8
'''
@Author: John
@Email: johnjim0816@gmail.com
@Date: 2020-06-10 15:27:16
@LastEditor: John
LastEditTime: 2021-01-20 18:58:37
@Discription:
@Environment: python 3.7.7
'''
import random
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity # buffer的最大容量
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
'''以队列的方式将样本填入buffer中
'''
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
'''随机采样batch_size个样本
'''
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*batch)
return state, action, reward, next_state, done
def __len__(self):
'''返回buffer的长度
'''
return len(self.buffer)