update
This commit is contained in:
@@ -14,16 +14,16 @@ import gym
|
||||
from A2C.multiprocessing_env import SubprocVecEnv
|
||||
|
||||
# num_envs = 16
|
||||
# env_name = "Pendulum-v0"
|
||||
# env = "Pendulum-v0"
|
||||
|
||||
def make_envs(num_envs=16,env_name="Pendulum-v0"):
|
||||
def make_envs(num_envs=16,env="Pendulum-v0"):
|
||||
''' 创建多个子环境
|
||||
'''
|
||||
num_envs = 16
|
||||
env_name = "CartPole-v0"
|
||||
env = "CartPole-v0"
|
||||
def make_env():
|
||||
def _thunk():
|
||||
env = gym.make(env_name)
|
||||
env = gym.make(env)
|
||||
return env
|
||||
|
||||
return _thunk
|
||||
@@ -34,10 +34,10 @@ def make_envs(num_envs=16,env_name="Pendulum-v0"):
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# num_envs = 16
|
||||
# env_name = "CartPole-v0"
|
||||
# env = "CartPole-v0"
|
||||
# def make_env():
|
||||
# def _thunk():
|
||||
# env = gym.make(env_name)
|
||||
# env = gym.make(env)
|
||||
# return env
|
||||
|
||||
# return _thunk
|
||||
@@ -45,4 +45,4 @@ def make_envs(num_envs=16,env_name="Pendulum-v0"):
|
||||
# envs = [make_env() for i in range(num_envs)]
|
||||
# envs = SubprocVecEnv(envs)
|
||||
if __name__ == "__main__":
|
||||
envs = make_envs(num_envs=16,env_name="CartPole-v0")
|
||||
envs = make_envs(num_envs=16,env="CartPole-v0")
|
||||
@@ -5,16 +5,20 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-11 20:58:21
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-03-20 16:58:04
|
||||
LastEditTime: 2021-04-05 11:14:39
|
||||
@Discription:
|
||||
@Environment: python 3.7.9
|
||||
'''
|
||||
import sys,os
|
||||
sys.path.append(os.getcwd()) # add current terminal path
|
||||
curr_path = os.path.dirname(__file__)
|
||||
parent_path=os.path.dirname(curr_path)
|
||||
sys.path.append(parent_path) # add current terminal path to sys.path
|
||||
|
||||
import torch
|
||||
import gym
|
||||
import datetime
|
||||
from A2C.agent import A2C
|
||||
from common.utils import save_results,make_dir,del_empty_dir
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-20 17:43:17
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-03-20 19:36:24
|
||||
LastEditTime: 2021-04-05 11:19:20
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -40,7 +40,7 @@ class A2CConfig:
|
||||
self.eval_eps = 200
|
||||
self.eval_steps = 200
|
||||
self.target_update = 4
|
||||
self.hidden_dim=256
|
||||
self.hidden_dim = 256
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class A2C:
|
||||
def train(cfg,env,agent):
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
actor_critic = ActorCritic(n_states, n_actions, hidden_dim)
|
||||
actor_critic = ActorCritic(n_states, n_actions, cfg.hidden_dim)
|
||||
ac_optimizer = optim.Adam(actor_critic.parameters(), lr=learning_rate)
|
||||
|
||||
all_lengths = []
|
||||
@@ -112,7 +112,7 @@ def train(cfg,env,agent):
|
||||
all_lengths.append(steps)
|
||||
average_lengths.append(np.mean(all_lengths[-10:]))
|
||||
if episode % 10 == 0:
|
||||
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps, average_lengths[-1]))
|
||||
sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps+1, average_lengths[-1]))
|
||||
break
|
||||
|
||||
# compute Q values
|
||||
@@ -154,7 +154,7 @@ def train(cfg,env,agent):
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = A2CConfig
|
||||
cfg = A2CConfig()
|
||||
env = gym.make("CartPole-v0")
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
|
||||
Reference in New Issue
Block a user