update
This commit is contained in:
@@ -19,12 +19,12 @@ from common.memory import ReplayBuffer
|
||||
|
||||
|
||||
class DDPG:
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
self.device = cfg.device
|
||||
self.critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
|
||||
for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
|
||||
@@ -41,17 +41,17 @@ class OUNoise(object):
|
||||
self.max_sigma = max_sigma
|
||||
self.min_sigma = min_sigma
|
||||
self.decay_period = decay_period
|
||||
self.n_actions = action_space.shape[0]
|
||||
self.action_dim = action_space.shape[0]
|
||||
self.low = action_space.low
|
||||
self.high = action_space.high
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.obs = np.ones(self.n_actions) * self.mu
|
||||
self.obs = np.ones(self.action_dim) * self.mu
|
||||
|
||||
def evolve_obs(self):
|
||||
x = self.obs
|
||||
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.n_actions)
|
||||
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
|
||||
self.obs = x + dx
|
||||
return self.obs
|
||||
|
||||
|
||||
@@ -82,9 +82,9 @@ if __name__ == "__main__":
|
||||
cfg = DDPGConfig()
|
||||
env = NormalizedActions(gym.make("Pendulum-v0"))
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.shape[0]
|
||||
agent = DDPG(n_states,n_actions,cfg)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.shape[0]
|
||||
agent = DDPG(state_dim,action_dim,cfg)
|
||||
rewards,ma_rewards = train(cfg,env,agent)
|
||||
agent.save(path=SAVED_MODEL_PATH)
|
||||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
|
||||
|
||||
Reference in New Issue
Block a user