Files
easy-rl/projects/PARL/DQN.ipynb
2022-08-15 22:31:37 +08:00

319 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 定义模型\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import paddle.nn as nn\n",
"import paddle.nn.functional as F\n",
"import parl\n",
"\n",
"class CartpoleModel(parl.Model):\n",
" \"\"\" Linear network to solve Cartpole problem.\n",
" Args:\n",
" n_states (int): Dimension of observation space.\n",
" n_actions (int): Dimension of action space.\n",
" \"\"\"\n",
"\n",
" def __init__(self, n_states, n_actions):\n",
" super(CartpoleModel, self).__init__()\n",
" hid1_size = 128\n",
" hid2_size = 128\n",
" self.fc1 = nn.Linear(n_states, hid1_size)\n",
" self.fc2 = nn.Linear(hid1_size, hid2_size)\n",
" self.fc3 = nn.Linear(hid2_size, n_actions)\n",
"\n",
" def forward(self, obs):\n",
" h1 = F.relu(self.fc1(obs))\n",
" h2 = F.relu(self.fc2(h1))\n",
" Q = self.fc3(h2)\n",
" return Q"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"import parl\n",
"import paddle\n",
"import numpy as np\n",
"\n",
"\n",
"class CartpoleAgent(parl.Agent):\n",
" \"\"\"Agent of Cartpole env.\n",
" Args:\n",
" algorithm(parl.Algorithm): algorithm used to solve the problem.\n",
" \"\"\"\n",
"\n",
" def __init__(self, algorithm, n_actions, e_greed=0.1, e_greed_decrement=0):\n",
" super(CartpoleAgent, self).__init__(algorithm)\n",
" assert isinstance(n_actions, int)\n",
" self.n_actions = n_actions\n",
"\n",
" self.global_step = 0\n",
" self.update_target_steps = 200\n",
"\n",
" self.e_greed = e_greed\n",
" self.e_greed_decrement = e_greed_decrement\n",
"\n",
" def sample(self, obs):\n",
" \"\"\"Sample an action `for exploration` when given an observation\n",
" Args:\n",
" obs(np.float32): shape of (n_states,)\n",
" Returns:\n",
" act(int): action\n",
" \"\"\"\n",
" sample = np.random.random()\n",
" if sample < self.e_greed:\n",
" act = np.random.randint(self.n_actions)\n",
" else:\n",
" if np.random.random() < 0.01:\n",
" act = np.random.randint(self.n_actions)\n",
" else:\n",
" act = self.predict(obs)\n",
" self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)\n",
" return act\n",
"\n",
" def predict(self, obs):\n",
" \"\"\"Predict an action when given an observation\n",
" Args:\n",
" obs(np.float32): shape of (n_states,)\n",
" Returns:\n",
" act(int): action\n",
" \"\"\"\n",
" obs = paddle.to_tensor(obs, dtype='float32')\n",
" pred_q = self.alg.predict(obs)\n",
" act = pred_q.argmax().numpy()[0]\n",
" return act\n",
"\n",
" def learn(self, obs, act, reward, next_obs, terminal):\n",
" \"\"\"Update model with an episode data\n",
" Args:\n",
" obs(np.float32): shape of (batch_size, n_states)\n",
" act(np.int32): shape of (batch_size)\n",
" reward(np.float32): shape of (batch_size)\n",
" next_obs(np.float32): shape of (batch_size, n_states)\n",
" terminal(np.float32): shape of (batch_size)\n",
" Returns:\n",
" loss(float)\n",
" \"\"\"\n",
" if self.global_step % self.update_target_steps == 0:\n",
" self.alg.sync_target()\n",
" self.global_step += 1\n",
"\n",
" act = np.expand_dims(act, axis=-1)\n",
" reward = np.expand_dims(reward, axis=-1)\n",
" terminal = np.expand_dims(terminal, axis=-1)\n",
"\n",
" obs = paddle.to_tensor(obs, dtype='float32')\n",
" act = paddle.to_tensor(act, dtype='int32')\n",
" reward = paddle.to_tensor(reward, dtype='float32')\n",
" next_obs = paddle.to_tensor(next_obs, dtype='float32')\n",
" terminal = paddle.to_tensor(terminal, dtype='float32')\n",
" loss = self.alg.learn(obs, act, reward, next_obs, terminal)\n",
" return loss.numpy()[0]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import gym\n",
"import numpy as np\n",
"import parl\n",
"\n",
"from parl.utils import logger, ReplayMemory\n",
"from parl.algorithms import DQN\n",
"\n",
"LEARN_FREQ = 5 # training frequency\n",
"MEMORY_SIZE = 200000\n",
"MEMORY_WARMUP_SIZE = 200\n",
"BATCH_SIZE = 64\n",
"LEARNING_RATE = 0.0005\n",
"GAMMA = 0.99\n",
"\n",
"# train an episode\n",
"def run_train_episode(agent, env, rpm):\n",
" total_reward = 0\n",
" obs = env.reset()\n",
" step = 0\n",
" while True:\n",
" step += 1\n",
" action = agent.sample(obs)\n",
" next_obs, reward, done, _ = env.step(action)\n",
" rpm.append(obs, action, reward, next_obs, done)\n",
"\n",
" # train model\n",
" if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):\n",
" # s,a,r,s',done\n",
" (batch_obs, batch_action, batch_reward, batch_next_obs,\n",
" batch_done) = rpm.sample_batch(BATCH_SIZE)\n",
" train_loss = agent.learn(batch_obs, batch_action, batch_reward,\n",
" batch_next_obs, batch_done)\n",
"\n",
" total_reward += reward\n",
" obs = next_obs\n",
" if done:\n",
" break\n",
" return total_reward\n",
"\n",
"\n",
"# evaluate 5 episodes\n",
"def run_evaluate_episodes(agent, env, eval_episodes=5, render=False):\n",
" eval_reward = []\n",
" for i in range(eval_episodes):\n",
" obs = env.reset()\n",
" episode_reward = 0\n",
" while True:\n",
" action = agent.predict(obs)\n",
" obs, reward, done, _ = env.step(action)\n",
" episode_reward += reward\n",
" if render:\n",
" env.render()\n",
" if done:\n",
" break\n",
" eval_reward.append(episode_reward)\n",
" return np.mean(eval_reward)\n",
"\n",
"\n",
"def main(args):\n",
" env = gym.make('CartPole-v0')\n",
" n_states = env.observation_space.shape[0]\n",
" n_actions = env.action_space.n\n",
" logger.info('n_states {}, n_actions {}'.format(n_states, n_actions))\n",
"\n",
" # set action_shape = 0 while in discrete control environment\n",
" rpm = ReplayMemory(MEMORY_SIZE, n_states, 0)\n",
"\n",
" # build an agent\n",
" model = CartpoleModel(n_states=n_states, n_actions=n_actions)\n",
" alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)\n",
" agent = CartpoleAgent(\n",
" alg, n_actions=n_actions, e_greed=0.1, e_greed_decrement=1e-6)\n",
"\n",
" # warmup memory\n",
" while len(rpm) < MEMORY_WARMUP_SIZE:\n",
" run_train_episode(agent, env, rpm)\n",
"\n",
" max_episode = args.max_episode\n",
"\n",
" # start training\n",
" episode = 0\n",
" while episode < max_episode:\n",
" # train part\n",
" for i in range(50):\n",
" total_reward = run_train_episode(agent, env, rpm)\n",
" episode += 1\n",
"\n",
" # test part\n",
" eval_reward = run_evaluate_episodes(agent, env, render=False)\n",
" logger.info('episode:{} e_greed:{} Test reward:{}'.format(\n",
" episode, agent.e_greed, eval_reward))\n",
"\n",
" # save the parameters to ./model.ckpt\n",
" save_path = './model.ckpt'\n",
" agent.save(save_path)\n",
"\n",
" # save the model and parameters of policy network for inference\n",
" save_inference_path = './inference_model'\n",
" input_shapes = [[None, env.observation_space.shape[0]]]\n",
" input_dtypes = ['float32']\n",
" agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)\n",
"\n",
"\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:64]\u001b[0m obs_dim 4, act_dim 2\n",
"\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:92]\u001b[0m episode:50 e_greed:0.0988929999999989 Test reward:18.4\n",
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:100 e_greed:0.09794799999999795 Test reward:9.6\n",
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:150 e_greed:0.0973899999999974 Test reward:37.8\n",
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:200 e_greed:0.09684299999999685 Test reward:8.8\n",
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:250 e_greed:0.09635499999999636 Test reward:9.4\n",
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:300 e_greed:0.09585299999999586 Test reward:9.2\n",
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:350 e_greed:0.09535799999999536 Test reward:9.2\n",
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:400 e_greed:0.09486399999999487 Test reward:10.0\n",
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:450 e_greed:0.09435299999999436 Test reward:9.2\n",
"\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:500 e_greed:0.09384899999999385 Test reward:9.4\n",
"\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:550 e_greed:0.09302299999999303 Test reward:69.0\n",
"\u001b[32m[08-01 21:48:25 MainThread @3996942455.py:92]\u001b[0m episode:600 e_greed:0.08774199999998775 Test reward:141.2\n",
"\u001b[32m[08-01 21:48:30 MainThread @3996942455.py:92]\u001b[0m episode:650 e_greed:0.0791019999999791 Test reward:184.0\n",
"\u001b[32m[08-01 21:48:35 MainThread @3996942455.py:92]\u001b[0m episode:700 e_greed:0.07011299999997012 Test reward:182.0\n",
"\u001b[32m[08-01 21:48:40 MainThread @3996942455.py:92]\u001b[0m episode:750 e_greed:0.06089099999996089 Test reward:197.4\n",
"\u001b[32m[08-01 21:48:45 MainThread @3996942455.py:92]\u001b[0m episode:800 e_greed:0.05139199999995139 Test reward:183.4\n",
"\u001b[32m[08-01 21:48:50 MainThread @3996942455.py:92]\u001b[0m episode:850 e_greed:0.042255999999942256 Test reward:153.0\n",
"\u001b[32m[08-01 21:48:55 MainThread @3996942455.py:92]\u001b[0m episode:900 e_greed:0.033495999999933496 Test reward:192.6\n",
"\u001b[32m[08-01 21:49:00 MainThread @3996942455.py:92]\u001b[0m episode:950 e_greed:0.024318999999924318 Test reward:166.6\n",
"\u001b[32m[08-01 21:49:06 MainThread @3996942455.py:92]\u001b[0m episode:1000 e_greed:0.014873999999916176 Test reward:187.0\n"
]
}
],
"source": [
"import argparse\n",
"parser = argparse.ArgumentParser()\n",
"parser.add_argument(\n",
" '--max_episode',\n",
" type=int,\n",
" default=1000,\n",
" help='stop condition: number of max episode')\n",
"args = parser.parse_args(args=[])\n",
"\n",
"main(args)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.12 ('rl_tutorials')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "4f613f1ab80ec98dc1b91d6e720de51301598a187317378e53e49b773c1123dd"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}