update
This commit is contained in:
@@ -80,9 +80,9 @@ if __name__ == "__main__":
|
||||
cfg = PGConfig()
|
||||
env = gym.make('CartPole-v0') # 可google为什么unwrapped gym,此处一般不需要
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
agent = PolicyGradient(n_states,cfg)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.n
|
||||
agent = PolicyGradient(state_dim,cfg)
|
||||
rewards, ma_rewards = train(cfg,env,agent)
|
||||
agent.save_model(SAVED_MODEL_PATH)
|
||||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
|
||||
|
||||
Reference in New Issue
Block a user