update
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-20 17:43:17
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-03-20 19:36:24
|
||||
LastEditTime: 2021-04-05 11:19:20
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -40,7 +40,7 @@ class A2CConfig:
|
||||
self.eval_eps = 200
|
||||
self.eval_steps = 200
|
||||
self.target_update = 4
|
||||
self.hidden_dim=256
|
||||
self.hidden_dim = 256
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class A2C:
|
||||
def train(cfg,env,agent):
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
actor_critic = ActorCritic(n_states, n_actions, hidden_dim)
|
||||
actor_critic = ActorCritic(n_states, n_actions, cfg.hidden_dim)
|
||||
ac_optimizer = optim.Adam(actor_critic.parameters(), lr=learning_rate)
|
||||
|
||||
all_lengths = []
|
||||
@@ -112,7 +112,7 @@ def train(cfg,env,agent):
|
||||
all_lengths.append(steps)
|
||||
average_lengths.append(np.mean(all_lengths[-10:]))
|
||||
if episode % 10 == 0:
|
||||
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps, average_lengths[-1]))
|
||||
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps+1, average_lengths[-1]))
|
||||
break
|
||||
|
||||
# compute Q values
|
||||
@@ -154,7 +154,7 @@ def train(cfg,env,agent):
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = A2CConfig
|
||||
cfg = A2CConfig()
|
||||
env = gym.make("CartPole-v0")
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
|
||||
Reference in New Issue
Block a user