This commit is contained in:
JohnJim0816
2021-04-16 14:59:23 +08:00
parent 312b57fdff
commit e4690ac89f
71 changed files with 805 additions and 153 deletions

View File

@@ -14,16 +14,16 @@ import gym
from A2C.multiprocessing_env import SubprocVecEnv
# num_envs = 16
# env_name = "Pendulum-v0"
# env = "Pendulum-v0"
def make_envs(num_envs=16,env_name="Pendulum-v0"):
def make_envs(num_envs=16,env="Pendulum-v0"):
''' 创建多个子环境
'''
num_envs = 16
env_name = "CartPole-v0"
env = "CartPole-v0"
def make_env():
def _thunk():
env = gym.make(env_name)
env = gym.make(env)
return env
return _thunk
@@ -34,10 +34,10 @@ def make_envs(num_envs=16,env_name="Pendulum-v0"):
# if __name__ == "__main__":
# num_envs = 16
# env_name = "CartPole-v0"
# env = "CartPole-v0"
# def make_env():
# def _thunk():
# env = gym.make(env_name)
# env = gym.make(env)
# return env
# return _thunk
@@ -45,4 +45,4 @@ def make_envs(num_envs=16,env_name="Pendulum-v0"):
# envs = [make_env() for i in range(num_envs)]
# envs = SubprocVecEnv(envs)
if __name__ == "__main__":
envs = make_envs(num_envs=16,env_name="CartPole-v0")
envs = make_envs(num_envs=16,env="CartPole-v0")

View File

@@ -5,16 +5,20 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-11 20:58:21
@LastEditor: John
LastEditTime: 2021-03-20 16:58:04
LastEditTime: 2021-04-05 11:14:39
@Discription:
@Environment: python 3.7.9
'''
import sys,os
sys.path.append(os.getcwd()) # add current terminal path
curr_path = os.path.dirname(__file__)
parent_path=os.path.dirname(curr_path)
sys.path.append(parent_path) # add current terminal path to sys.path
import torch
import gym
import datetime
from A2C.agent import A2C
from common.utils import save_results,make_dir,del_empty_dir

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-20 17:43:17
LastEditor: John
LastEditTime: 2021-03-20 19:36:24
LastEditTime: 2021-04-05 11:19:20
Discription:
Environment:
'''
@@ -40,7 +40,7 @@ class A2CConfig:
self.eval_eps = 200
self.eval_steps = 200
self.target_update = 4
self.hidden_dim=256
self.hidden_dim = 256
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -76,7 +76,7 @@ class A2C:
def train(cfg,env,agent):
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
actor_critic = ActorCritic(n_states, n_actions, hidden_dim)
actor_critic = ActorCritic(n_states, n_actions, cfg.hidden_dim)
ac_optimizer = optim.Adam(actor_critic.parameters(), lr=learning_rate)
all_lengths = []
@@ -112,7 +112,7 @@ def train(cfg,env,agent):
all_lengths.append(steps)
average_lengths.append(np.mean(all_lengths[-10:]))
if episode % 10 == 0:
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps, average_lengths[-1]))
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps+1, average_lengths[-1]))
break
# compute Q values
@@ -154,7 +154,7 @@ def train(cfg,env,agent):
plt.show()
if __name__ == "__main__":
cfg = A2CConfig
cfg = A2CConfig()
env = gym.make("CartPole-v0")
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n