Files
easy-rl/codes/DQN
JohnJim0816 747f3238c0 update
2021-05-04 15:30:01 +08:00
..
2021-03-23 19:42:33 +08:00
2021-05-03 23:00:01 +08:00
2021-05-03 23:00:01 +08:00
2021-03-31 15:37:09 +08:00
2021-05-03 23:00:01 +08:00
2021-05-04 15:30:01 +08:00

DQN

#TODO

原理简介

DQN是Q-leanning算法的优化和延伸Q-leaning中使用有限的Q表存储值的信息而DQN中则用神经网络替代Q表存储信息这样更适用于高维的情况相关知识基础可参考datawhale李宏毅笔记-Q学习

论文方面主要可以参考两篇一篇就是2013年谷歌DeepMind团队的Playing Atari with Deep Reinforcement Learning一篇是也是他们团队后来在Nature杂志上发表的Human-level control through deep reinforcement learning。后者在算法层面增加target q-net也可以叫做Nature DQN。

Nature DQN使用了两个Q网络一个当前Q网络𝑄用来选择动作更新模型参数另一个目标Q网络𝑄用于计算目标Q值。目标Q网络的网络参数不需要迭代更新而是每隔一段时间从当前Q网络𝑄复制过来即延时更新这样可以减少目标Q值和当前的Q值相关性。

要注意的是两个Q网络的结构是一模一样的。这样才可以复制网络参数。Nature DQN和Playing Atari with Deep Reinforcement Learning相比除了用一个新的相同结构的目标Q网络来计算目标Q值以外其余部分基本是完全相同的。细节也可参考强化学习Deep Q-Learning进阶之Nature DQN

https://blog.csdn.net/JohnJim0/article/details/109557173)

伪代码

img

代码实战

RL接口

首先是强化学习训练的基本接口,即通用的训练模式:

for i_episode in range(MAX_EPISODES):
	state = env.reset() # reset环境状态
	for i_step in range(MAX_STEPS):
		 action = agent.choose_action(state) # 根据当前环境state选择action
         next_state, reward, done, _ = env.step(action) # 更新环境参数
         agent.memory.push(state, action, reward, next_state, done) # 将state等这些transition存入memory
         agent.update() # 每步更新网络
         state = next_state # 跳转到下一个状态
         if done:
         	break        

如上首先需要循环多个episode训练在每个episode中首先需要重置环境然后开始探索每个episode加一个MAX_STEPS(也可以使用while not done, 加这个max_steps有时是因为比如gym环境训练目标就是在200个step下达到200的reward),接下来的流程如下:

  1. agent选择动作
  2. 环境根据agent的动作反馈出新的state和reward
  3. agent进行更新如有memory就会将transition(包含staterewardaction等)存入memory中
  4. 跳转到下一个状态 如果提前done了就跳出for循环进行下一个episode的训练。

两个Q网络

前面讲了Nature DQN中有两个Q网络一个是policy_net一个是延时更新的target_net两个网络的结构是一模一样的如下(见model.py)

import torch.nn as nn
import torch.nn.functional as F

class FCN(nn.Module):
    def __init__(self, state_dim=4, action_dim=18):
        """ 初始化q网络为全连接网络
            state_dim: 输入的feature即环境的state数目
            action_dim: 输出的action总个数
        """
        super(FCN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128) # 输入层
        self.fc2 = nn.Linear(128, 128) # 隐藏层
        self.fc3 = nn.Linear(128, action_dim) # 输出层
        
    def forward(self, x):
        # 各层对应的激活函数
        x = F.relu(self.fc1(x)) 
        x = F.relu(self.fc2(x))
        return self.fc3(x)

输入为state输出为action注意根据state和action的维度调整隐藏层的层数这里设为128

agent.py中我们定义强化学习算法,包括choose_actionupdate两个主要函数,初始化中:

self.policy_net = FCN(state_dim, action_dim).to(self.device)
self.target_net = FCN(state_dim, action_dim).to(self.device)
# target_net的初始模型参数完全复制policy_net
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()  # 不启用 BatchNormalization 和 Dropout
# 可查parameters()与state_dict()的区别前者require_grad=True

可以看到policy_net跟target_net结构和初始参数一样但在更新的时候target是每隔一段episode更新的如下(见main.py)

# 更新target network复制DQN中的所有weights and biases
if i_episode % cfg.target_update == 0:
	agent.target_net.load_state_dict(agent.policy_net.state_dict())

可以调整cfg.target_update注意该变量不要调得太大否则会收敛很慢我们最后保存的模型也是这个target_net如下(见agent.py)

def save_model(self,path):
	torch.save(self.target_net.state_dict(), path)

Replay Memory

然后就是Replay Memory了如下(见memory.py)

import random
import numpy as np

class ReplayBuffer:
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done =  zip(*batch)
        return state, action, reward, next_state, done
    
    def __len__(self):
        return len(self.buffer)

其实比较简单主要包括push和sample两个步骤push是将transitions放到memory中sample是从memory随机抽取一些transition。

最后结果如下:

rewards_curve_train

参考

with torch.no_grad()