{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "curr_path = str(Path().absolute())\n", "parent_path = str(Path().absolute().parent)\n", "sys.path.append(parent_path) # add current terminal path to sys.path" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import torch\n", "import datetime\n", "\n", "from SAC.env import NormalizedActions\n", "from SAC.agent import SAC\n", "from common.utils import save_results, make_dir\n", "from common.plot import plot_rewards\n", "\n", "curr_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\") # obtain current time" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class SACConfig:\n", " def __init__(self) -> None:\n", " self.algo = 'SAC'\n", " self.env = 'Pendulum-v0'\n", " self.result_path = curr_path+\"/outputs/\" +self.env+'/'+curr_time+'/results/' # path to save results\n", " self.model_path = curr_path+\"/outputs/\" +self.env+'/'+curr_time+'/models/' # path to save models\n", " self.train_eps = 300\n", " self.train_steps = 500\n", " self.eval_eps = 50\n", " self.eval_steps = 500\n", " self.gamma = 0.99\n", " self.mean_lambda=1e-3\n", " self.std_lambda=1e-3\n", " self.z_lambda=0.0\n", " self.soft_tau=1e-2\n", " self.value_lr = 3e-4\n", " self.soft_q_lr = 3e-4\n", " self.policy_lr = 3e-4\n", " self.capacity = 1000000\n", " self.hidden_dim = 256\n", " self.batch_size = 128\n", " self.device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def env_agent_config(cfg,seed=1):\n", " env = NormalizedActions(gym.make(\"Pendulum-v0\"))\n", " env.seed(seed)\n", " action_dim = env.action_space.shape[0]\n", " state_dim = env.observation_space.shape[0]\n", " agent = SAC(state_dim,action_dim,cfg)\n", " return env,agent" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def train(cfg,env,agent):\n", " print('Start to train !')\n", " print(f'Env: {cfg.env}, Algorithm: {cfg.algo}, Device: {cfg.device}')\n", " rewards = []\n", " ma_rewards = [] # moveing average reward\n", " for i_ep in range(cfg.train_eps):\n", " state = env.reset()\n", " ep_reward = 0\n", " for i_step in range(cfg.train_steps):\n", " action = agent.policy_net.get_action(state)\n", " next_state, reward, done, _ = env.step(action)\n", " agent.memory.push(state, action, reward, next_state, done)\n", " agent.update()\n", " state = next_state\n", " ep_reward += reward\n", " if done:\n", " break\n", " if (i_ep+1)%10==0:\n", " print(f\"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}\")\n", " rewards.append(ep_reward)\n", " if ma_rewards:\n", " ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n", " else:\n", " ma_rewards.append(ep_reward) \n", " print('Complete training!')\n", " return rewards, ma_rewards" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def eval(cfg,env,agent):\n", " print('Start to eval !')\n", " print(f'Env: {cfg.env}, Algorithm: {cfg.algo}, Device: {cfg.device}')\n", " rewards = []\n", " ma_rewards = [] # moveing average reward\n", " for i_ep in range(cfg.eval_eps):\n", " state = env.reset()\n", " ep_reward = 0\n", " for i_step in range(cfg.eval_steps):\n", " action = agent.policy_net.get_action(state)\n", " next_state, reward, done, _ = env.step(action)\n", " state = next_state\n", " ep_reward += reward\n", " if done:\n", " break\n", " if (i_ep+1)%10==0:\n", " print(f\"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}\")\n", " rewards.append(ep_reward)\n", " if ma_rewards:\n", " ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)\n", " else:\n", " ma_rewards.append(ep_reward) \n", " print('Complete evaling!')\n", " return rewards, ma_rewards\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "ename": "DeprecatedEnv", "evalue": "Env Pendulum-v0 not found (valid versions include ['Pendulum-v1'])", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mspec\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv_specs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyError\u001b[0m: 'Pendulum-v0'", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mDeprecatedEnv\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m# train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0magent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv_agent_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mrewards\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mma_rewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mmake_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36menv_agent_config\u001b[0;34m(cfg, seed)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0menv_agent_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNormalizedActions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Pendulum-v0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0maction_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mstate_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, **kwargs)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 235\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mregistry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, path, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Making new env: %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 128\u001b[0;31m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 129\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mspec\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 185\u001b[0m raise error.DeprecatedEnv(\n\u001b[1;32m 186\u001b[0m \"Env {} not found (valid versions include {})\".format(\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmatching_envs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m )\n\u001b[1;32m 189\u001b[0m )\n", "\u001b[0;31mDeprecatedEnv\u001b[0m: Env Pendulum-v0 not found (valid versions include ['Pendulum-v1'])" ] } ], "source": [ "if __name__ == \"__main__\":\n", " cfg=SACConfig()\n", " \n", " # train\n", " env,agent = env_agent_config(cfg,seed=1)\n", " rewards, ma_rewards = train(cfg, env, agent)\n", " make_dir(cfg.result_path, cfg.model_path)\n", " agent.save(path=cfg.model_path)\n", " save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)\n", " plot_rewards(rewards, ma_rewards, tag=\"train\",\n", " algo=cfg.algo, path=cfg.result_path)\n", " # eval\n", " env,agent = env_agent_config(cfg,seed=10)\n", " agent.load(path=cfg.model_path)\n", " rewards,ma_rewards = eval(cfg,env,agent)\n", " save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)\n", " plot_rewards(rewards,ma_rewards,tag=\"eval\",env=cfg.env,algo = cfg.algo,path=cfg.result_path)\n" ] } ], "metadata": { "interpreter": { "hash": "fe38df673a99c62a9fea33a7aceda74c9b65b12ee9d076c5851d98b692a4989a" }, "kernelspec": { "display_name": "Python 3.7.10 64-bit ('mujoco': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" }, "metadata": { "interpreter": { "hash": "fd81e6a9e450d5c245c1a0b5da0b03c89c450f614a13afa2acb1654375922756" } }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }