{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" }, "orig_nbformat": 2, "kernelspec": { "name": "python3710jvsc74a57bd0366e1054dee9d4501b0eb8f87335afd3c67fc62db6ee611bbc7f8f5a1fefe232", "display_name": "Python 3.7.10 64-bit ('py37': conda)" }, "metadata": { "interpreter": { "hash": "366e1054dee9d4501b0eb8f87335afd3c67fc62db6ee611bbc7f8f5a1fefe232" } } }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "curr_path = str(Path().absolute())\n", "parent_path = str(Path().absolute().parent)\n", "sys.path.append(parent_path) # add current terminal path to sys.path" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import torch\n", "import datetime\n", "\n", "from common.utils import save_results, make_dir\n", "from common.plot import plot_rewards\n", "from DQN.agent import DQN\n", "\n", "curr_time = datetime.datetime.now().strftime(\n", " \"%Y%m%d-%H%M%S\") # obtain current time" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class DQNConfig:\n", " def __init__(self):\n", " self.algo = \"DQN\" # name of algo\n", " self.env = 'CartPole-v0'\n", " self.result_path = curr_path+\"/outputs/\" + self.env + \\\n", " '/'+curr_time+'/results/' # path to save results\n", " self.model_path = curr_path+\"/outputs/\" + self.env + \\\n", " '/'+curr_time+'/models/' # path to save results\n", " self.train_eps = 300 # max trainng episodes\n", " self.eval_eps = 50 # number of episodes for evaluating\n", " self.gamma = 0.95\n", " self.epsilon_start = 0.90 # start epsilon of e-greedy policy\n", " self.epsilon_end = 0.01\n", " self.epsilon_decay = 500\n", " self.lr = 0.0001 # learning rate\n", " self.memory_capacity = 100000 # capacity of Replay Memory\n", " self.batch_size = 64\n", " self.target_update = 2 # update frequency of target net\n", " self.device = torch.device(\n", " \"cuda\" if torch.cuda.is_available() else \"cpu\") # check gpu\n", " self.hidden_dim = 256 # hidden size of net" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def env_agent_config(cfg,seed=1):\n", " env = gym.make(cfg.env) \n", " env.seed(seed)\n", " state_dim = env.observation_space.shape[0]\n", " action_dim = env.action_space.n\n", " agent = DQN(state_dim,action_dim,cfg)\n", " return env,agent" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def train(cfg, env, agent):\n", " print('Start to train !')\n", " print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')\n", " rewards = []\n", " ma_rewards = [] # moveing average reward\n", " for i_ep in range(cfg.train_eps):\n", " state = env.reset()\n", " done = False\n", " ep_reward = 0\n", " while True:\n", " action = agent.choose_action(state)\n", " next_state, reward, done, _ = env.step(action)\n", " ep_reward += reward\n", " agent.memory.push(state, action, reward, next_state, done)\n", " state = next_state\n", " agent.update()\n", " if done:\n", " break\n", " if i_ep % cfg.target_update == 0:\n", " agent.target_net.load_state_dict(agent.policy_net.state_dict())\n", " if (i_ep+1)%10 == 0:\n", " print('Episode:{}/{}, Reward:{}'.format(i_ep+1, cfg.train_eps, ep_reward))\n", " rewards.append(ep_reward)\n", " # save ma rewards\n", " if ma_rewards:\n", " ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n", " else:\n", " ma_rewards.append(ep_reward)\n", " print('Complete training!')\n", " return rewards, ma_rewards" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def eval(cfg,env,agent):\n", " rewards = [] \n", " ma_rewards = [] # moving average rewards\n", " for i_ep in range(cfg.eval_eps):\n", " ep_reward = 0 # reward per episode\n", " state = env.reset() \n", " while True:\n", " action = agent.predict(state) \n", " next_state, reward, done, _ = env.step(action) \n", " state = next_state \n", " ep_reward += reward\n", " if done:\n", " break\n", " rewards.append(ep_reward)\n", " if ma_rewards:\n", " ma_rewards.append(ma_rewards[-1]*0.9+ep_reward*0.1)\n", " else:\n", " ma_rewards.append(ep_reward)\n", " if (i_ep+1)%10==0:\n", " print(f\"Episode:{i_ep+1}/{cfg.eval_eps}, reward:{ep_reward:.1f}\")\n", " return rewards,ma_rewards" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Start to train !\n", "Env:CartPole-v0, Algorithm:DQN, Device:cuda\n", "Episode:10/300, Reward:13.0\n", "Episode:20/300, Reward:14.0\n", "Episode:30/300, Reward:14.0\n", "Episode:40/300, Reward:12.0\n", "Episode:50/300, Reward:125.0\n", "Episode:60/300, Reward:98.0\n", "Episode:70/300, Reward:200.0\n", "Episode:80/300, Reward:160.0\n", "Episode:90/300, Reward:200.0\n", "Episode:100/300, Reward:200.0\n", "Episode:110/300, Reward:200.0\n", "Episode:120/300, Reward:198.0\n", "Episode:130/300, Reward:200.0\n", "Episode:140/300, Reward:200.0\n", "Episode:150/300, Reward:200.0\n", "Episode:160/300, Reward:200.0\n", "Episode:170/300, Reward:200.0\n", "Episode:180/300, Reward:200.0\n", "Episode:190/300, Reward:200.0\n", "Episode:200/300, Reward:200.0\n", "Episode:210/300, Reward:200.0\n", "Episode:220/300, Reward:200.0\n", "Episode:230/300, Reward:188.0\n", "Episode:240/300, Reward:200.0\n", "Episode:250/300, Reward:200.0\n", "Episode:260/300, Reward:193.0\n", "Episode:270/300, Reward:200.0\n", "Episode:280/300, Reward:200.0\n", "Episode:290/300, Reward:200.0\n", "Episode:300/300, Reward:200.0\n", "Complete training!\n", "results saved!\n" ] }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T19:04:03.044086\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "image/png": "\n" }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Episode:10/50, reward:188.0\n", "Episode:20/50, reward:200.0\n", "Episode:30/50, reward:200.0\n", "Episode:40/50, reward:200.0\n", "Episode:50/50, reward:171.0\n", "results saved!\n" ] }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T19:04:05.465993\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "image/png": "\n" }, "metadata": {} } ], "source": [ "if __name__ == \"__main__\":\n", " cfg = DQNConfig()\n", "\n", " # train\n", " env,agent = env_agent_config(cfg,seed=1)\n", " rewards, ma_rewards = train(cfg, env, agent)\n", " make_dir(cfg.result_path, cfg.model_path)\n", " agent.save(path=cfg.model_path)\n", " save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)\n", " plot_rewards(rewards, ma_rewards, tag=\"train\",\n", " algo=cfg.algo, path=cfg.result_path)\n", " # eval\n", " env,agent = env_agent_config(cfg,seed=10)\n", " agent.load(path=cfg.model_path)\n", " rewards,ma_rewards = eval(cfg,env,agent)\n", " save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)\n", " plot_rewards(rewards,ma_rewards,tag=\"eval\",env=cfg.env,algo = cfg.algo,path=cfg.result_path)" ] } ] }