Files
easy-rl/codes/SAC/task0_train.ipynb
johnjim0816 659065e9db update
2021-05-07 16:31:25 +08:00

197 lines
6.5 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3710jvsc74a57bd0fd81e6a9e450d5c245c1a0b5da0b03c89c450f614a13afa2acb1654375922756",
"display_name": "Python 3.7.10 64-bit ('mujoco': conda)"
},
"metadata": {
"interpreter": {
"hash": "fd81e6a9e450d5c245c1a0b5da0b03c89c450f614a13afa2acb1654375922756"
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"curr_path = str(Path().absolute())\n",
"parent_path = str(Path().absolute().parent)\n",
"sys.path.append(parent_path) # add current terminal path to sys.path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import torch\n",
"import datetime\n",
"\n",
"from SAC.env import NormalizedActions\n",
"from SAC.agent import SAC\n",
"from common.utils import save_results, make_dir\n",
"from common.plot import plot_rewards\n",
"\n",
"curr_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\") # obtain current time"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class SACConfig:\n",
" def __init__(self) -> None:\n",
" self.algo = 'SAC'\n",
" self.env = 'Pendulum-v0'\n",
" self.result_path = curr_path+\"/outputs/\" +self.env+'/'+curr_time+'/results/' # path to save results\n",
" self.model_path = curr_path+\"/outputs/\" +self.env+'/'+curr_time+'/models/' # path to save models\n",
" self.train_eps = 300\n",
" self.train_steps = 500\n",
" self.eval_eps = 50\n",
" self.eval_steps = 500\n",
" self.gamma = 0.99\n",
" self.mean_lambda=1e-3\n",
" self.std_lambda=1e-3\n",
" self.z_lambda=0.0\n",
" self.soft_tau=1e-2\n",
" self.value_lr = 3e-4\n",
" self.soft_q_lr = 3e-4\n",
" self.policy_lr = 3e-4\n",
" self.capacity = 1000000\n",
" self.hidden_dim = 256\n",
" self.batch_size = 128\n",
" self.device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def env_agent_config(cfg,seed=1):\n",
" env = NormalizedActions(gym.make(\"Pendulum-v0\"))\n",
" env.seed(seed)\n",
" action_dim = env.action_space.shape[0]\n",
" state_dim = env.observation_space.shape[0]\n",
" agent = SAC(state_dim,action_dim,cfg)\n",
" return env,agent"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def train(cfg,env,agent):\n",
" print('Start to train !')\n",
" print(f'Env: {cfg.env}, Algorithm: {cfg.algo}, Device: {cfg.device}')\n",
" rewards = []\n",
" ma_rewards = [] # moveing average reward\n",
" for i_ep in range(cfg.train_eps):\n",
" state = env.reset()\n",
" ep_reward = 0\n",
" for i_step in range(cfg.train_steps):\n",
" action = agent.policy_net.get_action(state)\n",
" next_state, reward, done, _ = env.step(action)\n",
" agent.memory.push(state, action, reward, next_state, done)\n",
" agent.update()\n",
" state = next_state\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" if (i_ep+1)%10==0:\n",
" print(f\"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}\")\n",
" rewards.append(ep_reward)\n",
" if ma_rewards:\n",
" ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n",
" else:\n",
" ma_rewards.append(ep_reward) \n",
" print('Complete training')\n",
" return rewards, ma_rewards"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def eval(cfg,env,agent):\n",
" print('Start to eval !')\n",
" print(f'Env: {cfg.env}, Algorithm: {cfg.algo}, Device: {cfg.device}')\n",
" rewards = []\n",
" ma_rewards = [] # moveing average reward\n",
" for i_ep in range(cfg.eval_eps):\n",
" state = env.reset()\n",
" ep_reward = 0\n",
" for i_step in range(cfg.eval_steps):\n",
" action = agent.policy_net.get_action(state)\n",
" next_state, reward, done, _ = env.step(action)\n",
" state = next_state\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" if (i_ep+1)%10==0:\n",
" print(f\"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}\")\n",
" rewards.append(ep_reward)\n",
" if ma_rewards:\n",
" ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n",
" else:\n",
" ma_rewards.append(ep_reward) \n",
" print('Complete evaling')\n",
" return rewards, ma_rewards\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" cfg=SACConfig()\n",
" \n",
" # train\n",
" env,agent = env_agent_config(cfg,seed=1)\n",
" rewards, ma_rewards = train(cfg, env, agent)\n",
" make_dir(cfg.result_path, cfg.model_path)\n",
" agent.save(path=cfg.model_path)\n",
" save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)\n",
" plot_rewards(rewards, ma_rewards, tag=\"train\",\n",
" algo=cfg.algo, path=cfg.result_path)\n",
" # eval\n",
" env,agent = env_agent_config(cfg,seed=10)\n",
" agent.load(path=cfg.model_path)\n",
" rewards,ma_rewards = eval(cfg,env,agent)\n",
" save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)\n",
" plot_rewards(rewards,ma_rewards,tag=\"eval\",env=cfg.env,algo = cfg.algo,path=cfg.result_path)\n"
]
}
]
}