#!/usr/bin/env python # coding=utf-8 ''' Author: John Email: johnjim0816@gmail.com Date: 2021-04-13 11:00:13 LastEditor: John LastEditTime: 2021-04-15 01:25:14 Discription: Environment: ''' import numpy as np import torch class ReplayBuffer(object): def __init__(self, n_states, n_actions, max_size=int(1e6)): self.max_size = max_size self.ptr = 0 self.size = 0 self.state = np.zeros((max_size, n_states)) self.action = np.zeros((max_size, n_actions)) self.next_state = np.zeros((max_size, n_states)) self.reward = np.zeros((max_size, 1)) self.not_done = np.zeros((max_size, 1)) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def push(self, state, action, next_state, reward, done): self.state[self.ptr] = state self.action[self.ptr] = action self.next_state[self.ptr] = next_state self.reward[self.ptr] = reward self.not_done[self.ptr] = 1. - done self.ptr = (self.ptr + 1) % self.max_size self.size = min(self.size + 1, self.max_size) def sample(self, batch_size): ind = np.random.randint(0, self.size, size=batch_size) return ( torch.FloatTensor(self.state[ind]).to(self.device), torch.FloatTensor(self.action[ind]).to(self.device), torch.FloatTensor(self.next_state[ind]).to(self.device), torch.FloatTensor(self.reward[ind]).to(self.device), torch.FloatTensor(self.not_done[ind]).to(self.device) )