Files
easy-rl/codes/DoubleDQN/train.py
johnjim0816 3b712e8815 update codes
2021-12-21 20:14:13 +08:00

74 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2021-11-07 18:10:37
LastEditor: JiangJi
LastEditTime: 2021-11-19 18:34:05
Discription:
'''
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.train_eps):
ep_reward = 0 # 记录一回合内的奖励
state = env.reset() # 重置环境,返回初始状态
while True:
action = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
ep_reward += reward
agent.memory.push(state, action, reward, next_state, done)
state = next_state
agent.update()
if done:
break
if i_ep % cfg.target_update == 0:
agent.target_net.load_state_dict(agent.policy_net.state_dict())
if (i_ep+1)%10 == 0:
print(f'回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward}')
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print('完成训练!')
return rewards,ma_rewards
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
# 由于测试不需要使用epsilon-greedy策略所以相应的值设置为0
cfg.epsilon_start = 0.0 # e-greedy策略中初始epsilon
cfg.epsilon_end = 0.0 # e-greedy策略中的终止epsilon
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.test_eps):
state = env.reset()
ep_reward = 0
while True:
action = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
state = next_state
ep_reward += reward
if done:
break
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1]*0.9+ep_reward*0.1)
else:
ma_rewards.append(ep_reward)
print(f"回合:{i_ep+1}/{cfg.test_eps},奖励:{ep_reward:.1f}")
print('完成测试!')
return rewards,ma_rewards