194 lines
6.5 KiB
Plaintext
194 lines
6.5 KiB
Plaintext
{
|
||
"metadata": {
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.10"
|
||
},
|
||
"orig_nbformat": 2,
|
||
"kernelspec": {
|
||
"name": "python3710jvsc74a57bd0366e1054dee9d4501b0eb8f87335afd3c67fc62db6ee611bbc7f8f5a1fefe232",
|
||
"display_name": "Python 3.7.10 64-bit ('py37': conda)"
|
||
},
|
||
"metadata": {
|
||
"interpreter": {
|
||
"hash": "366e1054dee9d4501b0eb8f87335afd3c67fc62db6ee611bbc7f8f5a1fefe232"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2,
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sys\n",
|
||
"from pathlib import Path\n",
|
||
"curr_path = str(Path().absolute())\n",
|
||
"parent_path = str(Path().absolute().parent)\n",
|
||
"sys.path.append(parent_path) # add current terminal path to sys.path"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import gym\n",
|
||
"import torch\n",
|
||
"import datetime\n",
|
||
"from DoubleDQN.agent import DoubleDQN\n",
|
||
"from common.plot import plot_rewards\n",
|
||
"from common.utils import save_results, make_dir\n",
|
||
"\n",
|
||
"curr_time = datetime.datetime.now().strftime(\n",
|
||
" \"%Y%m%d-%H%M%S\") # obtain current time"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class DoubleDQNConfig:\n",
|
||
" def __init__(self):\n",
|
||
" self.algo = \"DoubleDQN\" # name of algo\n",
|
||
" self.env = 'CartPole-v0' # env name\n",
|
||
" self.result_path = curr_path+\"/outputs/\" + self.env + \\\n",
|
||
" '/'+curr_time+'/results/' # path to save results\n",
|
||
" self.model_path = curr_path+\"/outputs/\" + self.env + \\\n",
|
||
" '/'+curr_time+'/models/' # path to save models\n",
|
||
" self.train_eps = 200 # max tranng episodes\n",
|
||
" self.eval_eps = 50 # max evaling episodes\n",
|
||
" self.gamma = 0.95\n",
|
||
" self.epsilon_start = 1 # start epsilon of e-greedy policy\n",
|
||
" self.epsilon_end = 0.01 \n",
|
||
" self.epsilon_decay = 500\n",
|
||
" self.lr = 0.001 # learning rate\n",
|
||
" self.memory_capacity = 100000 # capacity of Replay Memory\n",
|
||
" self.batch_size = 64\n",
|
||
" self.target_update = 2 # update frequency of target net\n",
|
||
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # check gpu\n",
|
||
" self.hidden_dim = 256 # hidden size of net"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def env_agent_config(cfg,seed=1):\n",
|
||
" env = gym.make(cfg.env) \n",
|
||
" env.seed(seed)\n",
|
||
" state_dim = env.observation_space.shape[0]\n",
|
||
" action_dim = env.action_space.n\n",
|
||
" agent = DoubleDQN(state_dim,action_dim,cfg)\n",
|
||
" return env,agent"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train(cfg,env,agent):\n",
|
||
" print('Start to train !')\n",
|
||
" rewards,ma_rewards = [],[]\n",
|
||
" for i_ep in range(cfg.train_eps):\n",
|
||
" state = env.reset() \n",
|
||
" ep_reward = 0\n",
|
||
" while True:\n",
|
||
" action = agent.choose_action(state) \n",
|
||
" next_state, reward, done, _ = env.step(action)\n",
|
||
" ep_reward += reward\n",
|
||
" agent.memory.push(state, action, reward, next_state, done) \n",
|
||
" state = next_state \n",
|
||
" agent.update() \n",
|
||
" if done:\n",
|
||
" break\n",
|
||
" if i_ep % cfg.target_update == 0:\n",
|
||
" agent.target_net.load_state_dict(agent.policy_net.state_dict())\n",
|
||
" if (i_ep+1)%10 == 0:\n",
|
||
" print(f'Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward}')\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" if ma_rewards:\n",
|
||
" ma_rewards.append(\n",
|
||
" 0.9*ma_rewards[-1]+0.1*ep_reward)\n",
|
||
" else:\n",
|
||
" ma_rewards.append(ep_reward) \n",
|
||
" print('Complete training!')\n",
|
||
" return rewards,ma_rewards"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def eval(cfg,env,agent):\n",
|
||
" print('Start to eval !')\n",
|
||
" print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')\n",
|
||
" rewards = [] \n",
|
||
" ma_rewards = []\n",
|
||
" for i_ep in range(cfg.eval_eps):\n",
|
||
" state = env.reset() \n",
|
||
" ep_reward = 0 \n",
|
||
" while True:\n",
|
||
" action = agent.predict(state) \n",
|
||
" next_state, reward, done, _ = env.step(action) \n",
|
||
" state = next_state \n",
|
||
" ep_reward += reward\n",
|
||
" if done:\n",
|
||
" break\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" if ma_rewards:\n",
|
||
" ma_rewards.append(ma_rewards[-1]*0.9+ep_reward*0.1)\n",
|
||
" else:\n",
|
||
" ma_rewards.append(ep_reward)\n",
|
||
" print(f\"Episode:{i_ep+1}/{cfg.eval_eps}, reward:{ep_reward:.1f}\")\n",
|
||
" print('Complete evaling!')\n",
|
||
" return rewards,ma_rewards "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"if __name__ == \"__main__\":\n",
|
||
" cfg = DoubleDQNConfig()\n",
|
||
" # train\n",
|
||
" env,agent = env_agent_config(cfg,seed=1)\n",
|
||
" rewards, ma_rewards = train(cfg, env, agent)\n",
|
||
" make_dir(cfg.result_path, cfg.model_path)\n",
|
||
" agent.save(path=cfg.model_path)\n",
|
||
" save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)\n",
|
||
" plot_rewards(rewards, ma_rewards, tag=\"train\",\n",
|
||
" algo=cfg.algo, path=cfg.result_path)\n",
|
||
"\n",
|
||
" # eval\n",
|
||
" env,agent = env_agent_config(cfg,seed=10)\n",
|
||
" agent.load(path=cfg.model_path)\n",
|
||
" rewards,ma_rewards = eval(cfg,env,agent)\n",
|
||
" save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)\n",
|
||
" plot_rewards(rewards,ma_rewards,tag=\"eval\",env=cfg.env,algo = cfg.algo,path=cfg.result_path)"
|
||
]
|
||
}
|
||
]
|
||
} |