Files
easy-rl/codes/SAC/task0_train.ipynb
johnjim0816 75df999258 update
2021-12-22 11:19:13 +08:00

222 lines
14 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"curr_path = str(Path().absolute())\n",
"parent_path = str(Path().absolute().parent)\n",
"sys.path.append(parent_path) # add current terminal path to sys.path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import torch\n",
"import datetime\n",
"\n",
"from SAC.env import NormalizedActions\n",
"from SAC.agent import SAC\n",
"from common.utils import save_results, make_dir\n",
"from common.plot import plot_rewards\n",
"\n",
"curr_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\") # obtain current time"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class SACConfig:\n",
" def __init__(self) -> None:\n",
" self.algo = 'SAC'\n",
" self.env = 'Pendulum-v0'\n",
" self.result_path = curr_path+\"/outputs/\" +self.env+'/'+curr_time+'/results/' # path to save results\n",
" self.model_path = curr_path+\"/outputs/\" +self.env+'/'+curr_time+'/models/' # path to save models\n",
" self.train_eps = 300\n",
" self.train_steps = 500\n",
" self.test_eps = 50\n",
" self.eval_steps = 500\n",
" self.gamma = 0.99\n",
" self.mean_lambda=1e-3\n",
" self.std_lambda=1e-3\n",
" self.z_lambda=0.0\n",
" self.soft_tau=1e-2\n",
" self.value_lr = 3e-4\n",
" self.soft_q_lr = 3e-4\n",
" self.policy_lr = 3e-4\n",
" self.capacity = 1000000\n",
" self.hidden_dim = 256\n",
" self.batch_size = 128\n",
" self.device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def env_agent_config(cfg,seed=1):\n",
" env = NormalizedActions(gym.make(\"Pendulum-v0\"))\n",
" env.seed(seed)\n",
" action_dim = env.action_space.shape[0]\n",
" state_dim = env.observation_space.shape[0]\n",
" agent = SAC(state_dim,action_dim,cfg)\n",
" return env,agent"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def train(cfg,env,agent):\n",
" print('Start to train !')\n",
" print(f'Env: {cfg.env}, Algorithm: {cfg.algo}, Device: {cfg.device}')\n",
" rewards = []\n",
" ma_rewards = [] # moveing average reward\n",
" for i_ep in range(cfg.train_eps):\n",
" state = env.reset()\n",
" ep_reward = 0\n",
" for i_step in range(cfg.train_steps):\n",
" action = agent.policy_net.get_action(state)\n",
" next_state, reward, done, _ = env.step(action)\n",
" agent.memory.push(state, action, reward, next_state, done)\n",
" agent.update()\n",
" state = next_state\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" if (i_ep+1)%10==0:\n",
" print(f\"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}\")\n",
" rewards.append(ep_reward)\n",
" if ma_rewards:\n",
" ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n",
" else:\n",
" ma_rewards.append(ep_reward) \n",
" print('Complete training')\n",
" return rewards, ma_rewards"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def eval(cfg,env,agent):\n",
" print('Start to eval !')\n",
" print(f'Env: {cfg.env}, Algorithm: {cfg.algo}, Device: {cfg.device}')\n",
" rewards = []\n",
" ma_rewards = [] # moveing average reward\n",
" for i_ep in range(cfg.test_eps):\n",
" state = env.reset()\n",
" ep_reward = 0\n",
" for i_step in range(cfg.eval_steps):\n",
" action = agent.policy_net.get_action(state)\n",
" next_state, reward, done, _ = env.step(action)\n",
" state = next_state\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" if (i_ep+1)%10==0:\n",
" print(f\"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}\")\n",
" rewards.append(ep_reward)\n",
" if ma_rewards:\n",
" ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n",
" else:\n",
" ma_rewards.append(ep_reward) \n",
" print('Complete evaling')\n",
" return rewards, ma_rewards\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "DeprecatedEnv",
"evalue": "Env Pendulum-v0 not found (valid versions include ['Pendulum-v1'])",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mspec\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv_specs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'Pendulum-v0'",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mDeprecatedEnv\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-91b1038013e4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m# train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0magent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv_agent_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mrewards\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mma_rewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mmake_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-4-040773221550>\u001b[0m in \u001b[0;36menv_agent_config\u001b[0;34m(cfg, seed)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0menv_agent_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNormalizedActions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Pendulum-v0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0maction_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mstate_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, **kwargs)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 235\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mregistry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, path, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Making new env: %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 128\u001b[0;31m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 129\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mspec\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 185\u001b[0m raise error.DeprecatedEnv(\n\u001b[1;32m 186\u001b[0m \"Env {} not found (valid versions include {})\".format(\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmatching_envs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m )\n\u001b[1;32m 189\u001b[0m )\n",
"\u001b[0;31mDeprecatedEnv\u001b[0m: Env Pendulum-v0 not found (valid versions include ['Pendulum-v1'])"
]
}
],
"source": [
"if __name__ == \"__main__\":\n",
" cfg=SACConfig()\n",
" \n",
" # train\n",
" env,agent = env_agent_config(cfg,seed=1)\n",
" rewards, ma_rewards = train(cfg, env, agent)\n",
" make_dir(cfg.result_path, cfg.model_path)\n",
" agent.save(path=cfg.model_path)\n",
" save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)\n",
" plot_rewards(rewards, ma_rewards, tag=\"train\",\n",
" algo=cfg.algo, path=cfg.result_path)\n",
" # eval\n",
" env,agent = env_agent_config(cfg,seed=10)\n",
" agent.load(path=cfg.model_path)\n",
" rewards,ma_rewards = eval(cfg,env,agent)\n",
" save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)\n",
" plot_rewards(rewards,ma_rewards,tag=\"eval\",env=cfg.env,algo = cfg.algo,path=cfg.result_path)\n"
]
}
],
"metadata": {
"interpreter": {
"hash": "fe38df673a99c62a9fea33a7aceda74c9b65b12ee9d076c5851d98b692a4989a"
},
"kernelspec": {
"display_name": "Python 3.7.10 64-bit ('mujoco': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"metadata": {
"interpreter": {
"hash": "fd81e6a9e450d5c245c1a0b5da0b03c89c450f614a13afa2acb1654375922756"
}
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}