Files
easy-rl/projects/parl_tutorials/DQN.ipynb
2022-11-06 12:15:36 +08:00

539 lines
93 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1、定义算法\n",
"相比于Q learningDQN本质上是为了适应更为复杂的环境并且经过不断的改良迭代到了Nature DQN即Volodymyr Mnih发表的Nature论文这里才算是基本完善。DQN主要改动的点有三个\n",
"* 使用深度神经网络替代原来的Q表这个很容易理解原因\n",
"* 使用了经验回放Replay Buffer这个好处有很多一个是使用一堆历史数据去训练比之前用一次就扔掉好多了大大提高样本效率另外一个是面试常提到的减少样本之间的相关性原则上获取经验跟学习阶段是分开的原来时序的训练数据有可能是不稳定的打乱之后再学习有助于提高训练的稳定性跟深度学习中划分训练测试集时打乱样本是一个道理。\n",
"* 使用了两个网络即策略网络和目标网络每隔若干步才把每步更新的策略网络参数复制给目标网络这样做也是为了训练的稳定避免Q值的估计发散。想象一下如果当前有个transition这个Q learning中提过的一定要记住样本导致对Q值进行了较差的过估计如果接下来从经验回放中提取到的样本正好连续几个都这样的很有可能导致Q值的发散它的青春小鸟一去不回来了。再打个比方我们玩RPG或者闯关类游戏有些人为了破纪录经常Save和Load只要我出了错我不满意我就加载之前的存档假设不允许加载呢就像DQN算法一样训练过程中会退不了这时候是不是搞两个档一个档每帧都存一下另外一个档打了不错的结果再存也就是若干个间隔再存一下到最后用间隔若干步数再存的档一般都比每帧都存的档好些呢。当然你也可以再搞更多个档也就是DQN增加多个目标网络但是对于DQN则没有多大必要多几个网络效果不见得会好很多。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.1 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m[09-26 17:18:11 MainThread @utils.py:73]\u001b[0m paddlepaddle version: 2.3.2.\n"
]
}
],
"source": [
"\n",
"import paddle\n",
"import paddle.nn as nn\n",
"import paddle.nn.functional as F\n",
"import parl\n",
"\n",
"class MLP(parl.Model):\n",
" \"\"\" Linear network to solve Cartpole problem.\n",
" Args:\n",
" input_dim (int): Dimension of observation space.\n",
" output_dim (int): Dimension of action space.\n",
" \"\"\"\n",
"\n",
" def __init__(self, input_dim, output_dim):\n",
" super(MLP, self).__init__()\n",
" hidden_dim1 = 256\n",
" hidden_dim2 = 256\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim1)\n",
" self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)\n",
" self.fc3 = nn.Linear(hidden_dim2, output_dim)\n",
"\n",
" def forward(self, state):\n",
" x = F.relu(self.fc1(state))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.2 定义经验回放"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from collections import deque\n",
"class ReplayBuffer:\n",
" def __init__(self, capacity: int) -> None:\n",
" self.capacity = capacity\n",
" self.buffer = deque(maxlen=self.capacity)\n",
" def push(self,transitions):\n",
" '''_summary_\n",
" Args:\n",
" trainsitions (tuple): _description_\n",
" '''\n",
" self.buffer.append(transitions)\n",
" def sample(self, batch_size: int, sequential: bool = False):\n",
" if batch_size > len(self.buffer):\n",
" batch_size = len(self.buffer)\n",
" if sequential: # sequential sampling\n",
" rand = random.randint(0, len(self.buffer) - batch_size)\n",
" batch = [self.buffer[i] for i in range(rand, rand + batch_size)]\n",
" return zip(*batch)\n",
" else:\n",
" batch = random.sample(self.buffer, batch_size)\n",
" return zip(*batch)\n",
" def clear(self):\n",
" self.buffer.clear()\n",
" def __len__(self):\n",
" return len(self.buffer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.3 定义智能体"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from random import random\n",
"import parl\n",
"import paddle\n",
"import math\n",
"import numpy as np\n",
"\n",
"\n",
"class DQNAgent(parl.Agent):\n",
" \"\"\"Agent of DQN.\n",
" \"\"\"\n",
"\n",
" def __init__(self, algorithm, memory,cfg):\n",
" super(DQNAgent, self).__init__(algorithm)\n",
" self.n_actions = cfg['n_actions']\n",
" self.epsilon = cfg['epsilon_start']\n",
" self.sample_count = 0 \n",
" self.epsilon_start = cfg['epsilon_start']\n",
" self.epsilon_end = cfg['epsilon_end']\n",
" self.epsilon_decay = cfg['epsilon_decay']\n",
" self.batch_size = cfg['batch_size']\n",
" self.global_step = 0\n",
" self.update_target_steps = 600\n",
" self.memory = memory # replay buffer\n",
"\n",
" def sample_action(self, state):\n",
" self.sample_count += 1\n",
" # epsilon must decay(linear,exponential and etc.) for balancing exploration and exploitation\n",
" self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \\\n",
" math.exp(-1. * self.sample_count / self.epsilon_decay) \n",
" if random.random() < self.epsilon:\n",
" action = np.random.randint(self.n_actions)\n",
" else:\n",
" action = self.predict_action(state)\n",
" return action\n",
"\n",
" def predict_action(self, state):\n",
" state = paddle.to_tensor(state , dtype='float32')\n",
" q_values = self.alg.predict(state) # self.alg 是自带的算法\n",
" action = q_values.argmax().numpy()[0]\n",
" return action\n",
"\n",
" def update(self):\n",
" \"\"\"Update model with an episode data\n",
" Args:\n",
" obs(np.float32): shape of (batch_size, obs_dim)\n",
" act(np.int32): shape of (batch_size)\n",
" reward(np.float32): shape of (batch_size)\n",
" next_obs(np.float32): shape of (batch_size, obs_dim)\n",
" terminal(np.float32): shape of (batch_size)\n",
" Returns:\n",
" loss(float)\n",
" \"\"\"\n",
" if len(self.memory) < self.batch_size: # when transitions in memory donot meet a batch, not update\n",
" return\n",
" \n",
" if self.global_step % self.update_target_steps == 0:\n",
" self.alg.sync_target()\n",
" self.global_step += 1\n",
" state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(\n",
" self.batch_size)\n",
" action_batch = np.expand_dims(action_batch, axis=-1)\n",
" reward_batch = np.expand_dims(reward_batch, axis=-1)\n",
" done_batch = np.expand_dims(done_batch, axis=-1)\n",
"\n",
" state_batch = paddle.to_tensor(state_batch, dtype='float32')\n",
" action_batch = paddle.to_tensor(action_batch, dtype='int32')\n",
" reward_batch = paddle.to_tensor(reward_batch, dtype='float32')\n",
" next_state_batch = paddle.to_tensor(next_state_batch, dtype='float32')\n",
" done_batch = paddle.to_tensor(done_batch, dtype='float32')\n",
" loss = self.alg.learn(state_batch, action_batch, reward_batch, next_state_batch, done_batch) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 定义训练"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def train(cfg, env, agent):\n",
" ''' 训练\n",
" '''\n",
" print(f\"开始训练!\")\n",
" print(f\"环境:{cfg['env_name']},算法:{cfg['algo_name']},设备:{cfg['device']}\")\n",
" rewards = [] # record rewards for all episodes\n",
" steps = []\n",
" for i_ep in range(cfg[\"train_eps\"]):\n",
" ep_reward = 0 # reward per episode\n",
" ep_step = 0\n",
" state = env.reset() # reset and obtain initial state\n",
" for _ in range(cfg['ep_max_steps']):\n",
" ep_step += 1\n",
" action = agent.sample_action(state) # sample action\n",
" next_state, reward, done, _ = env.step(action) # update env and return transitions\n",
" agent.memory.push((state, action, reward,next_state, done)) # save transitions\n",
" state = next_state # update next state for env\n",
" agent.update() # update agent\n",
" ep_reward += reward #\n",
" if done:\n",
" break\n",
" steps.append(ep_step)\n",
" rewards.append(ep_reward)\n",
" if (i_ep + 1) % 10 == 0:\n",
" print(f\"回合:{i_ep+1}/{cfg['train_eps']},奖励:{ep_reward:.2f}Epislon: {agent.epsilon:.3f}\")\n",
" print(\"完成训练!\")\n",
" env.close()\n",
" res_dic = {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}\n",
" return res_dic\n",
"\n",
"def test(cfg, env, agent):\n",
" print(\"开始测试!\")\n",
" print(f\"环境:{cfg['env_name']},算法:{cfg['algo_name']},设备:{cfg['device']}\")\n",
" rewards = [] # record rewards for all episodes\n",
" steps = []\n",
" for i_ep in range(cfg['test_eps']):\n",
" ep_reward = 0 # reward per episode\n",
" ep_step = 0\n",
" state = env.reset() # reset and obtain initial state\n",
" for _ in range(cfg['ep_max_steps']):\n",
" ep_step+=1\n",
" action = agent.predict_action(state) # predict action\n",
" next_state, reward, done, _ = env.step(action) \n",
" state = next_state \n",
" ep_reward += reward \n",
" if done:\n",
" break\n",
" steps.append(ep_step)\n",
" rewards.append(ep_reward)\n",
" print(f\"回合:{i_ep+1}/{cfg['test_eps']},奖励:{ep_reward:.2f}\")\n",
" print(\"完成测试!\")\n",
" env.close()\n",
" return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 定义环境"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jj/opt/anaconda3/envs/easyrl/lib/python3.7/site-packages/gym/envs/registration.py:250: DeprecationWarning: SelectableGroups dict interface is deprecated. Use select.\n",
" for plugin in metadata.entry_points().get(entry_point, []):\n"
]
}
],
"source": [
"import gym\n",
"import paddle\n",
"import numpy as np\n",
"import random\n",
"import os\n",
"from parl.algorithms import DQN\n",
"def all_seed(env,seed = 1):\n",
" ''' omnipotent seed for RL, attention the position of seed function, you'd better put it just following the env create function\n",
" Args:\n",
" env (_type_): \n",
" seed (int, optional): _description_. Defaults to 1.\n",
" '''\n",
" print(f\"seed = {seed}\")\n",
" env.seed(seed) # env config\n",
" np.random.seed(seed)\n",
" random.seed(seed)\n",
" paddle.seed(seed)\n",
" \n",
"def env_agent_config(cfg):\n",
" ''' create env and agent\n",
" '''\n",
" env = gym.make(cfg['env_name']) \n",
" if cfg['seed'] !=0: # set random seed\n",
" all_seed(env,seed=cfg[\"seed\"]) \n",
" n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'n'))\n",
" n_actions = env.action_space.n # action dimension\n",
" print(f\"n_states: {n_states}, n_actions: {n_actions}\")\n",
" cfg.update({\"n_states\":n_states,\"n_actions\":n_actions}) # update to cfg paramters\n",
" model = MLP(n_states,n_actions)\n",
" algo = DQN(model, gamma=cfg['gamma'], lr=cfg['lr'])\n",
" memory = ReplayBuffer(cfg[\"memory_capacity\"]) # replay buffer\n",
" agent = DQNAgent(algo,memory,cfg) # create agent\n",
" return env, agent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 设置参数"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jj/opt/anaconda3/envs/easyrl/lib/python3.7/site-packages/seaborn/rcmod.py:82: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
" if LooseVersion(mpl.__version__) >= \"3.0\":\n",
"/Users/jj/opt/anaconda3/envs/easyrl/lib/python3.7/site-packages/setuptools/_distutils/version.py:351: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
" other = LooseVersion(other)\n"
]
}
],
"source": [
"import argparse\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"def get_args():\n",
" \"\"\" \n",
" \"\"\"\n",
" parser = argparse.ArgumentParser(description=\"hyperparameters\") \n",
" parser.add_argument('--algo_name',default='DQN',type=str,help=\"name of algorithm\")\n",
" parser.add_argument('--env_name',default='CartPole-v0',type=str,help=\"name of environment\")\n",
" parser.add_argument('--train_eps',default=200,type=int,help=\"episodes of training\") # 训练的回合数\n",
" parser.add_argument('--test_eps',default=20,type=int,help=\"episodes of testing\") # 测试的回合数\n",
" parser.add_argument('--ep_max_steps',default = 100000,type=int,help=\"steps per episode, much larger value can simulate infinite steps\")\n",
" parser.add_argument('--gamma',default=0.99,type=float,help=\"discounted factor\") # 折扣因子\n",
" parser.add_argument('--epsilon_start',default=0.95,type=float,help=\"initial value of epsilon\") # e-greedy策略中初始epsilon\n",
" parser.add_argument('--epsilon_end',default=0.01,type=float,help=\"final value of epsilon\") # e-greedy策略中的终止epsilon\n",
" parser.add_argument('--epsilon_decay',default=200,type=int,help=\"decay rate of epsilon\") # e-greedy策略中epsilon的衰减率\n",
" parser.add_argument('--memory_capacity',default=200000,type=int) # replay memory的容量\n",
" parser.add_argument('--memory_warmup_size',default=200,type=int) # replay memory的预热容量\n",
" parser.add_argument('--batch_size',default=64,type=int,help=\"batch size of training\") # 训练时每次使用的样本数\n",
" parser.add_argument('--targe_update_fre',default=200,type=int,help=\"frequency of target network update\") # target network更新频率\n",
" parser.add_argument('--seed',default=10,type=int,help=\"seed\") \n",
" parser.add_argument('--lr',default=0.0001,type=float,help=\"learning rate\")\n",
" parser.add_argument('--device',default='cpu',type=str,help=\"cpu or gpu\") \n",
" args = parser.parse_args([]) \n",
" args = {**vars(args)} # type(dict) \n",
" return args\n",
"def smooth(data, weight=0.9): \n",
" '''用于平滑曲线类似于Tensorboard中的smooth\n",
"\n",
" Args:\n",
" data (List):输入数据\n",
" weight (Float): 平滑权重处于0-1之间数值越高说明越平滑一般取0.9\n",
"\n",
" Returns:\n",
" smoothed (List): 平滑后的数据\n",
" '''\n",
" last = data[0] # First value in the plot (first timestep)\n",
" smoothed = list()\n",
" for point in data:\n",
" smoothed_val = last * weight + (1 - weight) * point # 计算平滑值\n",
" smoothed.append(smoothed_val) \n",
" last = smoothed_val \n",
" return smoothed\n",
"\n",
"def plot_rewards(rewards,cfg,path=None,tag='train'):\n",
" sns.set()\n",
" plt.figure() # 创建一个图形实例,方便同时多画几个图\n",
" plt.title(f\"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}\")\n",
" plt.xlabel('epsiodes')\n",
" plt.plot(rewards, label='rewards')\n",
" plt.plot(smooth(rewards), label='smoothed')\n",
" plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 收获成果!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"seed = 10\n",
"n_states: 4, n_actions: 2\n",
"开始训练!\n",
"环境CartPole-v0算法DQN设备cpu\n",
"回合10/200奖励10.00Epislon: 0.062\n",
"回合20/200奖励85.00Epislon: 0.014\n",
"回合30/200奖励41.00Epislon: 0.011\n",
"回合40/200奖励31.00Epislon: 0.010\n",
"回合50/200奖励22.00Epislon: 0.010\n",
"回合60/200奖励10.00Epislon: 0.010\n",
"回合70/200奖励10.00Epislon: 0.010\n",
"回合80/200奖励22.00Epislon: 0.010\n",
"回合90/200奖励30.00Epislon: 0.010\n",
"回合100/200奖励20.00Epislon: 0.010\n",
"回合110/200奖励15.00Epislon: 0.010\n",
"回合120/200奖励45.00Epislon: 0.010\n",
"回合130/200奖励73.00Epislon: 0.010\n",
"回合140/200奖励180.00Epislon: 0.010\n",
"回合150/200奖励163.00Epislon: 0.010\n",
"回合160/200奖励191.00Epislon: 0.010\n",
"回合170/200奖励200.00Epislon: 0.010\n",
"回合180/200奖励200.00Epislon: 0.010\n",
"回合190/200奖励200.00Epislon: 0.010\n",
"回合200/200奖励200.00Epislon: 0.010\n",
"完成训练!\n",
"开始测试!\n",
"环境CartPole-v0算法DQN设备cpu\n",
"回合1/20奖励200.00\n",
"回合2/20奖励200.00\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jj/opt/anaconda3/envs/easyrl/lib/python3.7/site-packages/seaborn/rcmod.py:400: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
" if LooseVersion(mpl.__version__) >= \"3.0\":\n",
"/Users/jj/opt/anaconda3/envs/easyrl/lib/python3.7/site-packages/setuptools/_distutils/version.py:351: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
" other = LooseVersion(other)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"回合3/20奖励200.00\n",
"回合4/20奖励200.00\n",
"回合5/20奖励200.00\n",
"回合6/20奖励200.00\n",
"回合7/20奖励200.00\n",
"回合8/20奖励193.00\n",
"回合9/20奖励200.00\n",
"回合10/20奖励200.00\n",
"回合11/20奖励200.00\n",
"回合12/20奖励200.00\n",
"回合13/20奖励200.00\n",
"回合14/20奖励194.00\n",
"回合15/20奖励200.00\n",
"回合16/20奖励200.00\n",
"回合17/20奖励200.00\n",
"回合18/20奖励200.00\n",
"回合19/20奖励199.00\n",
"回合20/20奖励200.00\n",
"完成测试!\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 获取参数\n",
"cfg = get_args() \n",
"# 训练\n",
"env, agent = env_agent_config(cfg)\n",
"res_dic = train(cfg, env, agent)\n",
" \n",
"plot_rewards(res_dic['rewards'], cfg, tag=\"train\") \n",
"# 测试\n",
"res_dic = test(cfg, env, agent)\n",
"plot_rewards(res_dic['rewards'], cfg, tag=\"test\") # 画出结果"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.13 ('easyrl')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "8994a120d39b6e6a2ecc94b4007f5314b68aa69fc88a7f00edf21be39b41f49c"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}