Files
easy-rl/codes/TD3/memory.py
2022-05-31 01:20:58 +08:00

44 lines
1.4 KiB
Python

#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2021-04-13 11:00:13
LastEditor: John
LastEditTime: 2021-04-15 01:25:14
Discription:
Environment:
'''
import numpy as np
import torch
class ReplayBuffer(object):
def __init__(self, n_states, n_actions, max_size=int(1e6)):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.state = np.zeros((max_size, n_states))
self.action = np.zeros((max_size, n_actions))
self.next_state = np.zeros((max_size, n_states))
self.reward = np.zeros((max_size, 1))
self.not_done = np.zeros((max_size, 1))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def push(self, state, action, next_state, reward, done):
self.state[self.ptr] = state
self.action[self.ptr] = action
self.next_state[self.ptr] = next_state
self.reward[self.ptr] = reward
self.not_done[self.ptr] = 1. - done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.state[ind]).to(self.device),
torch.FloatTensor(self.action[ind]).to(self.device),
torch.FloatTensor(self.next_state[ind]).to(self.device),
torch.FloatTensor(self.reward[ind]).to(self.device),
torch.FloatTensor(self.not_done[ind]).to(self.device)
)