hot update

This commit is contained in:
johnjim0816
2022-08-22 17:50:11 +08:00
parent 0a54840828
commit ad65dd17cd
54 changed files with 1639 additions and 503 deletions

View File

@@ -5,11 +5,12 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-10 15:27:16
@LastEditor: John
LastEditTime: 2021-09-15 14:52:37
LastEditTime: 2022-08-22 17:23:21
@Discription:
@Environment: python 3.7.7
'''
import random
from collections import deque
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity # 经验回放的容量
@@ -34,3 +35,40 @@ class ReplayBuffer:
'''
return len(self.buffer)
class ReplayBufferQue:
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self.buffer = deque(maxlen=self.capacity)
def push(self,trainsitions):
'''_summary_
Args:
trainsitions (tuple): _description_
'''
self.buffer.append(trainsitions)
def sample(self, batch_size: int, sequential: bool = False):
if batch_size > len(self.buffer):
batch_size = len(self.buffer)
if sequential: # sequential sampling
rand = random.randint(0, len(self.buffer) - batch_size)
batch = [self.buffer[i] for i in range(rand, rand + batch_size)]
return zip(*batch)
else:
batch = random.sample(self.buffer, batch_size)
return zip(*batch)
def clear(self):
self.buffer.clear()
def __len__(self):
return len(self.buffer)
class PGReplay(ReplayBufferQue):
'''replay buffer for policy gradient based methods, each time these methods will sample all transitions
Args:
ReplayBufferQue (_type_): _description_
'''
def __init__(self):
self.buffer = deque()
def sample(self):
''' sample all the transitions
'''
batch = list(self.buffer)
return zip(*batch)