update
This commit is contained in:
@@ -123,14 +123,15 @@ def train(cfg,envs):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print('Finish training!')
|
||||
return test_rewards, test_ma_rewards
|
||||
return {'rewards':test_rewards,'ma_rewards':test_ma_rewards}
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
envs = [make_envs(cfg.env_name) for i in range(cfg.n_envs)]
|
||||
envs = SubprocVecEnv(envs)
|
||||
# training
|
||||
rewards,ma_rewards = train(cfg,envs)
|
||||
res_dic = train(cfg,envs)
|
||||
make_dir(cfg.result_path,cfg.model_path)
|
||||
save_args(cfg)
|
||||
save_results(rewards, ma_rewards, tag='train', path=cfg.result_path) # 保存结果
|
||||
plot_rewards(rewards, ma_rewards, cfg, tag="train") # 画出结果
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="train") # 画出结果
|
||||
|
||||
Reference in New Issue
Block a user