update
39
codes/DoubleDQN/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
食用本篇之前,需要有DQN算法的基础,参考[DQN算法实战](../DQN)。
|
||||
|
||||
## 原理简介
|
||||
|
||||
Double-DQN是2016年提出的算法,灵感源自2010年的Double-Qlearning,可参考论文[Deep Reinforcement Learning with Double Q-learning](https://arxiv.org/abs/1509.06461)。
|
||||
跟Nature DQN一样,Double-DQN也用了两个网络,一个当前网络(对应用$Q$表示),一个目标网络(对应一般用$Q'$表示,为方便区分,以下用$Q_{tar}$代替)。我们先回忆一下,对于非终止状态,目标$Q_{tar}$值计算如下
|
||||

|
||||
|
||||
而在Double-DQN中,不再是直接从目标$Q_{tar}$网络中选择各个动作中的最大$Q_{tar}$值,而是先从当前$Q$网络选择$Q$值最大对应的动作,然后代入到目标网络中计算对应的值:
|
||||

|
||||
Double-DQN的好处是Nature DQN中使用max虽然可以快速让Q值向可能的优化目标靠拢,但是很容易过犹不及,导致过度估计(Over Estimation),所谓过度估计就是最终我们得到的算法模型有很大的偏差(bias)。为了解决这个问题, DDQN通过解耦目标Q值动作的选择和目标Q值的计算这两步,来达到消除过度估计的问题,感兴趣可以阅读原论文。
|
||||
|
||||
伪代码如下:
|
||||

|
||||
当然也可以两个网络可以同时为当前网络和目标网络,如下:
|
||||

|
||||
或者这样更好理解如何同时为当前网络和目标网络:
|
||||

|
||||
|
||||
## 代码实战
|
||||
完整程序见[github](https://github.com/JohnJim0816/reinforcement-learning-tutorials/tree/master/DoubleDQN)。结合上面的原理,其实Double DQN改进来很简单,基本只需要在```update```中修改几行代码,如下:
|
||||
```python
|
||||
'''以下是Nature DQN的q_target计算方式
|
||||
next_q_state_value = self.target_net(
|
||||
next_state_batch).max(1)[0].detach() # # 计算所有next states的Q'(s_{t+1})的最大值,Q'为目标网络的q函数,比如tensor([ 0.0060, -0.0171,...,])
|
||||
#计算 q_target
|
||||
#对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
|
||||
q_target = reward_batch + self.gamma * next_q_state_value * (1-done_batch[0])
|
||||
'''
|
||||
'''以下是Double DQNq_target计算方式,与NatureDQN稍有不同'''
|
||||
next_target_values = self.target_net(
|
||||
next_state_batch)
|
||||
#选出Q(s_t‘, a)对应的action,代入到next_target_values获得target net对应的next_q_value,即Q’(s_t|a=argmax Q(s_t‘, a))
|
||||
next_target_q_value = next_target_values.gather(1, torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)
|
||||
q_target = reward_batch + self.gamma * next_target_q_value * (1-done_batch[0])
|
||||
```
|
||||
reward变化结果如下:
|
||||

|
||||
其中下边蓝色和红色分别表示Double DQN和Nature DQN在训练中的reward变化图,而上面蓝色和绿色则表示Double DQN和Nature DQN在测试中的reward变化图。
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:50:49
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-03-13 15:01:27
|
||||
LastEditTime: 2021-03-28 11:07:35
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -16,16 +16,15 @@ LastEditTime: 2021-03-13 15:01:27
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
from common.memory import ReplayBuffer
|
||||
from common.model import MLP2
|
||||
from common.model import MLP
|
||||
class DoubleDQN:
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.action_dim = action_dim # 总的动作个数
|
||||
self.device = cfg.device # 设备,cpu或gpu等
|
||||
self.gamma = cfg.gamma
|
||||
# e-greedy策略相关参数
|
||||
@@ -34,8 +33,8 @@ class DoubleDQN:
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.batch_size = cfg.batch_size
|
||||
self.policy_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
# target_net的初始模型参数完全复制policy_net
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval() # 不启用 BatchNormalization 和 Dropout
|
||||
@@ -63,7 +62,7 @@ class DoubleDQN:
|
||||
# 所以tensor.max(1)[1]返回最大值对应的下标,即action
|
||||
action = q_value.max(1)[1].item()
|
||||
else:
|
||||
action = random.randrange(self.n_actions)
|
||||
action = random.randrange(self.action_dim)
|
||||
return action
|
||||
def update(self):
|
||||
|
||||
|
||||
BIN
codes/DoubleDQN/assets/20201222145725907.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
codes/DoubleDQN/assets/20201222150225327.png
Normal file
|
After Width: | Height: | Size: 24 KiB |
|
After Width: | Height: | Size: 105 KiB |
|
After Width: | Height: | Size: 74 KiB |
|
After Width: | Height: | Size: 185 KiB |
|
After Width: | Height: | Size: 75 KiB |
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:48:57
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-03-17 20:11:19
|
||||
LastEditTime: 2021-03-28 11:05:14
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -32,7 +32,7 @@ if not os.path.exists(RESULT_PATH):
|
||||
|
||||
class DoubleDQNConfig:
|
||||
def __init__(self):
|
||||
self.algo = "Double DQN" # 算法名称
|
||||
self.algo = "Double DQN" # name of algo
|
||||
self.gamma = 0.99
|
||||
self.epsilon_start = 0.9 # e-greedy策略的初始epsilon
|
||||
self.epsilon_end = 0.01
|
||||
@@ -40,7 +40,7 @@ class DoubleDQNConfig:
|
||||
self.lr = 0.01 # 学习率
|
||||
self.memory_capacity = 10000 # Replay Memory容量
|
||||
self.batch_size = 128
|
||||
self.train_eps = 250 # 训练的episode数目
|
||||
self.train_eps = 300 # 训练的episode数目
|
||||
self.train_steps = 200 # 训练每个episode的最大长度
|
||||
self.target_update = 2 # target net的更新频率
|
||||
self.eval_eps = 20 # 测试的episode数目
|
||||
@@ -84,9 +84,9 @@ if __name__ == "__main__":
|
||||
cfg = DoubleDQNConfig()
|
||||
env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym,此处一般不需要
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
agent = DoubleDQN(n_states,n_actions,cfg)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.n
|
||||
agent = DoubleDQN(state_dim,action_dim,cfg)
|
||||
rewards,ma_rewards = train(cfg,env,agent)
|
||||
agent.save(path=SAVED_MODEL_PATH)
|
||||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
|
||||
|
||||
BIN
codes/DoubleDQN/results/20210328-110516/ma_rewards_train.npy
Normal file
BIN
codes/DoubleDQN/results/20210328-110516/rewards_curve_train.png
Normal file
|
After Width: | Height: | Size: 55 KiB |