{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义模型\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "import paddle.nn as nn\n", "import paddle.nn.functional as F\n", "import parl\n", "\n", "class CartpoleModel(parl.Model):\n", " \"\"\" Linear network to solve Cartpole problem.\n", " Args:\n", " n_states (int): Dimension of observation space.\n", " n_actions (int): Dimension of action space.\n", " \"\"\"\n", "\n", " def __init__(self, n_states, n_actions):\n", " super(CartpoleModel, self).__init__()\n", " hid1_size = 128\n", " hid2_size = 128\n", " self.fc1 = nn.Linear(n_states, hid1_size)\n", " self.fc2 = nn.Linear(hid1_size, hid2_size)\n", " self.fc3 = nn.Linear(hid2_size, n_actions)\n", "\n", " def forward(self, obs):\n", " h1 = F.relu(self.fc1(obs))\n", " h2 = F.relu(self.fc2(h1))\n", " Q = self.fc3(h2)\n", " return Q" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import parl\n", "import paddle\n", "import numpy as np\n", "\n", "\n", "class CartpoleAgent(parl.Agent):\n", " \"\"\"Agent of Cartpole env.\n", " Args:\n", " algorithm(parl.Algorithm): algorithm used to solve the problem.\n", " \"\"\"\n", "\n", " def __init__(self, algorithm, n_actions, e_greed=0.1, e_greed_decrement=0):\n", " super(CartpoleAgent, self).__init__(algorithm)\n", " assert isinstance(n_actions, int)\n", " self.n_actions = n_actions\n", "\n", " self.global_step = 0\n", " self.update_target_steps = 200\n", "\n", " self.e_greed = e_greed\n", " self.e_greed_decrement = e_greed_decrement\n", "\n", " def sample(self, obs):\n", " \"\"\"Sample an action `for exploration` when given an observation\n", " Args:\n", " obs(np.float32): shape of (n_states,)\n", " Returns:\n", " act(int): action\n", " \"\"\"\n", " sample = np.random.random()\n", " if sample < self.e_greed:\n", " act = np.random.randint(self.n_actions)\n", " else:\n", " if np.random.random() < 0.01:\n", " act = np.random.randint(self.n_actions)\n", " else:\n", " act = self.predict(obs)\n", " self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)\n", " return act\n", "\n", " def predict(self, obs):\n", " \"\"\"Predict an action when given an observation\n", " Args:\n", " obs(np.float32): shape of (n_states,)\n", " Returns:\n", " act(int): action\n", " \"\"\"\n", " obs = paddle.to_tensor(obs, dtype='float32')\n", " pred_q = self.alg.predict(obs)\n", " act = pred_q.argmax().numpy()[0]\n", " return act\n", "\n", " def learn(self, obs, act, reward, next_obs, terminal):\n", " \"\"\"Update model with an episode data\n", " Args:\n", " obs(np.float32): shape of (batch_size, n_states)\n", " act(np.int32): shape of (batch_size)\n", " reward(np.float32): shape of (batch_size)\n", " next_obs(np.float32): shape of (batch_size, n_states)\n", " terminal(np.float32): shape of (batch_size)\n", " Returns:\n", " loss(float)\n", " \"\"\"\n", " if self.global_step % self.update_target_steps == 0:\n", " self.alg.sync_target()\n", " self.global_step += 1\n", "\n", " act = np.expand_dims(act, axis=-1)\n", " reward = np.expand_dims(reward, axis=-1)\n", " terminal = np.expand_dims(terminal, axis=-1)\n", "\n", " obs = paddle.to_tensor(obs, dtype='float32')\n", " act = paddle.to_tensor(act, dtype='int32')\n", " reward = paddle.to_tensor(reward, dtype='float32')\n", " next_obs = paddle.to_tensor(next_obs, dtype='float32')\n", " terminal = paddle.to_tensor(terminal, dtype='float32')\n", " loss = self.alg.learn(obs, act, reward, next_obs, terminal)\n", " return loss.numpy()[0]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import os\n", "import gym\n", "import numpy as np\n", "import parl\n", "\n", "from parl.utils import logger, ReplayMemory\n", "from parl.algorithms import DQN\n", "\n", "LEARN_FREQ = 5 # training frequency\n", "MEMORY_SIZE = 200000\n", "MEMORY_WARMUP_SIZE = 200\n", "BATCH_SIZE = 64\n", "LEARNING_RATE = 0.0005\n", "GAMMA = 0.99\n", "\n", "# train an episode\n", "def run_train_episode(agent, env, rpm):\n", " total_reward = 0\n", " obs = env.reset()\n", " step = 0\n", " while True:\n", " step += 1\n", " action = agent.sample(obs)\n", " next_obs, reward, done, _ = env.step(action)\n", " rpm.append(obs, action, reward, next_obs, done)\n", "\n", " # train model\n", " if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):\n", " # s,a,r,s',done\n", " (batch_obs, batch_action, batch_reward, batch_next_obs,\n", " batch_done) = rpm.sample_batch(BATCH_SIZE)\n", " train_loss = agent.learn(batch_obs, batch_action, batch_reward,\n", " batch_next_obs, batch_done)\n", "\n", " total_reward += reward\n", " obs = next_obs\n", " if done:\n", " break\n", " return total_reward\n", "\n", "\n", "# evaluate 5 episodes\n", "def run_evaluate_episodes(agent, env, eval_episodes=5, render=False):\n", " eval_reward = []\n", " for i in range(eval_episodes):\n", " obs = env.reset()\n", " episode_reward = 0\n", " while True:\n", " action = agent.predict(obs)\n", " obs, reward, done, _ = env.step(action)\n", " episode_reward += reward\n", " if render:\n", " env.render()\n", " if done:\n", " break\n", " eval_reward.append(episode_reward)\n", " return np.mean(eval_reward)\n", "\n", "\n", "def main(args):\n", " env = gym.make('CartPole-v0')\n", " n_states = env.observation_space.shape[0]\n", " n_actions = env.action_space.n\n", " logger.info('n_states {}, n_actions {}'.format(n_states, n_actions))\n", "\n", " # set action_shape = 0 while in discrete control environment\n", " rpm = ReplayMemory(MEMORY_SIZE, n_states, 0)\n", "\n", " # build an agent\n", " model = CartpoleModel(n_states=n_states, n_actions=n_actions)\n", " alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)\n", " agent = CartpoleAgent(\n", " alg, n_actions=n_actions, e_greed=0.1, e_greed_decrement=1e-6)\n", "\n", " # warmup memory\n", " while len(rpm) < MEMORY_WARMUP_SIZE:\n", " run_train_episode(agent, env, rpm)\n", "\n", " max_episode = args.max_episode\n", "\n", " # start training\n", " episode = 0\n", " while episode < max_episode:\n", " # train part\n", " for i in range(50):\n", " total_reward = run_train_episode(agent, env, rpm)\n", " episode += 1\n", "\n", " # test part\n", " eval_reward = run_evaluate_episodes(agent, env, render=False)\n", " logger.info('episode:{} e_greed:{} Test reward:{}'.format(\n", " episode, agent.e_greed, eval_reward))\n", "\n", " # save the parameters to ./model.ckpt\n", " save_path = './model.ckpt'\n", " agent.save(save_path)\n", "\n", " # save the model and parameters of policy network for inference\n", " save_inference_path = './inference_model'\n", " input_shapes = [[None, env.observation_space.shape[0]]]\n", " input_dtypes = ['float32']\n", " agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)\n", "\n", "\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:64]\u001b[0m obs_dim 4, act_dim 2\n", "\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:92]\u001b[0m episode:50 e_greed:0.0988929999999989 Test reward:18.4\n", "\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:100 e_greed:0.09794799999999795 Test reward:9.6\n", "\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:150 e_greed:0.0973899999999974 Test reward:37.8\n", "\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:200 e_greed:0.09684299999999685 Test reward:8.8\n", "\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:250 e_greed:0.09635499999999636 Test reward:9.4\n", "\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:300 e_greed:0.09585299999999586 Test reward:9.2\n", "\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:350 e_greed:0.09535799999999536 Test reward:9.2\n", "\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:400 e_greed:0.09486399999999487 Test reward:10.0\n", "\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:450 e_greed:0.09435299999999436 Test reward:9.2\n", "\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:500 e_greed:0.09384899999999385 Test reward:9.4\n", "\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:550 e_greed:0.09302299999999303 Test reward:69.0\n", "\u001b[32m[08-01 21:48:25 MainThread @3996942455.py:92]\u001b[0m episode:600 e_greed:0.08774199999998775 Test reward:141.2\n", "\u001b[32m[08-01 21:48:30 MainThread @3996942455.py:92]\u001b[0m episode:650 e_greed:0.0791019999999791 Test reward:184.0\n", "\u001b[32m[08-01 21:48:35 MainThread @3996942455.py:92]\u001b[0m episode:700 e_greed:0.07011299999997012 Test reward:182.0\n", "\u001b[32m[08-01 21:48:40 MainThread @3996942455.py:92]\u001b[0m episode:750 e_greed:0.06089099999996089 Test reward:197.4\n", "\u001b[32m[08-01 21:48:45 MainThread @3996942455.py:92]\u001b[0m episode:800 e_greed:0.05139199999995139 Test reward:183.4\n", "\u001b[32m[08-01 21:48:50 MainThread @3996942455.py:92]\u001b[0m episode:850 e_greed:0.042255999999942256 Test reward:153.0\n", "\u001b[32m[08-01 21:48:55 MainThread @3996942455.py:92]\u001b[0m episode:900 e_greed:0.033495999999933496 Test reward:192.6\n", "\u001b[32m[08-01 21:49:00 MainThread @3996942455.py:92]\u001b[0m episode:950 e_greed:0.024318999999924318 Test reward:166.6\n", "\u001b[32m[08-01 21:49:06 MainThread @3996942455.py:92]\u001b[0m episode:1000 e_greed:0.014873999999916176 Test reward:187.0\n" ] } ], "source": [ "import argparse\n", "parser = argparse.ArgumentParser()\n", "parser.add_argument(\n", " '--max_episode',\n", " type=int,\n", " default=1000,\n", " help='stop condition: number of max episode')\n", "args = parser.parse_args(args=[])\n", "\n", "main(args)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.12 ('rl_tutorials')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "4f613f1ab80ec98dc1b91d6e720de51301598a187317378e53e49b773c1123dd" } } }, "nbformat": 4, "nbformat_minor": 2 }