Files
easy-rl/codes/DoubleDQN/task0_train.ipynb
JohnJim0816 b17c8f4e41 update
2021-05-06 02:07:56 +08:00

194 lines
6.5 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3710jvsc74a57bd0366e1054dee9d4501b0eb8f87335afd3c67fc62db6ee611bbc7f8f5a1fefe232",
"display_name": "Python 3.7.10 64-bit ('py37': conda)"
},
"metadata": {
"interpreter": {
"hash": "366e1054dee9d4501b0eb8f87335afd3c67fc62db6ee611bbc7f8f5a1fefe232"
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"curr_path = str(Path().absolute())\n",
"parent_path = str(Path().absolute().parent)\n",
"sys.path.append(parent_path) # add current terminal path to sys.path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import torch\n",
"import datetime\n",
"from DoubleDQN.agent import DoubleDQN\n",
"from common.plot import plot_rewards\n",
"from common.utils import save_results, make_dir\n",
"\n",
"curr_time = datetime.datetime.now().strftime(\n",
" \"%Y%m%d-%H%M%S\") # obtain current time"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DoubleDQNConfig:\n",
" def __init__(self):\n",
" self.algo = \"DoubleDQN\" # name of algo\n",
" self.env = 'CartPole-v0' # env name\n",
" self.result_path = curr_path+\"/outputs/\" + self.env + \\\n",
" '/'+curr_time+'/results/' # path to save results\n",
" self.model_path = curr_path+\"/outputs/\" + self.env + \\\n",
" '/'+curr_time+'/models/' # path to save models\n",
" self.train_eps = 200 # max tranng episodes\n",
" self.eval_eps = 50 # max evaling episodes\n",
" self.gamma = 0.95\n",
" self.epsilon_start = 1 # start epsilon of e-greedy policy\n",
" self.epsilon_end = 0.01 \n",
" self.epsilon_decay = 500\n",
" self.lr = 0.001 # learning rate\n",
" self.memory_capacity = 100000 # capacity of Replay Memory\n",
" self.batch_size = 64\n",
" self.target_update = 2 # update frequency of target net\n",
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # check gpu\n",
" self.hidden_dim = 256 # hidden size of net"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def env_agent_config(cfg,seed=1):\n",
" env = gym.make(cfg.env) \n",
" env.seed(seed)\n",
" state_dim = env.observation_space.shape[0]\n",
" action_dim = env.action_space.n\n",
" agent = DoubleDQN(state_dim,action_dim,cfg)\n",
" return env,agent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(cfg,env,agent):\n",
" print('Start to train !')\n",
" rewards,ma_rewards = [],[]\n",
" for i_ep in range(cfg.train_eps):\n",
" state = env.reset() \n",
" ep_reward = 0\n",
" while True:\n",
" action = agent.choose_action(state) \n",
" next_state, reward, done, _ = env.step(action)\n",
" ep_reward += reward\n",
" agent.memory.push(state, action, reward, next_state, done) \n",
" state = next_state \n",
" agent.update() \n",
" if done:\n",
" break\n",
" if i_ep % cfg.target_update == 0:\n",
" agent.target_net.load_state_dict(agent.policy_net.state_dict())\n",
" if (i_ep+1)%10 == 0:\n",
" print(f'Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward}')\n",
" rewards.append(ep_reward)\n",
" if ma_rewards:\n",
" ma_rewards.append(\n",
" 0.9*ma_rewards[-1]+0.1*ep_reward)\n",
" else:\n",
" ma_rewards.append(ep_reward) \n",
" print('Complete training')\n",
" return rewards,ma_rewards"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def eval(cfg,env,agent):\n",
" print('Start to eval !')\n",
" print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')\n",
" rewards = [] \n",
" ma_rewards = []\n",
" for i_ep in range(cfg.eval_eps):\n",
" state = env.reset() \n",
" ep_reward = 0 \n",
" while True:\n",
" action = agent.predict(state) \n",
" next_state, reward, done, _ = env.step(action) \n",
" state = next_state \n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" rewards.append(ep_reward)\n",
" if ma_rewards:\n",
" ma_rewards.append(ma_rewards[-1]*0.9+ep_reward*0.1)\n",
" else:\n",
" ma_rewards.append(ep_reward)\n",
" print(f\"Episode:{i_ep+1}/{cfg.eval_eps}, reward:{ep_reward:.1f}\")\n",
" print('Complete evaling')\n",
" return rewards,ma_rewards "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" cfg = DoubleDQNConfig()\n",
" # train\n",
" env,agent = env_agent_config(cfg,seed=1)\n",
" rewards, ma_rewards = train(cfg, env, agent)\n",
" make_dir(cfg.result_path, cfg.model_path)\n",
" agent.save(path=cfg.model_path)\n",
" save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)\n",
" plot_rewards(rewards, ma_rewards, tag=\"train\",\n",
" algo=cfg.algo, path=cfg.result_path)\n",
"\n",
" # eval\n",
" env,agent = env_agent_config(cfg,seed=10)\n",
" agent.load(path=cfg.model_path)\n",
" rewards,ma_rewards = eval(cfg,env,agent)\n",
" save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)\n",
" plot_rewards(rewards,ma_rewards,tag=\"eval\",env=cfg.env,algo = cfg.algo,path=cfg.result_path)"
]
}
]
}