{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10-final" }, "orig_nbformat": 2, "kernelspec": { "name": "python3", "display_name": "Python 3.7.10 64-bit ('py37': conda)", "metadata": { "interpreter": { "hash": "fbea1422c2cf61ed9c0cfc03f38f71cc9083cc288606edc4170b5309b352ce27" } } } }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys,os\n", "from pathlib import Path\n", "curr_path = str(Path().absolute())\n", "parent_path = str(Path().absolute().parent)\n", "sys.path.append(parent_path) # add current terminal path to sys.path\n", "\n", "import gym\n", "import torch\n", "import numpy as np\n", "import datetime\n", "\n", "from HierarchicalDQN.agent import HierarchicalDQN\n", "from common.plot import plot_rewards\n", "from common.utils import save_results" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "SEQUENCE = datetime.datetime.now().strftime(\n", " \"%Y%m%d-%H%M%S\") # obtain current time\n", "SAVED_MODEL_PATH = curr_path+\"/saved_model/\"+SEQUENCE+'/' # path to save model\n", "if not os.path.exists(curr_path+\"/saved_model/\"):\n", " os.mkdir(curr_path+\"/saved_model/\")\n", "if not os.path.exists(SAVED_MODEL_PATH):\n", " os.mkdir(SAVED_MODEL_PATH)\n", "RESULT_PATH = curr_path+\"/results/\"+SEQUENCE+'/' # path to save rewards\n", "if not os.path.exists(curr_path+\"/results/\"):\n", " os.mkdir(curr_path+\"/results/\")\n", "if not os.path.exists(RESULT_PATH):\n", " os.mkdir(RESULT_PATH)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class HierarchicalDQNConfig:\n", " def __init__(self):\n", " self.algo = \"H-DQN\" # name of algo\n", " self.gamma = 0.95\n", " self.epsilon_start = 1 # start epsilon of e-greedy policy\n", " self.epsilon_end = 0.01\n", " self.epsilon_decay = 500\n", " self.lr = 0.0001 # learning rate\n", " self.memory_capacity = 20000 # Replay Memory capacity\n", " self.batch_size = 64\n", " self.train_eps = 300 # 训练的episode数目\n", " self.target_update = 2 # target net的更新频率\n", " self.eval_eps = 20 # 测试的episode数目\n", " self.device = torch.device(\n", " \"cuda\" if torch.cuda.is_available() else \"cpu\") # 检测gpu\n", " self.hidden_dim = 256 # dimension of hidden layer" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def train(cfg, env, agent):\n", " print('Start to train !')\n", " rewards = []\n", " ma_rewards = [] # moveing average reward\n", " for i_episode in range(cfg.train_eps):\n", " state = env.reset()\n", " done = False\n", " ep_reward = 0\n", " while not done:\n", " goal = agent.set_goal(state)\n", " onehot_goal = agent.to_onehot(goal)\n", " meta_state = state\n", " extrinsic_reward = 0\n", " while not done and goal != np.argmax(state):\n", " goal_state = np.concatenate([state, onehot_goal])\n", " action = agent.choose_action(goal_state)\n", " next_state, reward, done, _ = env.step(action)\n", " ep_reward += reward\n", " extrinsic_reward += reward\n", " intrinsic_reward = 1.0 if goal == np.argmax(\n", " next_state) else 0.0\n", " agent.memory.push(goal_state, action, intrinsic_reward, np.concatenate(\n", " [next_state, onehot_goal]), done)\n", " state = next_state\n", " agent.update()\n", " agent.meta_memory.push(meta_state, goal, extrinsic_reward, state, done)\n", " print('Episode:{}/{}, Reward:{}'.format(i_episode+1, cfg.train_eps, ep_reward))\n", " rewards.append(ep_reward)\n", " if ma_rewards:\n", " ma_rewards.append(\n", " 0.9*ma_rewards[-1]+0.1*ep_reward)\n", " else:\n", " ma_rewards.append(ep_reward)\n", " print('Complete training!')\n", " return rewards, ma_rewards" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Start to train !\n", "Episode:1/300, Reward:25.0\n", "Episode:2/300, Reward:26.0\n", "Episode:3/300, Reward:23.0\n", "Episode:4/300, Reward:19.0\n", "Episode:5/300, Reward:23.0\n", "Episode:6/300, Reward:21.0\n", "Episode:7/300, Reward:21.0\n", "Episode:8/300, Reward:22.0\n", "Episode:9/300, Reward:15.0\n", "Episode:10/300, Reward:12.0\n", "Episode:11/300, Reward:39.0\n", "Episode:12/300, Reward:42.0\n", "Episode:13/300, Reward:79.0\n", "Episode:14/300, Reward:54.0\n", "Episode:15/300, Reward:28.0\n", "Episode:16/300, Reward:85.0\n", "Episode:17/300, Reward:46.0\n", "Episode:18/300, Reward:37.0\n", "Episode:19/300, Reward:45.0\n", "Episode:20/300, Reward:79.0\n", "Episode:21/300, Reward:80.0\n", "Episode:22/300, Reward:154.0\n", "Episode:23/300, Reward:74.0\n", "Episode:24/300, Reward:129.0\n", "Episode:25/300, Reward:185.0\n", "Episode:26/300, Reward:200.0\n", "Episode:27/300, Reward:115.0\n", "Episode:28/300, Reward:104.0\n", "Episode:29/300, Reward:200.0\n", "Episode:30/300, Reward:118.0\n", "Episode:31/300, Reward:200.0\n", "Episode:32/300, Reward:200.0\n", "Episode:33/300, Reward:83.0\n", "Episode:34/300, Reward:75.0\n", "Episode:35/300, Reward:46.0\n", "Episode:36/300, Reward:96.0\n", "Episode:37/300, Reward:78.0\n", "Episode:38/300, Reward:150.0\n", "Episode:39/300, Reward:147.0\n", "Episode:40/300, Reward:74.0\n", "Episode:41/300, Reward:137.0\n", "Episode:42/300, Reward:182.0\n", "Episode:43/300, Reward:200.0\n", "Episode:44/300, Reward:200.0\n", "Episode:45/300, Reward:200.0\n", "Episode:46/300, Reward:184.0\n", "Episode:47/300, Reward:200.0\n", "Episode:48/300, Reward:200.0\n", "Episode:49/300, Reward:200.0\n", "Episode:50/300, Reward:61.0\n", "Episode:51/300, Reward:9.0\n", "Episode:52/300, Reward:9.0\n", "Episode:53/300, Reward:200.0\n", "Episode:54/300, Reward:200.0\n", "Episode:55/300, Reward:200.0\n", "Episode:56/300, Reward:200.0\n", "Episode:57/300, Reward:200.0\n", "Episode:58/300, Reward:200.0\n", "Episode:59/300, Reward:200.0\n", "Episode:60/300, Reward:167.0\n", "Episode:61/300, Reward:200.0\n", "Episode:62/300, Reward:200.0\n", "Episode:63/300, Reward:200.0\n", "Episode:64/300, Reward:200.0\n", "Episode:65/300, Reward:200.0\n", "Episode:66/300, Reward:200.0\n", "Episode:67/300, Reward:200.0\n", "Episode:68/300, Reward:200.0\n", "Episode:69/300, Reward:197.0\n", "Episode:70/300, Reward:200.0\n", "Episode:71/300, Reward:200.0\n", "Episode:72/300, Reward:200.0\n", "Episode:73/300, Reward:200.0\n", "Episode:74/300, Reward:200.0\n", "Episode:75/300, Reward:200.0\n", "Episode:76/300, Reward:200.0\n", "Episode:77/300, Reward:200.0\n", "Episode:78/300, Reward:200.0\n", "Episode:79/300, Reward:200.0\n", "Episode:80/300, Reward:200.0\n", "Episode:81/300, Reward:181.0\n", "Episode:82/300, Reward:200.0\n", "Episode:83/300, Reward:200.0\n", "Episode:84/300, Reward:200.0\n", "Episode:85/300, Reward:200.0\n", "Episode:86/300, Reward:200.0\n", "Episode:87/300, Reward:200.0\n", "Episode:88/300, Reward:200.0\n", "Episode:89/300, Reward:200.0\n", "Episode:90/300, Reward:200.0\n", "Episode:91/300, Reward:200.0\n", "Episode:92/300, Reward:200.0\n", "Episode:93/300, Reward:200.0\n", "Episode:94/300, Reward:200.0\n", "Episode:95/300, Reward:200.0\n", "Episode:96/300, Reward:200.0\n", "Episode:97/300, Reward:200.0\n", "Episode:98/300, Reward:200.0\n", "Episode:99/300, Reward:192.0\n", "Episode:100/300, Reward:183.0\n", "Episode:101/300, Reward:200.0\n", "Episode:102/300, Reward:200.0\n", "Episode:103/300, Reward:200.0\n", "Episode:104/300, Reward:200.0\n", "Episode:105/300, Reward:200.0\n", "Episode:106/300, Reward:200.0\n", "Episode:107/300, Reward:200.0\n", "Episode:108/300, Reward:200.0\n", "Episode:109/300, Reward:200.0\n", "Episode:110/300, Reward:200.0\n", "Episode:111/300, Reward:200.0\n", "Episode:112/300, Reward:200.0\n", "Episode:113/300, Reward:200.0\n", "Episode:114/300, Reward:200.0\n", "Episode:115/300, Reward:200.0\n", "Episode:116/300, Reward:200.0\n", "Episode:117/300, Reward:200.0\n", "Episode:118/300, Reward:200.0\n", "Episode:119/300, Reward:200.0\n", "Episode:120/300, Reward:196.0\n", "Episode:121/300, Reward:200.0\n", "Episode:122/300, Reward:200.0\n", "Episode:123/300, Reward:200.0\n", "Episode:124/300, Reward:200.0\n", "Episode:125/300, Reward:200.0\n", "Episode:126/300, Reward:189.0\n", "Episode:127/300, Reward:193.0\n", "Episode:128/300, Reward:200.0\n", "Episode:129/300, Reward:200.0\n", "Episode:130/300, Reward:193.0\n", "Episode:131/300, Reward:183.0\n", "Episode:132/300, Reward:183.0\n", "Episode:133/300, Reward:200.0\n", "Episode:134/300, Reward:200.0\n", "Episode:135/300, Reward:200.0\n", "Episode:136/300, Reward:200.0\n", "Episode:137/300, Reward:200.0\n", "Episode:138/300, Reward:200.0\n", "Episode:139/300, Reward:100.0\n", "Episode:140/300, Reward:118.0\n", "Episode:141/300, Reward:99.0\n", "Episode:142/300, Reward:185.0\n", "Episode:143/300, Reward:41.0\n", "Episode:144/300, Reward:11.0\n", "Episode:145/300, Reward:9.0\n", "Episode:146/300, Reward:152.0\n", "Episode:147/300, Reward:155.0\n", "Episode:148/300, Reward:181.0\n", "Episode:149/300, Reward:197.0\n", "Episode:150/300, Reward:200.0\n", "Episode:151/300, Reward:200.0\n", "Episode:152/300, Reward:200.0\n", "Episode:153/300, Reward:200.0\n", "Episode:154/300, Reward:200.0\n", "Episode:155/300, Reward:200.0\n", "Episode:156/300, Reward:123.0\n", "Episode:157/300, Reward:11.0\n", "Episode:158/300, Reward:8.0\n", "Episode:159/300, Reward:9.0\n", "Episode:160/300, Reward:10.0\n", "Episode:161/300, Reward:9.0\n", "Episode:162/300, Reward:10.0\n", "Episode:163/300, Reward:9.0\n", "Episode:164/300, Reward:9.0\n", "Episode:165/300, Reward:10.0\n", "Episode:166/300, Reward:9.0\n", "Episode:167/300, Reward:9.0\n", "Episode:168/300, Reward:9.0\n", "Episode:169/300, Reward:9.0\n", "Episode:170/300, Reward:10.0\n", "Episode:171/300, Reward:9.0\n", "Episode:172/300, Reward:9.0\n", "Episode:173/300, Reward:11.0\n", "Episode:174/300, Reward:11.0\n", "Episode:175/300, Reward:10.0\n", "Episode:176/300, Reward:9.0\n", "Episode:177/300, Reward:10.0\n", "Episode:178/300, Reward:8.0\n", "Episode:179/300, Reward:9.0\n", "Episode:180/300, Reward:9.0\n", "Episode:181/300, Reward:10.0\n", "Episode:182/300, Reward:10.0\n", "Episode:183/300, Reward:9.0\n", "Episode:184/300, Reward:10.0\n", "Episode:185/300, Reward:10.0\n", "Episode:186/300, Reward:13.0\n", "Episode:187/300, Reward:16.0\n", "Episode:188/300, Reward:117.0\n", "Episode:189/300, Reward:13.0\n", "Episode:190/300, Reward:16.0\n", "Episode:191/300, Reward:11.0\n", "Episode:192/300, Reward:11.0\n", "Episode:193/300, Reward:13.0\n", "Episode:194/300, Reward:13.0\n", "Episode:195/300, Reward:9.0\n", "Episode:196/300, Reward:20.0\n", "Episode:197/300, Reward:12.0\n", "Episode:198/300, Reward:10.0\n", "Episode:199/300, Reward:14.0\n", "Episode:200/300, Reward:12.0\n", "Episode:201/300, Reward:14.0\n", "Episode:202/300, Reward:12.0\n", "Episode:203/300, Reward:11.0\n", "Episode:204/300, Reward:10.0\n", "Episode:205/300, Reward:13.0\n", "Episode:206/300, Reward:10.0\n", "Episode:207/300, Reward:10.0\n", "Episode:208/300, Reward:13.0\n", "Episode:209/300, Reward:9.0\n", "Episode:210/300, Reward:11.0\n", "Episode:211/300, Reward:14.0\n", "Episode:212/300, Reward:10.0\n", "Episode:213/300, Reward:20.0\n", "Episode:214/300, Reward:12.0\n", "Episode:215/300, Reward:13.0\n", "Episode:216/300, Reward:17.0\n", "Episode:217/300, Reward:17.0\n", "Episode:218/300, Reward:11.0\n", "Episode:219/300, Reward:15.0\n", "Episode:220/300, Reward:26.0\n", "Episode:221/300, Reward:73.0\n", "Episode:222/300, Reward:44.0\n", "Episode:223/300, Reward:48.0\n", "Episode:224/300, Reward:102.0\n", "Episode:225/300, Reward:162.0\n", "Episode:226/300, Reward:123.0\n", "Episode:227/300, Reward:200.0\n", "Episode:228/300, Reward:200.0\n", "Episode:229/300, Reward:120.0\n", "Episode:230/300, Reward:173.0\n", "Episode:231/300, Reward:138.0\n", "Episode:232/300, Reward:106.0\n", "Episode:233/300, Reward:193.0\n", "Episode:234/300, Reward:117.0\n", "Episode:235/300, Reward:120.0\n", "Episode:236/300, Reward:98.0\n", "Episode:237/300, Reward:98.0\n", "Episode:238/300, Reward:200.0\n", "Episode:239/300, Reward:96.0\n", "Episode:240/300, Reward:170.0\n", "Episode:241/300, Reward:107.0\n", "Episode:242/300, Reward:107.0\n", "Episode:243/300, Reward:200.0\n", "Episode:244/300, Reward:128.0\n", "Episode:245/300, Reward:165.0\n", "Episode:246/300, Reward:168.0\n", "Episode:247/300, Reward:200.0\n", "Episode:248/300, Reward:200.0\n", "Episode:249/300, Reward:200.0\n", "Episode:250/300, Reward:200.0\n", "Episode:251/300, Reward:200.0\n", "Episode:252/300, Reward:200.0\n", "Episode:253/300, Reward:200.0\n", "Episode:254/300, Reward:200.0\n", "Episode:255/300, Reward:200.0\n", "Episode:256/300, Reward:200.0\n", "Episode:257/300, Reward:164.0\n", "Episode:258/300, Reward:200.0\n", "Episode:259/300, Reward:190.0\n", "Episode:260/300, Reward:185.0\n", "Episode:261/300, Reward:200.0\n", "Episode:262/300, Reward:200.0\n", "Episode:263/300, Reward:200.0\n", "Episode:264/300, Reward:200.0\n", "Episode:265/300, Reward:168.0\n", "Episode:266/300, Reward:200.0\n", "Episode:267/300, Reward:200.0\n", "Episode:268/300, Reward:200.0\n", "Episode:269/300, Reward:200.0\n", "Episode:270/300, Reward:200.0\n", "Episode:271/300, Reward:200.0\n", "Episode:272/300, Reward:200.0\n", "Episode:273/300, Reward:200.0\n", "Episode:274/300, Reward:200.0\n", "Episode:275/300, Reward:188.0\n", "Episode:276/300, Reward:200.0\n", "Episode:277/300, Reward:177.0\n", "Episode:278/300, Reward:200.0\n", "Episode:279/300, Reward:200.0\n", "Episode:280/300, Reward:200.0\n", "Episode:281/300, Reward:200.0\n", "Episode:282/300, Reward:200.0\n", "Episode:283/300, Reward:200.0\n", "Episode:284/300, Reward:189.0\n", "Episode:285/300, Reward:200.0\n", "Episode:286/300, Reward:200.0\n", "Episode:287/300, Reward:200.0\n", "Episode:288/300, Reward:200.0\n", "Episode:289/300, Reward:200.0\n", "Episode:290/300, Reward:200.0\n", "Episode:291/300, Reward:200.0\n", "Episode:292/300, Reward:200.0\n", "Episode:293/300, Reward:200.0\n", "Episode:294/300, Reward:200.0\n", "Episode:295/300, Reward:200.0\n", "Episode:296/300, Reward:200.0\n", "Episode:297/300, Reward:200.0\n", "Episode:298/300, Reward:200.0\n", "Episode:299/300, Reward:200.0\n", "Episode:300/300, Reward:200.0\n", "Complete training!\n", "results saved!\n" ] }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-03-31T14:01:15.395751\n image/svg+xml\n \n \n Matplotlib v3.3.4, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "image/png": "\n" }, "metadata": {} } ], "source": [ "env = gym.make('CartPole-v0')\n", "env.seed(1)\n", "cfg = HierarchicalDQNConfig()\n", "state_dim = env.observation_space.shape[0]\n", "action_dim = env.action_space.n\n", "agent = HierarchicalDQN(state_dim, action_dim, cfg)\n", "rewards, ma_rewards = train(cfg, env, agent)\n", "agent.save(path=SAVED_MODEL_PATH)\n", "save_results(rewards, ma_rewards, tag='train', path=RESULT_PATH)\n", "plot_rewards(rewards, ma_rewards, tag=\"train\",\n", " algo=cfg.algo, path=RESULT_PATH)" ] } ] }