update
This commit is contained in:
@@ -10,7 +10,7 @@ import torch
|
||||
import datetime
|
||||
import numpy as np
|
||||
import argparse
|
||||
from common.utils import save_results_1, make_dir
|
||||
from common.utils import save_results, make_dir
|
||||
from common.utils import plot_rewards,save_args
|
||||
from dqn import DQN
|
||||
|
||||
@@ -95,8 +95,8 @@ def train(cfg, env, agent):
|
||||
|
||||
|
||||
def test(cfg, env, agent):
|
||||
print('开始测试!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
print('Start testing!')
|
||||
print(f'Env:{cfg.env_name}, A{cfg.algo_name}, 设备:{cfg.device}')
|
||||
############# 由于测试不需要使用epsilon-greedy策略,所以相应的值设置为0 ###############
|
||||
cfg.epsilon_start = 0.0 # e-greedy策略中初始epsilon
|
||||
cfg.epsilon_end = 0.0 # e-greedy策略中的终止epsilon
|
||||
@@ -123,7 +123,7 @@ def test(cfg, env, agent):
|
||||
else:
|
||||
ma_rewards.append(ep_reward)
|
||||
print(f'Episode:{i_ep+1}/{cfg.test_eps}, Reward:{ep_reward:.2f}, Step:{ep_step:.2f}')
|
||||
print('完成测试!')
|
||||
print('Finish testing')
|
||||
env.close()
|
||||
return {'rewards':rewards,'ma_rewards':ma_rewards,'steps':steps}
|
||||
|
||||
@@ -133,16 +133,16 @@ if __name__ == "__main__":
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path) # 创建保存结果和模型路径的文件夹
|
||||
save_args(cfg)
|
||||
agent.save(path=cfg.model_path) # 保存模型
|
||||
save_results_1(res_dic, tag='train',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="train") # 画出结果
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results_1(res_dic, tag='test',
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'],cfg, tag="test") # 画出结果
|
||||
|
||||
Reference in New Issue
Block a user