更新蘑菇书附书代码

This commit is contained in:
johnjim0816
2022-12-04 20:54:36 +08:00
parent f030fe283d
commit dc8d13a13e
23 changed files with 10784 additions and 0 deletions

View File

@@ -0,0 +1,202 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 定义算法\n",
"\n",
"最基础的策略梯度算法就是REINFORCE算法又称作Monte-Carlo Policy Gradient算法。我们策略优化的目标如下\n",
"\n",
"$$\n",
"J_{\\theta}= \\Psi_{\\pi} \\nabla_\\theta \\log \\pi_\\theta\\left(a_t \\mid s_t\\right)\n",
"$$\n",
"\n",
"其中$\\Psi_{\\pi}$在REINFORCE算法中表示衰减的回报(具体公式见伪代码)也可以用优势来估计也就是我们熟知的A3C算法这个在后面包括GAE算法中都会讲到。\n",
"\n",
"### 1.1. 策略函数设计\n",
"\n",
"既然策略梯度是直接对策略函数进行梯度计算那么策略函数如何设计呢一般来讲有两种设计方式一个是softmax函数另外一个是高斯分布$\\mathbb{N}\\left(\\phi(\\mathbb{s})^{\\mathbb{\\pi}} \\theta, \\sigma^2\\right)$,前者用于离散动作空间,后者多用于连续动作空间。\n",
"\n",
"softmax函数可以表示为\n",
"$$\n",
"\\pi_\\theta(s, a)=\\frac{e^{\\phi(s, a)^{T_\\theta}}}{\\sum_b e^{\\phi(s, b)^{T^T}}}\n",
"$$\n",
"对应的梯度为:\n",
"$$\n",
"\\nabla_\\theta \\log \\pi_\\theta(s, a)=\\phi(s, a)-\\mathbb{E}_{\\pi_\\theta}[\\phi(s,)\n",
"$$\n",
"高斯分布对应的梯度为:\n",
"$$\n",
"\\nabla_\\theta \\log \\pi_\\theta(s, a)=\\frac{\\left(a-\\phi(s)^T \\theta\\right) \\phi(s)}{\\sigma^2}\n",
"$$\n",
"但是对于一些特殊的情况,例如在本次演示中动作维度=2且为离散空间这个时候可以用伯努利分布来实现这种方式其实是不推荐的这里给大家做演示也是为了展现一些特殊情况启发大家一些思考例如BernoulliBinomialGaussian分布之间的关系。简单说来Binomial分布$n = 1$时就是Bernoulli分布$n \\rightarrow \\infty$时就是Gaussian分布。\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.2. 模型设计\n",
"\n",
"前面讲到尽管本次演示是离散空间但是由于动作维度等于2此时就可以用特殊的高斯分布来表示策略函数即伯努利分布。伯努利的分布实际上是用一个概率作为输入然后从中采样动作伯努利采样出来的动作只可能是0或1就像投掷出硬币的正反面。在这种情况下我们的策略模型就需要在MLP的基础上将状态作为输入将动作作为倒数第二层输出并在最后一层增加激活函数来输出对应动作的概率。不清楚激活函数作用的同学可以再看一遍深度学习相关的知识简单来说其作用就是增加神经网络的非线性。既然需要输出对应动作的概率那么输出的值需要处于0-1之间此时sigmoid函数刚好满足我们的需求实现代码参考如下。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"class PGNet(nn.Module):\n",
" def __init__(self, input_dim,output_dim,hidden_dim=128):\n",
" \"\"\" 初始化q网络为全连接网络\n",
" input_dim: 输入的特征数即环境的状态维度\n",
" output_dim: 输出的动作维度\n",
" \"\"\"\n",
" super(PGNet, self).__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim) # 输入层\n",
" self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层\n",
" self.fc3 = nn.Linear(hidden_dim, output_dim) # 输出层\n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = torch.sigmoid(self.fc3(x))\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.3. 更新函数设计\n",
"\n",
"前面提到我们的优化目标也就是策略梯度算法的损失函数如下:\n",
"$$\n",
"J_{\\theta}= \\Psi_{\\pi} \\nabla_\\theta \\log \\pi_\\theta\\left(a_t \\mid s_t\\right)\n",
"$$\n",
"\n",
"我们需要拆开成两个部分$\\Psi_{\\pi}$和$\\nabla_\\theta \\log \\pi_\\theta\\left(a_t \\mid s_t\\right)$分开计算,首先看值函数部分$\\Psi_{\\pi}$在REINFORCE算法中值函数是从当前时刻开始的衰减回报如下\n",
"$$\n",
"G \\leftarrow \\sum_{k=t+1}^{T} \\gamma^{k-1} r_{k}\n",
"$$\n",
"\n",
"这个实际用代码来实现的时候可能有点绕,我们可以倒过来看,在同一回合下,我们的终止时刻是$T$,那么对应的回报$G_T=\\gamma^{T-1}r_T$,而对应的$G_{T-1}=\\gamma^{T-2}r_{T-1}+\\gamma^{T-1}r_T$,在这里代码中我们使用了一个动态规划的技巧,如下:\n",
"```python\n",
"running_add = running_add * self.gamma + reward_pool[i] # running_add初始值为0\n",
"```\n",
"这个公式也是倒过来循环的,第一次的值等于:\n",
"$$\n",
"running\\_add = r_T\n",
"$$\n",
"第二次的值则等于:\n",
"$$\n",
"running\\_add = r_T*\\gamma+r_{T-1}\n",
"$$\n",
"第三次的值等于:\n",
"$$\n",
"running\\_add = (r_T*\\gamma+r_{T-1})*\\gamma+r_{T-2} = r_T*\\gamma^2+r_{T-1}*\\gamma+r_{T-2}\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.distributions import Bernoulli\n",
"from torch.autograd import Variable\n",
"import numpy as np\n",
"\n",
"class PolicyGradient:\n",
" \n",
" def __init__(self, model,memory,cfg):\n",
" self.gamma = cfg['gamma']\n",
" self.device = torch.device(cfg['device']) \n",
" self.memory = memory\n",
" self.policy_net = model.to(self.device)\n",
" self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg['lr'])\n",
"\n",
" def sample_action(self,state):\n",
"\n",
" state = torch.from_numpy(state).float()\n",
" state = Variable(state)\n",
" probs = self.policy_net(state)\n",
" m = Bernoulli(probs) # 伯努利分布\n",
" action = m.sample()\n",
" \n",
" action = action.data.numpy().astype(int)[0] # 转为标量\n",
" return action\n",
" def predict_action(self,state):\n",
"\n",
" state = torch.from_numpy(state).float()\n",
" state = Variable(state)\n",
" probs = self.policy_net(state)\n",
" m = Bernoulli(probs) # 伯努利分布\n",
" action = m.sample()\n",
" action = action.data.numpy().astype(int)[0] # 转为标量\n",
" return action\n",
" \n",
" def update(self):\n",
" state_pool,action_pool,reward_pool= self.memory.sample()\n",
" state_pool,action_pool,reward_pool = list(state_pool),list(action_pool),list(reward_pool)\n",
" # Discount reward\n",
" running_add = 0\n",
" for i in reversed(range(len(reward_pool))):\n",
" if reward_pool[i] == 0:\n",
" running_add = 0\n",
" else:\n",
" running_add = running_add * self.gamma + reward_pool[i]\n",
" reward_pool[i] = running_add\n",
"\n",
" # Normalize reward\n",
" reward_mean = np.mean(reward_pool)\n",
" reward_std = np.std(reward_pool)\n",
" for i in range(len(reward_pool)):\n",
" reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std\n",
"\n",
" # Gradient Desent\n",
" self.optimizer.zero_grad()\n",
"\n",
" for i in range(len(reward_pool)):\n",
" state = state_pool[i]\n",
" action = Variable(torch.FloatTensor([action_pool[i]]))\n",
" reward = reward_pool[i]\n",
" state = Variable(torch.from_numpy(state).float())\n",
" probs = self.policy_net(state)\n",
" m = Bernoulli(probs)\n",
" loss = -m.log_prob(action) * reward # Negtive score function x reward\n",
" # print(loss)\n",
" loss.backward()\n",
" self.optimizer.step()\n",
" self.memory.clear()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.13 ('easyrl')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.7.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "8994a120d39b6e6a2ecc94b4007f5314b68aa69fc88a7f00edf21be39b41f49c"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}