44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
Author: John
|
|
Email: johnjim0816@gmail.com
|
|
Date: 2021-04-13 11:00:13
|
|
LastEditor: John
|
|
LastEditTime: 2021-04-15 01:25:14
|
|
Discription:
|
|
Environment:
|
|
'''
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class ReplayBuffer(object):
|
|
def __init__(self, n_states, n_actions, max_size=int(1e6)):
|
|
self.max_size = max_size
|
|
self.ptr = 0
|
|
self.size = 0
|
|
self.state = np.zeros((max_size, n_states))
|
|
self.action = np.zeros((max_size, n_actions))
|
|
self.next_state = np.zeros((max_size, n_states))
|
|
self.reward = np.zeros((max_size, 1))
|
|
self.not_done = np.zeros((max_size, 1))
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
def push(self, state, action, next_state, reward, done):
|
|
self.state[self.ptr] = state
|
|
self.action[self.ptr] = action
|
|
self.next_state[self.ptr] = next_state
|
|
self.reward[self.ptr] = reward
|
|
self.not_done[self.ptr] = 1. - done
|
|
self.ptr = (self.ptr + 1) % self.max_size
|
|
self.size = min(self.size + 1, self.max_size)
|
|
|
|
def sample(self, batch_size):
|
|
ind = np.random.randint(0, self.size, size=batch_size)
|
|
return (
|
|
torch.FloatTensor(self.state[ind]).to(self.device),
|
|
torch.FloatTensor(self.action[ind]).to(self.device),
|
|
torch.FloatTensor(self.next_state[ind]).to(self.device),
|
|
torch.FloatTensor(self.reward[ind]).to(self.device),
|
|
torch.FloatTensor(self.not_done[ind]).to(self.device)
|
|
) |