This commit is contained in:
johnjim0816
2022-07-21 22:12:19 +08:00
parent 6b3121fcff
commit e9b3e92141
21 changed files with 99 additions and 85 deletions

View File

@@ -5,7 +5,7 @@ Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2021-11-07 18:10:37
LastEditor: JiangJi
LastEditTime: 2022-07-21 00:08:38
LastEditTime: 2022-07-21 21:52:31
Discription:
'''
import sys,os
@@ -86,7 +86,7 @@ def train(cfg,env,agent):
else:
ma_rewards.append(ep_reward)
print('Finish training!')
return rewards,ma_rewards
return {'rewards':rewards,'ma_rewards':ma_rewards}
def test(cfg,env,agent):
print('Start testing')
@@ -115,22 +115,24 @@ def test(cfg,env,agent):
ma_rewards.append(ep_reward)
print(f"Epside:{i_ep+1}/{cfg.test_eps}, Reward:{ep_reward:.1f}")
print('Finish testing!')
return rewards,ma_rewards
return {'rewards':rewards,'ma_rewards':ma_rewards}
if __name__ == "__main__":
cfg = get_args()
print(cfg.device)
# training
env,agent = env_agent_config(cfg,seed=1)
rewards, ma_rewards = train(cfg, env, agent)
res_dic = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path)
save_args(cfg)
agent.save(path=cfg.model_path)
save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="train")
save_results(res_dic, tag='train',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="train")
# testing
env,agent = env_agent_config(cfg,seed=10)
agent.load(path=cfg.model_path)
rewards,ma_rewards = test(cfg,env,agent)
save_results(rewards,ma_rewards,tag = 'test',path = cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="test")
res_dic = test(cfg,env,agent)
save_results(res_dic, tag='test',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="test")