update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-10 15:28:30
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-03-19 19:56:46
|
||||
LastEditTime: 2021-09-16 00:52:30
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -32,12 +32,12 @@ class NormalizedActions(gym.ActionWrapper):
|
||||
return action
|
||||
|
||||
class OUNoise(object):
|
||||
'''Ornstein–Uhlenbeck
|
||||
'''Ornstein–Uhlenbeck噪声
|
||||
'''
|
||||
def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000):
|
||||
self.mu = mu
|
||||
self.theta = theta
|
||||
self.sigma = max_sigma
|
||||
self.mu = mu # OU噪声的参数
|
||||
self.theta = theta # OU噪声的参数
|
||||
self.sigma = max_sigma # OU噪声的参数
|
||||
self.max_sigma = max_sigma
|
||||
self.min_sigma = min_sigma
|
||||
self.decay_period = decay_period
|
||||
@@ -45,17 +45,14 @@ class OUNoise(object):
|
||||
self.low = action_space.low
|
||||
self.high = action_space.high
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.obs = np.ones(self.action_dim) * self.mu
|
||||
|
||||
def evolve_obs(self):
|
||||
x = self.obs
|
||||
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
|
||||
self.obs = x + dx
|
||||
return self.obs
|
||||
|
||||
def get_action(self, action, t=0):
|
||||
ou_obs = self.evolve_obs()
|
||||
self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
|
||||
return np.clip(action + ou_obs, self.low, self.high)
|
||||
self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period) # sigma会逐渐衰减
|
||||
return np.clip(action + ou_obs, self.low, self.high) # 动作加上噪声后进行剪切
|
||||
Reference in New Issue
Block a user