update
This commit is contained in:
@@ -70,8 +70,8 @@ def sarsa_train(cfg,env,agent):
|
||||
if __name__ == "__main__":
|
||||
sarsa_cfg = SarsaConfig()
|
||||
env = RacetrackEnv()
|
||||
n_actions=9
|
||||
agent = Sarsa(n_actions,sarsa_cfg)
|
||||
action_dim=9
|
||||
agent = Sarsa(action_dim,sarsa_cfg)
|
||||
rewards,ma_rewards = sarsa_train(sarsa_cfg,env,agent)
|
||||
agent.save(path=SAVED_MODEL_PATH)
|
||||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
|
||||
|
||||
Reference in New Issue
Block a user