This commit is contained in:
johnjim0816
2021-12-22 11:19:13 +08:00
parent c257313d5b
commit 75df999258
55 changed files with 605 additions and 403 deletions

51
codes/QLearning/train.py Normal file
View File

@@ -0,0 +1,51 @@
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录奖励
ma_rewards = [] # 记录滑动平均奖励
for i_ep in range(cfg.train_eps):
ep_reward = 0 # 记录每个回合的奖励
state = env.reset() # 重置环境,即开始新的回合
while True:
action = agent.choose_action(state) # 根据算法选择一个动作
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
agent.update(state, action, reward, next_state, done) # Q学习算法更新
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1]*0.9+ep_reward*0.1)
else:
ma_rewards.append(ep_reward)
if ()
print("回合数:{}/{},奖励{:.1f}".format(i_ep+1, cfg.train_eps,ep_reward))
print('完成训练!')
return rewards,ma_rewards
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
for item in agent.Q_table.items():
print(item)
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 滑动平均的奖励
for i_ep in range(cfg.test_eps):
ep_reward = 0 # 记录每个episode的reward
state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合)
while True:
action = agent.predict(state) # 根据算法选择一个动作
next_state, reward, done, _ = env.step(action) # 与环境进行一个交互
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1]*0.9+ep_reward*0.1)
else:
ma_rewards.append(ep_reward)
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
print('完成测试!')
return rewards,ma_rewards