This commit is contained in:
johnjim0816
2021-12-22 11:19:13 +08:00
parent c257313d5b
commit 75df999258
55 changed files with 605 additions and 403 deletions

View File

@@ -42,7 +42,7 @@ def test(cfg, env, agent):
print(f'环境:{cfg.env_name}, 算法:{cfg.algo}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.eval_eps):
for i_ep in range(cfg.test_eps):
state = env.reset()
done = False
ep_reward = 0
@@ -59,6 +59,6 @@ def test(cfg, env, agent):
ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print(f"回合:{i_ep+1}/{cfg.eval_eps},奖励:{ep_reward:.1f}")
print(f"回合:{i_ep+1}/{cfg.test_eps},奖励:{ep_reward:.1f}")
print('完成测试!')
return rewards, ma_rewards