Files
easy-rl/notebooks/A2C.ipynb
2022-12-04 20:54:36 +08:00

371 lines
56 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.distributions import Categorical\n",
"import numpy as np\n",
"from multiprocessing import Process, Pipe\n",
"import argparse\n",
"import gym"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 建立Actor和Critic网络"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class ActorCritic(nn.Module):\n",
" ''' A2C网络模型包含一个Actor和Critic\n",
" '''\n",
" def __init__(self, input_dim, output_dim, hidden_dim):\n",
" super(ActorCritic, self).__init__()\n",
" self.critic = nn.Sequential(\n",
" nn.Linear(input_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim, 1)\n",
" )\n",
" \n",
" self.actor = nn.Sequential(\n",
" nn.Linear(input_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim, output_dim),\n",
" nn.Softmax(dim=1),\n",
" )\n",
" \n",
" def forward(self, x):\n",
" value = self.critic(x)\n",
" probs = self.actor(x)\n",
" dist = Categorical(probs)\n",
" return dist, value"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class A2C:\n",
" ''' A2C算法\n",
" '''\n",
" def __init__(self,n_states,n_actions,cfg) -> None:\n",
" self.gamma = cfg.gamma\n",
" self.device = cfg.device\n",
" self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)\n",
" self.optimizer = optim.Adam(self.model.parameters())\n",
"\n",
" def compute_returns(self,next_value, rewards, masks):\n",
" R = next_value\n",
" returns = []\n",
" for step in reversed(range(len(rewards))):\n",
" R = rewards[step] + self.gamma * R * masks[step]\n",
" returns.insert(0, R)\n",
" return returns"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def make_envs(env_name):\n",
" def _thunk():\n",
" env = gym.make(env_name)\n",
" env.seed(2)\n",
" return env\n",
" return _thunk\n",
"def test_env(env,model,vis=False):\n",
" state = env.reset()\n",
" if vis: env.render()\n",
" done = False\n",
" total_reward = 0\n",
" while not done:\n",
" state = torch.FloatTensor(state).unsqueeze(0).to(cfg.device)\n",
" dist, _ = model(state)\n",
" next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n",
" state = next_state\n",
" if vis: env.render()\n",
" total_reward += reward\n",
" return total_reward\n",
"\n",
"def compute_returns(next_value, rewards, masks, gamma=0.99):\n",
" R = next_value\n",
" returns = []\n",
" for step in reversed(range(len(rewards))):\n",
" R = rewards[step] + gamma * R * masks[step]\n",
" returns.insert(0, R)\n",
" return returns\n",
"\n",
"\n",
"def train(cfg,envs):\n",
" print('Start training!')\n",
" print(f'Env:{cfg.env_name}, Algorithm:{cfg.algo_name}, Device:{cfg.device}')\n",
" env = gym.make(cfg.env_name) # a single env\n",
" env.seed(10)\n",
" n_states = envs.observation_space.shape[0]\n",
" n_actions = envs.action_space.n\n",
" model = ActorCritic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)\n",
" optimizer = optim.Adam(model.parameters())\n",
" step_idx = 0\n",
" test_rewards = []\n",
" test_ma_rewards = []\n",
" state = envs.reset()\n",
" while step_idx < cfg.max_steps:\n",
" log_probs = []\n",
" values = []\n",
" rewards = []\n",
" masks = []\n",
" entropy = 0\n",
" # rollout trajectory\n",
" for _ in range(cfg.n_steps):\n",
" state = torch.FloatTensor(state).to(cfg.device)\n",
" dist, value = model(state)\n",
" action = dist.sample()\n",
" next_state, reward, done, _ = envs.step(action.cpu().numpy())\n",
" log_prob = dist.log_prob(action)\n",
" entropy += dist.entropy().mean()\n",
" log_probs.append(log_prob)\n",
" values.append(value)\n",
" rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(cfg.device))\n",
" masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(cfg.device))\n",
" state = next_state\n",
" step_idx += 1\n",
" if step_idx % 200 == 0:\n",
" test_reward = np.mean([test_env(env,model) for _ in range(10)])\n",
" print(f\"step_idx:{step_idx}, test_reward:{test_reward}\")\n",
" test_rewards.append(test_reward)\n",
" if test_ma_rewards:\n",
" test_ma_rewards.append(0.9*test_ma_rewards[-1]+0.1*test_reward)\n",
" else:\n",
" test_ma_rewards.append(test_reward) \n",
" # plot(step_idx, test_rewards) \n",
" next_state = torch.FloatTensor(next_state).to(cfg.device)\n",
" _, next_value = model(next_state)\n",
" returns = compute_returns(next_value, rewards, masks)\n",
" log_probs = torch.cat(log_probs)\n",
" returns = torch.cat(returns).detach()\n",
" values = torch.cat(values)\n",
" advantage = returns - values\n",
" actor_loss = -(log_probs * advantage.detach()).mean()\n",
" critic_loss = advantage.pow(2).mean()\n",
" loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" print('Finish training')\n",
" return test_rewards, test_ma_rewards"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns \n",
"def plot_rewards(rewards, ma_rewards, cfg, tag='train'):\n",
" sns.set()\n",
" plt.figure() # 创建一个图形实例,方便同时多画几个图\n",
" plt.title(\"learning curve on {} of {} for {}\".format(\n",
" cfg.device, cfg.algo_name, cfg.env_name))\n",
" plt.xlabel('epsiodes')\n",
" plt.plot(rewards, label='rewards')\n",
" plt.plot(ma_rewards, label='ma rewards')\n",
" plt.legend()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start training!\n",
"Env:CartPole-v0, Algorithm:A2C, Device:cuda\n",
"step_idx:200, test_reward:18.6\n",
"step_idx:400, test_reward:19.7\n",
"step_idx:600, test_reward:24.2\n",
"step_idx:800, test_reward:19.5\n",
"step_idx:1000, test_reward:33.9\n",
"step_idx:1200, test_reward:36.1\n",
"step_idx:1400, test_reward:32.6\n",
"step_idx:1600, test_reward:36.3\n",
"step_idx:1800, test_reward:38.9\n",
"step_idx:2000, test_reward:60.8\n",
"step_idx:2200, test_reward:41.9\n",
"step_idx:2400, test_reward:42.2\n",
"step_idx:2600, test_reward:71.6\n",
"step_idx:2800, test_reward:123.6\n",
"step_idx:3000, test_reward:57.5\n",
"step_idx:3200, test_reward:155.4\n",
"step_idx:3400, test_reward:111.4\n",
"step_idx:3600, test_reward:133.8\n",
"step_idx:3800, test_reward:133.8\n",
"step_idx:4000, test_reward:114.3\n",
"step_idx:4200, test_reward:165.5\n",
"step_idx:4400, test_reward:119.4\n",
"step_idx:4600, test_reward:173.4\n",
"step_idx:4800, test_reward:115.4\n",
"step_idx:5000, test_reward:159.7\n",
"step_idx:5200, test_reward:178.1\n",
"step_idx:5400, test_reward:137.8\n",
"step_idx:5600, test_reward:146.0\n",
"step_idx:5800, test_reward:187.4\n",
"step_idx:6000, test_reward:200.0\n",
"step_idx:6200, test_reward:169.2\n",
"step_idx:6400, test_reward:167.8\n",
"step_idx:6600, test_reward:184.3\n",
"step_idx:6800, test_reward:162.3\n",
"step_idx:7000, test_reward:125.4\n",
"step_idx:7200, test_reward:150.6\n",
"step_idx:7400, test_reward:152.6\n",
"step_idx:7600, test_reward:122.5\n",
"step_idx:7800, test_reward:136.3\n",
"step_idx:8000, test_reward:131.4\n",
"step_idx:8200, test_reward:174.6\n",
"step_idx:8400, test_reward:91.7\n",
"step_idx:8600, test_reward:170.1\n",
"step_idx:8800, test_reward:166.0\n",
"step_idx:9000, test_reward:150.2\n",
"step_idx:9200, test_reward:104.6\n",
"step_idx:9400, test_reward:147.2\n",
"step_idx:9600, test_reward:111.8\n",
"step_idx:9800, test_reward:118.7\n",
"step_idx:10000, test_reward:102.6\n",
"step_idx:10200, test_reward:99.0\n",
"step_idx:10400, test_reward:64.6\n",
"step_idx:10600, test_reward:133.7\n",
"step_idx:10800, test_reward:119.7\n",
"step_idx:11000, test_reward:112.6\n",
"step_idx:11200, test_reward:116.1\n",
"step_idx:11400, test_reward:116.3\n",
"step_idx:11600, test_reward:116.2\n",
"step_idx:11800, test_reward:115.3\n",
"step_idx:12000, test_reward:109.7\n",
"step_idx:12200, test_reward:110.3\n",
"step_idx:12400, test_reward:131.4\n",
"step_idx:12600, test_reward:128.3\n",
"step_idx:12800, test_reward:128.8\n",
"step_idx:13000, test_reward:119.8\n",
"step_idx:13200, test_reward:108.6\n",
"step_idx:13400, test_reward:128.4\n",
"step_idx:13600, test_reward:138.2\n",
"step_idx:13800, test_reward:119.1\n",
"step_idx:14000, test_reward:140.7\n",
"step_idx:14200, test_reward:145.3\n",
"step_idx:14400, test_reward:154.1\n",
"step_idx:14600, test_reward:165.2\n",
"step_idx:14800, test_reward:138.2\n",
"step_idx:15000, test_reward:143.5\n",
"step_idx:15200, test_reward:125.4\n",
"step_idx:15400, test_reward:137.1\n",
"step_idx:15600, test_reward:150.1\n",
"step_idx:15800, test_reward:132.9\n",
"step_idx:16000, test_reward:140.4\n",
"step_idx:16200, test_reward:141.3\n",
"step_idx:16400, test_reward:135.5\n",
"step_idx:16600, test_reward:135.5\n",
"step_idx:16800, test_reward:125.6\n",
"step_idx:17000, test_reward:126.8\n",
"step_idx:17200, test_reward:124.7\n",
"step_idx:17400, test_reward:129.6\n",
"step_idx:17600, test_reward:114.3\n",
"step_idx:17800, test_reward:57.3\n",
"step_idx:18000, test_reward:164.7\n",
"step_idx:18200, test_reward:165.8\n",
"step_idx:18400, test_reward:196.7\n",
"step_idx:18600, test_reward:198.8\n",
"step_idx:18800, test_reward:200.0\n",
"step_idx:19000, test_reward:199.6\n",
"step_idx:19200, test_reward:189.5\n",
"step_idx:19400, test_reward:177.9\n",
"step_idx:19600, test_reward:159.3\n",
"step_idx:19800, test_reward:127.7\n",
"step_idx:20000, test_reward:143.6\n",
"Finish training\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import easydict\n",
"from common.multiprocessing_env import SubprocVecEnv\n",
"cfg = easydict.EasyDict({\n",
" \"algo_name\": 'A2C',\n",
" \"env_name\": 'CartPole-v0',\n",
" \"n_envs\": 8,\n",
" \"max_steps\": 20000,\n",
" \"n_steps\":5,\n",
" \"gamma\":0.99,\n",
" \"lr\": 1e-3,\n",
" \"hidden_dim\": 256,\n",
" \"device\":torch.device(\n",
" \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"})\n",
"envs = [make_envs(cfg.env_name) for i in range(cfg.n_envs)]\n",
"envs = SubprocVecEnv(envs) \n",
"rewards,ma_rewards = train(cfg,envs)\n",
"plot_rewards(rewards, ma_rewards, cfg, tag=\"train\") # 画出结果"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.12 ('rl_tutorials')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "4f613f1ab80ec98dc1b91d6e720de51301598a187317378e53e49b773c1123dd"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}