From 900d5689c9628c8338569eeeccea017f424d1adc Mon Sep 17 00:00:00 2001 From: johnjim0816 Date: Wed, 14 Jul 2021 17:37:24 +0800 Subject: [PATCH] update --- codes/MonteCarlo/task0_train.py | 8 ++++---- codes/Sarsa/task0_train.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/codes/MonteCarlo/task0_train.py b/codes/MonteCarlo/task0_train.py index dae0c95..a0a3f32 100644 --- a/codes/MonteCarlo/task0_train.py +++ b/codes/MonteCarlo/task0_train.py @@ -5,7 +5,7 @@ Author: John Email: johnjim0816@gmail.com Date: 2021-03-11 14:26:44 LastEditor: John -LastEditTime: 2021-05-05 17:27:50 +LastEditTime: 2021-07-14 17:25:16 Discription: Environment: ''' @@ -48,7 +48,7 @@ def env_agent_config(cfg,seed=1): return env,agent def train(cfg, env, agent): - print('Start to eval !') + print('Start to training !') print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}') rewards = [] ma_rewards = [] # moving average rewards @@ -102,7 +102,7 @@ def eval(cfg, env, agent): if __name__ == "__main__": cfg = MCConfig() - # train + ''' train ''' env,agent = env_agent_config(cfg,seed=1) rewards, ma_rewards = train(cfg, env, agent) make_dir(cfg.result_path, cfg.model_path) @@ -110,7 +110,7 @@ if __name__ == "__main__": save_results(rewards, ma_rewards, tag='train', path=cfg.result_path) plot_rewards(rewards, ma_rewards, tag="train", algo=cfg.algo, path=cfg.result_path) - # eval + ''' eval ''' env,agent = env_agent_config(cfg,seed=10) agent.load(path=cfg.model_path) rewards,ma_rewards = eval(cfg,env,agent) diff --git a/codes/Sarsa/task0_train.py b/codes/Sarsa/task0_train.py index d21db17..7fad1ab 100644 --- a/codes/Sarsa/task0_train.py +++ b/codes/Sarsa/task0_train.py @@ -5,7 +5,7 @@ Author: John Email: johnjim0816@gmail.com Date: 2021-03-11 17:59:16 LastEditor: John -LastEditTime: 2021-05-06 17:12:37 +LastEditTime: 2021-07-14 17:27:40 Discription: Environment: ''' @@ -26,8 +26,8 @@ class SarsaConfig: ''' parameters for Sarsa ''' def __init__(self): - self.algo = 'Qlearning' - self.env = 'CliffWalking-v0' # 0 up, 1 right, 2 down, 3 left + self.algo = 'Sarsa' + self.env = 'RacetrackEnv' # 0 up, 1 right, 2 down, 3 left self.result_path = curr_path+"/outputs/" +self.env+'/'+curr_time+'/results/' # path to save results self.model_path = curr_path+"/outputs/" +self.env+'/'+curr_time+'/models/' # path to save models self.train_eps = 200