This commit is contained in:
johnjim0816
2022-07-21 22:12:19 +08:00
parent 6b3121fcff
commit e9b3e92141
21 changed files with 99 additions and 85 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-11 20:58:21
@LastEditor: John
LastEditTime: 2022-07-21 00:05:41
LastEditTime: 2022-07-21 21:51:34
@Discription:
@Environment: python 3.7.7
'''
@@ -86,7 +86,7 @@ def train(cfg, env, agent):
else:
ma_rewards.append(ep_reward)
print('Finish training!')
return rewards, ma_rewards
return {'rewards':rewards,'ma_rewards':ma_rewards}
def test(cfg, env, agent):
print('Start testing')
@@ -111,21 +111,23 @@ def test(cfg, env, agent):
ma_rewards.append(ep_reward)
print(f"Epside:{i_ep+1}/{cfg.test_eps}, Reward:{ep_reward:.1f}")
print('Finish testing!')
return rewards, ma_rewards
return {'rewards':rewards,'ma_rewards':ma_rewards}
if __name__ == "__main__":
cfg = get_args()
# training
env,agent = env_agent_config(cfg,seed=1)
rewards, ma_rewards = train(cfg, env, agent)
res_dic = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path)
save_args(cfg)
agent.save(path=cfg.model_path)
save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="train")
save_results(res_dic, tag='train',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="train")
# testing
env,agent = env_agent_config(cfg,seed=10)
agent.load(path=cfg.model_path)
rewards,ma_rewards = test(cfg,env,agent)
save_results(rewards,ma_rewards,tag = 'test',path = cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="test")
res_dic = test(cfg,env,agent)
save_results(res_dic, tag='test',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="test")