Files
easy-rl/codes/PPO/memory.py
johnjim0816 fb2affb69e update
2021-09-27 03:44:29 +08:00

44 lines
1.3 KiB
Python

#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-23 15:30:46
LastEditor: John
LastEditTime: 2021-09-26 22:00:07
Discription:
Environment:
'''
import numpy as np
class PPOMemory:
def __init__(self, batch_size):
self.states = []
self.probs = []
self.vals = []
self.actions = []
self.rewards = []
self.dones = []
self.batch_size = batch_size
def sample(self):
batch_step = np.arange(0, len(self.states), self.batch_size)
indices = np.arange(len(self.states), dtype=np.int64)
np.random.shuffle(indices)
batches = [indices[i:i+self.batch_size] for i in batch_step]
return np.array(self.states),np.array(self.actions),np.array(self.probs),\
np.array(self.vals),np.array(self.rewards),np.array(self.dones),batches
def push(self, state, action, probs, vals, reward, done):
self.states.append(state)
self.actions.append(action)
self.probs.append(probs)
self.vals.append(vals)
self.rewards.append(reward)
self.dones.append(done)
def clear(self):
self.states = []
self.probs = []
self.actions = []
self.rewards = []
self.dones = []
self.vals = []