Files
easy-rl/codes/A2C/task0.ipynb
johnjim0816 3b712e8815 update codes
2021-12-21 20:14:13 +08:00

266 lines
28 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"curr_path = str(Path().absolute()) # 当前路径\n",
"parent_path = str(Path().absolute().parent) # 父路径\n",
"sys.path.append(parent_path) # 添加路径到系统路径\n",
"import math\n",
"import random\n",
"\n",
"import gym\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"from torch.distributions import Categorical\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"use_cuda = torch.cuda.is_available()\n",
"device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from common.multiprocessing_env import SubprocVecEnv\n",
"\n",
"num_envs = 16\n",
"env_name = \"CartPole-v0\"\n",
"\n",
"def make_env():\n",
" def _thunk():\n",
" env = gym.make(env_name)\n",
" return env\n",
"\n",
" return _thunk\n",
"\n",
"envs = [make_env() for i in range(num_envs)]\n",
"envs = SubprocVecEnv(envs)\n",
"\n",
"env = gym.make(env_name)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class ActorCritic(nn.Module):\n",
" def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):\n",
" super(ActorCritic, self).__init__()\n",
" \n",
" self.critic = nn.Sequential(\n",
" nn.Linear(num_inputs, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, 1)\n",
" )\n",
" \n",
" self.actor = nn.Sequential(\n",
" nn.Linear(num_inputs, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, num_outputs),\n",
" nn.Softmax(dim=1),\n",
" )\n",
" \n",
" def forward(self, x):\n",
" value = self.critic(x)\n",
" probs = self.actor(x)\n",
" dist = Categorical(probs)\n",
" return dist, value"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def plot(frame_idx, rewards):\n",
" clear_output(True)\n",
" plt.figure(figsize=(20,5))\n",
" plt.subplot(131)\n",
" plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n",
" plt.plot(rewards)\n",
" plt.show()\n",
" \n",
"def test_env(vis=False):\n",
" state = env.reset()\n",
" if vis: env.render()\n",
" done = False\n",
" total_reward = 0\n",
" while not done:\n",
" state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
" dist, _ = model(state)\n",
" next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n",
" state = next_state\n",
" if vis: env.render()\n",
" total_reward += reward\n",
" return total_reward"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def compute_returns(next_value, rewards, masks, gamma=0.99):\n",
" R = next_value\n",
" returns = []\n",
" for step in reversed(range(len(rewards))):\n",
" R = rewards[step] + gamma * R * masks[step]\n",
" returns.insert(0, R)\n",
" return returns"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"num_inputs = envs.observation_space.shape[0]\n",
"num_outputs = envs.action_space.n\n",
"\n",
"#Hyper params:\n",
"hidden_size = 256\n",
"lr = 3e-4\n",
"num_steps = 5\n",
"\n",
"model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)\n",
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"max_frames = 20000\n",
"frame_idx = 0\n",
"test_rewards = []"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1440x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"state = envs.reset()\n",
"\n",
"while frame_idx < max_frames:\n",
"\n",
" log_probs = []\n",
" values = []\n",
" rewards = []\n",
" masks = []\n",
" entropy = 0\n",
"\n",
" for _ in range(num_steps):\n",
" state = torch.FloatTensor(state).to(device)\n",
" dist, value = model(state)\n",
"\n",
" action = dist.sample()\n",
" next_state, reward, done, _ = envs.step(action.cpu().numpy())\n",
"\n",
" log_prob = dist.log_prob(action)\n",
" entropy += dist.entropy().mean()\n",
" \n",
" log_probs.append(log_prob)\n",
" values.append(value)\n",
" rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))\n",
" masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))\n",
" \n",
" state = next_state\n",
" frame_idx += 1\n",
" \n",
" if frame_idx % 1000 == 0:\n",
" test_rewards.append(np.mean([test_env() for _ in range(10)]))\n",
" plot(frame_idx, test_rewards)\n",
" \n",
" next_state = torch.FloatTensor(next_state).to(device)\n",
" _, next_value = model(next_state)\n",
" returns = compute_returns(next_value, rewards, masks)\n",
" \n",
" log_probs = torch.cat(log_probs)\n",
" returns = torch.cat(returns).detach()\n",
" values = torch.cat(values)\n",
"\n",
" advantage = returns - values\n",
"\n",
" actor_loss = -(log_probs * advantage.detach()).mean()\n",
" critic_loss = advantage.pow(2).mean()\n",
"\n",
" loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()"
]
}
],
"metadata": {
"interpreter": {
"hash": "fe38df673a99c62a9fea33a7aceda74c9b65b12ee9d076c5851d98b692a4989a"
},
"kernelspec": {
"display_name": "Python 3.7.9 64-bit ('py37': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}