34 lines
1020 B
Python
34 lines
1020 B
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
@Author: John
|
|
@Email: johnjim0816@gmail.com
|
|
@Date: 2020-06-10 15:27:16
|
|
@LastEditor: John
|
|
@LastEditTime: 2020-06-13 00:29:45
|
|
@Discription:
|
|
@Environment: python 3.7.7
|
|
'''
|
|
import random
|
|
import numpy as np
|
|
|
|
class ReplayBuffer:
|
|
|
|
def __init__(self, capacity):
|
|
self.capacity = capacity
|
|
self.buffer = []
|
|
self.position = 0
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
if len(self.buffer) < self.capacity:
|
|
self.buffer.append(None)
|
|
self.buffer[self.position] = (state, action, reward, next_state, done)
|
|
self.position = (self.position + 1) % self.capacity
|
|
|
|
def sample(self, batch_size):
|
|
batch = random.sample(self.buffer, batch_size)
|
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch = map(np.stack, zip(*batch))
|
|
return state_batch, action_batch, reward_batch, next_state_batch, done_batch
|
|
|
|
def __len__(self):
|
|
return len(self.buffer) |