This commit is contained in:
JohnJim0816
2021-04-16 14:59:23 +08:00
parent 312b57fdff
commit e4690ac89f
71 changed files with 805 additions and 153 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-20 17:43:17
LastEditor: John
LastEditTime: 2021-03-20 19:36:24
LastEditTime: 2021-04-05 11:19:20
Discription:
Environment:
'''
@@ -40,7 +40,7 @@ class A2CConfig:
self.eval_eps = 200
self.eval_steps = 200
self.target_update = 4
self.hidden_dim=256
self.hidden_dim = 256
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -76,7 +76,7 @@ class A2C:
def train(cfg,env,agent):
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
actor_critic = ActorCritic(n_states, n_actions, hidden_dim)
actor_critic = ActorCritic(n_states, n_actions, cfg.hidden_dim)
ac_optimizer = optim.Adam(actor_critic.parameters(), lr=learning_rate)
all_lengths = []
@@ -112,7 +112,7 @@ def train(cfg,env,agent):
all_lengths.append(steps)
average_lengths.append(np.mean(all_lengths[-10:]))
if episode % 10 == 0:
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps, average_lengths[-1]))
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps+1, average_lengths[-1]))
break
# compute Q values
@@ -154,7 +154,7 @@ def train(cfg,env,agent):
plt.show()
if __name__ == "__main__":
cfg = A2CConfig
cfg = A2CConfig()
env = gym.make("CartPole-v0")
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n