update
This commit is contained in:
@@ -1,34 +1,44 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
@Author: John
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-10 15:27:16
|
||||
@LastEditor: John
|
||||
@LastEditTime: 2020-06-11 21:04:50
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-04-13 11:00:13
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-04-15 01:25:14
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class ReplayBuffer:
|
||||
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.position = 0
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(None)
|
||||
self.buffer[self.position] = (state, action, reward, next_state, done)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size):
|
||||
batch = random.sample(self.buffer, batch_size)
|
||||
state, action, reward, next_state, done = map(np.stack, zip(*batch))
|
||||
return state, action, reward, next_state, done
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self, state_dim, action_dim, max_size=int(1e6)):
|
||||
self.max_size = max_size
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
self.state = np.zeros((max_size, state_dim))
|
||||
self.action = np.zeros((max_size, action_dim))
|
||||
self.next_state = np.zeros((max_size, state_dim))
|
||||
self.reward = np.zeros((max_size, 1))
|
||||
self.not_done = np.zeros((max_size, 1))
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def push(self, state, action, next_state, reward, done):
|
||||
self.state[self.ptr] = state
|
||||
self.action[self.ptr] = action
|
||||
self.next_state[self.ptr] = next_state
|
||||
self.reward[self.ptr] = reward
|
||||
self.not_done[self.ptr] = 1. - done
|
||||
self.ptr = (self.ptr + 1) % self.max_size
|
||||
self.size = min(self.size + 1, self.max_size)
|
||||
|
||||
def sample(self, batch_size):
|
||||
ind = np.random.randint(0, self.size, size=batch_size)
|
||||
return (
|
||||
torch.FloatTensor(self.state[ind]).to(self.device),
|
||||
torch.FloatTensor(self.action[ind]).to(self.device),
|
||||
torch.FloatTensor(self.next_state[ind]).to(self.device),
|
||||
torch.FloatTensor(self.reward[ind]).to(self.device),
|
||||
torch.FloatTensor(self.not_done[ind]).to(self.device)
|
||||
)
|
||||
Reference in New Issue
Block a user