This commit is contained in:
JohnJim0816
2021-03-28 11:18:52 +08:00
parent 2df8d965d2
commit 6e4d966e1f
56 changed files with 497 additions and 165 deletions

View File

@@ -70,8 +70,8 @@ def sarsa_train(cfg,env,agent):
if __name__ == "__main__":
sarsa_cfg = SarsaConfig()
env = RacetrackEnv()
n_actions=9
agent = Sarsa(n_actions,sarsa_cfg)
action_dim=9
agent = Sarsa(action_dim,sarsa_cfg)
rewards,ma_rewards = sarsa_train(sarsa_cfg,env,agent)
agent.save(path=SAVED_MODEL_PATH)
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)