hot update PG
@@ -1,318 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 定义模型\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import paddle\n",
|
|
||||||
"import paddle.nn as nn\n",
|
|
||||||
"import paddle.nn.functional as F\n",
|
|
||||||
"import parl\n",
|
|
||||||
"\n",
|
|
||||||
"class CartpoleModel(parl.Model):\n",
|
|
||||||
" \"\"\" Linear network to solve Cartpole problem.\n",
|
|
||||||
" Args:\n",
|
|
||||||
" n_states (int): Dimension of observation space.\n",
|
|
||||||
" n_actions (int): Dimension of action space.\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" def __init__(self, n_states, n_actions):\n",
|
|
||||||
" super(CartpoleModel, self).__init__()\n",
|
|
||||||
" hid1_size = 128\n",
|
|
||||||
" hid2_size = 128\n",
|
|
||||||
" self.fc1 = nn.Linear(n_states, hid1_size)\n",
|
|
||||||
" self.fc2 = nn.Linear(hid1_size, hid2_size)\n",
|
|
||||||
" self.fc3 = nn.Linear(hid2_size, n_actions)\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(self, obs):\n",
|
|
||||||
" h1 = F.relu(self.fc1(obs))\n",
|
|
||||||
" h2 = F.relu(self.fc2(h1))\n",
|
|
||||||
" Q = self.fc3(h2)\n",
|
|
||||||
" return Q"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 16,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import parl\n",
|
|
||||||
"import paddle\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"class CartpoleAgent(parl.Agent):\n",
|
|
||||||
" \"\"\"Agent of Cartpole env.\n",
|
|
||||||
" Args:\n",
|
|
||||||
" algorithm(parl.Algorithm): algorithm used to solve the problem.\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" def __init__(self, algorithm, n_actions, e_greed=0.1, e_greed_decrement=0):\n",
|
|
||||||
" super(CartpoleAgent, self).__init__(algorithm)\n",
|
|
||||||
" assert isinstance(n_actions, int)\n",
|
|
||||||
" self.n_actions = n_actions\n",
|
|
||||||
"\n",
|
|
||||||
" self.global_step = 0\n",
|
|
||||||
" self.update_target_steps = 200\n",
|
|
||||||
"\n",
|
|
||||||
" self.e_greed = e_greed\n",
|
|
||||||
" self.e_greed_decrement = e_greed_decrement\n",
|
|
||||||
"\n",
|
|
||||||
" def sample(self, obs):\n",
|
|
||||||
" \"\"\"Sample an action `for exploration` when given an observation\n",
|
|
||||||
" Args:\n",
|
|
||||||
" obs(np.float32): shape of (n_states,)\n",
|
|
||||||
" Returns:\n",
|
|
||||||
" act(int): action\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" sample = np.random.random()\n",
|
|
||||||
" if sample < self.e_greed:\n",
|
|
||||||
" act = np.random.randint(self.n_actions)\n",
|
|
||||||
" else:\n",
|
|
||||||
" if np.random.random() < 0.01:\n",
|
|
||||||
" act = np.random.randint(self.n_actions)\n",
|
|
||||||
" else:\n",
|
|
||||||
" act = self.predict(obs)\n",
|
|
||||||
" self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)\n",
|
|
||||||
" return act\n",
|
|
||||||
"\n",
|
|
||||||
" def predict(self, obs):\n",
|
|
||||||
" \"\"\"Predict an action when given an observation\n",
|
|
||||||
" Args:\n",
|
|
||||||
" obs(np.float32): shape of (n_states,)\n",
|
|
||||||
" Returns:\n",
|
|
||||||
" act(int): action\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" obs = paddle.to_tensor(obs, dtype='float32')\n",
|
|
||||||
" pred_q = self.alg.predict(obs)\n",
|
|
||||||
" act = pred_q.argmax().numpy()[0]\n",
|
|
||||||
" return act\n",
|
|
||||||
"\n",
|
|
||||||
" def learn(self, obs, act, reward, next_obs, terminal):\n",
|
|
||||||
" \"\"\"Update model with an episode data\n",
|
|
||||||
" Args:\n",
|
|
||||||
" obs(np.float32): shape of (batch_size, n_states)\n",
|
|
||||||
" act(np.int32): shape of (batch_size)\n",
|
|
||||||
" reward(np.float32): shape of (batch_size)\n",
|
|
||||||
" next_obs(np.float32): shape of (batch_size, n_states)\n",
|
|
||||||
" terminal(np.float32): shape of (batch_size)\n",
|
|
||||||
" Returns:\n",
|
|
||||||
" loss(float)\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" if self.global_step % self.update_target_steps == 0:\n",
|
|
||||||
" self.alg.sync_target()\n",
|
|
||||||
" self.global_step += 1\n",
|
|
||||||
"\n",
|
|
||||||
" act = np.expand_dims(act, axis=-1)\n",
|
|
||||||
" reward = np.expand_dims(reward, axis=-1)\n",
|
|
||||||
" terminal = np.expand_dims(terminal, axis=-1)\n",
|
|
||||||
"\n",
|
|
||||||
" obs = paddle.to_tensor(obs, dtype='float32')\n",
|
|
||||||
" act = paddle.to_tensor(act, dtype='int32')\n",
|
|
||||||
" reward = paddle.to_tensor(reward, dtype='float32')\n",
|
|
||||||
" next_obs = paddle.to_tensor(next_obs, dtype='float32')\n",
|
|
||||||
" terminal = paddle.to_tensor(terminal, dtype='float32')\n",
|
|
||||||
" loss = self.alg.learn(obs, act, reward, next_obs, terminal)\n",
|
|
||||||
" return loss.numpy()[0]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 17,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import gym\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import parl\n",
|
|
||||||
"\n",
|
|
||||||
"from parl.utils import logger, ReplayMemory\n",
|
|
||||||
"from parl.algorithms import DQN\n",
|
|
||||||
"\n",
|
|
||||||
"LEARN_FREQ = 5 # training frequency\n",
|
|
||||||
"MEMORY_SIZE = 200000\n",
|
|
||||||
"MEMORY_WARMUP_SIZE = 200\n",
|
|
||||||
"BATCH_SIZE = 64\n",
|
|
||||||
"LEARNING_RATE = 0.0005\n",
|
|
||||||
"GAMMA = 0.99\n",
|
|
||||||
"\n",
|
|
||||||
"# train an episode\n",
|
|
||||||
"def run_train_episode(agent, env, rpm):\n",
|
|
||||||
" total_reward = 0\n",
|
|
||||||
" obs = env.reset()\n",
|
|
||||||
" step = 0\n",
|
|
||||||
" while True:\n",
|
|
||||||
" step += 1\n",
|
|
||||||
" action = agent.sample(obs)\n",
|
|
||||||
" next_obs, reward, done, _ = env.step(action)\n",
|
|
||||||
" rpm.append(obs, action, reward, next_obs, done)\n",
|
|
||||||
"\n",
|
|
||||||
" # train model\n",
|
|
||||||
" if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):\n",
|
|
||||||
" # s,a,r,s',done\n",
|
|
||||||
" (batch_obs, batch_action, batch_reward, batch_next_obs,\n",
|
|
||||||
" batch_done) = rpm.sample_batch(BATCH_SIZE)\n",
|
|
||||||
" train_loss = agent.learn(batch_obs, batch_action, batch_reward,\n",
|
|
||||||
" batch_next_obs, batch_done)\n",
|
|
||||||
"\n",
|
|
||||||
" total_reward += reward\n",
|
|
||||||
" obs = next_obs\n",
|
|
||||||
" if done:\n",
|
|
||||||
" break\n",
|
|
||||||
" return total_reward\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# evaluate 5 episodes\n",
|
|
||||||
"def run_evaluate_episodes(agent, env, eval_episodes=5, render=False):\n",
|
|
||||||
" eval_reward = []\n",
|
|
||||||
" for i in range(eval_episodes):\n",
|
|
||||||
" obs = env.reset()\n",
|
|
||||||
" episode_reward = 0\n",
|
|
||||||
" while True:\n",
|
|
||||||
" action = agent.predict(obs)\n",
|
|
||||||
" obs, reward, done, _ = env.step(action)\n",
|
|
||||||
" episode_reward += reward\n",
|
|
||||||
" if render:\n",
|
|
||||||
" env.render()\n",
|
|
||||||
" if done:\n",
|
|
||||||
" break\n",
|
|
||||||
" eval_reward.append(episode_reward)\n",
|
|
||||||
" return np.mean(eval_reward)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def main(args):\n",
|
|
||||||
" env = gym.make('CartPole-v0')\n",
|
|
||||||
" n_states = env.observation_space.shape[0]\n",
|
|
||||||
" n_actions = env.action_space.n\n",
|
|
||||||
" logger.info('n_states {}, n_actions {}'.format(n_states, n_actions))\n",
|
|
||||||
"\n",
|
|
||||||
" # set action_shape = 0 while in discrete control environment\n",
|
|
||||||
" rpm = ReplayMemory(MEMORY_SIZE, n_states, 0)\n",
|
|
||||||
"\n",
|
|
||||||
" # build an agent\n",
|
|
||||||
" model = CartpoleModel(n_states=n_states, n_actions=n_actions)\n",
|
|
||||||
" alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)\n",
|
|
||||||
" agent = CartpoleAgent(\n",
|
|
||||||
" alg, n_actions=n_actions, e_greed=0.1, e_greed_decrement=1e-6)\n",
|
|
||||||
"\n",
|
|
||||||
" # warmup memory\n",
|
|
||||||
" while len(rpm) < MEMORY_WARMUP_SIZE:\n",
|
|
||||||
" run_train_episode(agent, env, rpm)\n",
|
|
||||||
"\n",
|
|
||||||
" max_episode = args.max_episode\n",
|
|
||||||
"\n",
|
|
||||||
" # start training\n",
|
|
||||||
" episode = 0\n",
|
|
||||||
" while episode < max_episode:\n",
|
|
||||||
" # train part\n",
|
|
||||||
" for i in range(50):\n",
|
|
||||||
" total_reward = run_train_episode(agent, env, rpm)\n",
|
|
||||||
" episode += 1\n",
|
|
||||||
"\n",
|
|
||||||
" # test part\n",
|
|
||||||
" eval_reward = run_evaluate_episodes(agent, env, render=False)\n",
|
|
||||||
" logger.info('episode:{} e_greed:{} Test reward:{}'.format(\n",
|
|
||||||
" episode, agent.e_greed, eval_reward))\n",
|
|
||||||
"\n",
|
|
||||||
" # save the parameters to ./model.ckpt\n",
|
|
||||||
" save_path = './model.ckpt'\n",
|
|
||||||
" agent.save(save_path)\n",
|
|
||||||
"\n",
|
|
||||||
" # save the model and parameters of policy network for inference\n",
|
|
||||||
" save_inference_path = './inference_model'\n",
|
|
||||||
" input_shapes = [[None, env.observation_space.shape[0]]]\n",
|
|
||||||
" input_dtypes = ['float32']\n",
|
|
||||||
" agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
" "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 18,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:64]\u001b[0m obs_dim 4, act_dim 2\n",
|
|
||||||
"\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:92]\u001b[0m episode:50 e_greed:0.0988929999999989 Test reward:18.4\n",
|
|
||||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:100 e_greed:0.09794799999999795 Test reward:9.6\n",
|
|
||||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:150 e_greed:0.0973899999999974 Test reward:37.8\n",
|
|
||||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:200 e_greed:0.09684299999999685 Test reward:8.8\n",
|
|
||||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:250 e_greed:0.09635499999999636 Test reward:9.4\n",
|
|
||||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:300 e_greed:0.09585299999999586 Test reward:9.2\n",
|
|
||||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:350 e_greed:0.09535799999999536 Test reward:9.2\n",
|
|
||||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:400 e_greed:0.09486399999999487 Test reward:10.0\n",
|
|
||||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:450 e_greed:0.09435299999999436 Test reward:9.2\n",
|
|
||||||
"\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:500 e_greed:0.09384899999999385 Test reward:9.4\n",
|
|
||||||
"\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:550 e_greed:0.09302299999999303 Test reward:69.0\n",
|
|
||||||
"\u001b[32m[08-01 21:48:25 MainThread @3996942455.py:92]\u001b[0m episode:600 e_greed:0.08774199999998775 Test reward:141.2\n",
|
|
||||||
"\u001b[32m[08-01 21:48:30 MainThread @3996942455.py:92]\u001b[0m episode:650 e_greed:0.0791019999999791 Test reward:184.0\n",
|
|
||||||
"\u001b[32m[08-01 21:48:35 MainThread @3996942455.py:92]\u001b[0m episode:700 e_greed:0.07011299999997012 Test reward:182.0\n",
|
|
||||||
"\u001b[32m[08-01 21:48:40 MainThread @3996942455.py:92]\u001b[0m episode:750 e_greed:0.06089099999996089 Test reward:197.4\n",
|
|
||||||
"\u001b[32m[08-01 21:48:45 MainThread @3996942455.py:92]\u001b[0m episode:800 e_greed:0.05139199999995139 Test reward:183.4\n",
|
|
||||||
"\u001b[32m[08-01 21:48:50 MainThread @3996942455.py:92]\u001b[0m episode:850 e_greed:0.042255999999942256 Test reward:153.0\n",
|
|
||||||
"\u001b[32m[08-01 21:48:55 MainThread @3996942455.py:92]\u001b[0m episode:900 e_greed:0.033495999999933496 Test reward:192.6\n",
|
|
||||||
"\u001b[32m[08-01 21:49:00 MainThread @3996942455.py:92]\u001b[0m episode:950 e_greed:0.024318999999924318 Test reward:166.6\n",
|
|
||||||
"\u001b[32m[08-01 21:49:06 MainThread @3996942455.py:92]\u001b[0m episode:1000 e_greed:0.014873999999916176 Test reward:187.0\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import argparse\n",
|
|
||||||
"parser = argparse.ArgumentParser()\n",
|
|
||||||
"parser.add_argument(\n",
|
|
||||||
" '--max_episode',\n",
|
|
||||||
" type=int,\n",
|
|
||||||
" default=1000,\n",
|
|
||||||
" help='stop condition: number of max episode')\n",
|
|
||||||
"args = parser.parse_args(args=[])\n",
|
|
||||||
"\n",
|
|
||||||
"main(args)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3.7.12 ('rl_tutorials')",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.7.12"
|
|
||||||
},
|
|
||||||
"orig_nbformat": 4,
|
|
||||||
"vscode": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "4f613f1ab80ec98dc1b91d6e720de51301598a187317378e53e49b773c1123dd"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
[PARL](https://github.com/PaddlePaddle/PARL)是一个高性能、灵活的强化学习框架,由百度AI Studio开发。
|
|
||||||
|
|
||||||
## 安装
|
|
||||||
|
|
||||||
1. 安装parl,参考[PARL Github](https://github.com/PaddlePaddle/PARL)
|
|
||||||
2. 安装paddlepaddle:```pip install paddlepaddle```
|
|
||||||
|
|
||||||
## 常见问题
|
|
||||||
|
|
||||||
```jupyter-client 7.3.1 requires pyzmq>=22.3, but you have pyzmq 18.1.1 which is incompatible.```:
|
|
||||||
```pip install -U pyzmq```
|
|
||||||
@@ -11,7 +11,6 @@
|
|||||||
项目内容主要包含以下几个部分:
|
项目内容主要包含以下几个部分:
|
||||||
* [Jupyter Notebook](./notebooks/):使用Notebook写的算法,有比较详细的实战引导,推荐新手食用
|
* [Jupyter Notebook](./notebooks/):使用Notebook写的算法,有比较详细的实战引导,推荐新手食用
|
||||||
* [codes](./codes/):这些是基于Python脚本写的算法,风格比较接近实际项目的写法,推荐有一定代码基础的人阅读,下面会说明其具体的一些架构
|
* [codes](./codes/):这些是基于Python脚本写的算法,风格比较接近实际项目的写法,推荐有一定代码基础的人阅读,下面会说明其具体的一些架构
|
||||||
* [parl](./PARL/):应业务需求,写了一些基于百度飞浆平台和```parl```模块的RL实例
|
|
||||||
* [附件](./assets/):目前包含强化学习各算法的中文伪代码
|
* [附件](./assets/):目前包含强化学习各算法的中文伪代码
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
129
projects/codes/PolicyGradient/main.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
'''
|
||||||
|
Author: John
|
||||||
|
Email: johnjim0816@gmail.com
|
||||||
|
Date: 2020-11-22 23:21:53
|
||||||
|
LastEditor: John
|
||||||
|
LastEditTime: 2022-08-25 20:59:23
|
||||||
|
Discription:
|
||||||
|
Environment:
|
||||||
|
'''
|
||||||
|
import sys,os
|
||||||
|
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||||
|
parent_path = os.path.dirname(curr_path) # parent path
|
||||||
|
sys.path.append(parent_path) # add to system path
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import datetime
|
||||||
|
import argparse
|
||||||
|
from itertools import count
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from pg import PolicyGradient
|
||||||
|
from common.utils import save_results, make_dir,all_seed,save_args,plot_rewards
|
||||||
|
from common.models import MLP
|
||||||
|
from common.memories import PGReplay
|
||||||
|
from common.launcher import Launcher
|
||||||
|
from envs.register import register_env
|
||||||
|
|
||||||
|
|
||||||
|
class PGNet(MLP):
|
||||||
|
''' instead of outputing action, PG Net outputs propabilities of actions, we can use class inheritance from MLP here
|
||||||
|
'''
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
x = F.sigmoid(self.fc3(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Main(Launcher):
|
||||||
|
def get_args(self):
|
||||||
|
""" Hyperparameters
|
||||||
|
"""
|
||||||
|
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # Obtain current time
|
||||||
|
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||||
|
parser.add_argument('--algo_name',default='PolicyGradient',type=str,help="name of algorithm")
|
||||||
|
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||||
|
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||||
|
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||||
|
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||||
|
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
|
||||||
|
parser.add_argument('--update_fre',default=8,type=int)
|
||||||
|
parser.add_argument('--hidden_dim',default=36,type=int)
|
||||||
|
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||||
|
parser.add_argument('--seed',default=1,type=int,help="seed")
|
||||||
|
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||||
|
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||||
|
args = parser.parse_args()
|
||||||
|
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||||
|
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||||
|
}
|
||||||
|
args = {**vars(args),**default_args} # type(dict)
|
||||||
|
return args
|
||||||
|
def env_agent_config(self,cfg):
|
||||||
|
register_env(cfg['env_name'])
|
||||||
|
env = gym.make(cfg['env_name'])
|
||||||
|
if cfg['seed'] !=0: # set random seed
|
||||||
|
all_seed(env,seed=cfg['seed'])
|
||||||
|
n_states = env.observation_space.shape[0]
|
||||||
|
n_actions = env.action_space.n # action dimension
|
||||||
|
print(f"state dim: {n_states}, action dim: {n_actions}")
|
||||||
|
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||||
|
model = PGNet(n_states,1,hidden_dim=cfg['hidden_dim'])
|
||||||
|
memory = PGReplay()
|
||||||
|
agent = PolicyGradient(model,memory,cfg)
|
||||||
|
return env,agent
|
||||||
|
def train(self,cfg,env,agent):
|
||||||
|
print("Start training!")
|
||||||
|
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||||
|
rewards = []
|
||||||
|
for i_ep in range(cfg['train_eps']):
|
||||||
|
state = env.reset()
|
||||||
|
ep_reward = 0
|
||||||
|
for _ in count():
|
||||||
|
action = agent.sample_action(state) # sample action
|
||||||
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
ep_reward += reward
|
||||||
|
if done:
|
||||||
|
reward = 0
|
||||||
|
agent.memory.push((state,float(action),reward))
|
||||||
|
state = next_state
|
||||||
|
if done:
|
||||||
|
print(f"Episode:{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
|
||||||
|
break
|
||||||
|
if (i_ep+1) % cfg['update_fre'] == 0:
|
||||||
|
agent.update()
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
print('Finish training!')
|
||||||
|
env.close() # close environment
|
||||||
|
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
||||||
|
return res_dic
|
||||||
|
|
||||||
|
def test(self,cfg,env,agent):
|
||||||
|
print("Start testing!")
|
||||||
|
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||||
|
rewards = []
|
||||||
|
for i_ep in range(cfg['test_eps']):
|
||||||
|
state = env.reset()
|
||||||
|
ep_reward = 0
|
||||||
|
for _ in count():
|
||||||
|
action = agent.predict_action(state)
|
||||||
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
ep_reward += reward
|
||||||
|
if done:
|
||||||
|
reward = 0
|
||||||
|
state = next_state
|
||||||
|
if done:
|
||||||
|
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||||
|
break
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
print("Finish testing!")
|
||||||
|
env.close()
|
||||||
|
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main = Main()
|
||||||
|
main.run()
|
||||||
|
|
||||||
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"algo_name": "PolicyGradient",
|
|
||||||
"env_name": "CartPole-v0",
|
|
||||||
"train_eps": 200,
|
|
||||||
"test_eps": 20,
|
|
||||||
"gamma": 0.99,
|
|
||||||
"lr": 0.005,
|
|
||||||
"update_fre": 8,
|
|
||||||
"hidden_dim": 36,
|
|
||||||
"device": "cpu",
|
|
||||||
"seed": 1,
|
|
||||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220822-174059/results/",
|
|
||||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220822-174059/models/",
|
|
||||||
"save_fig": true,
|
|
||||||
"show_fig": false
|
|
||||||
}
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.99, "lr": 0.005, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/models/", "n_states": 4, "n_actions": 2}
|
||||||
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 66 KiB |
@@ -5,7 +5,7 @@ Author: John
|
|||||||
Email: johnjim0816@gmail.com
|
Email: johnjim0816@gmail.com
|
||||||
Date: 2020-11-22 23:27:44
|
Date: 2020-11-22 23:27:44
|
||||||
LastEditor: John
|
LastEditor: John
|
||||||
LastEditTime: 2022-08-22 17:35:34
|
LastEditTime: 2022-08-25 20:58:59
|
||||||
Discription:
|
Discription:
|
||||||
Environment:
|
Environment:
|
||||||
'''
|
'''
|
||||||
@@ -19,12 +19,12 @@ import numpy as np
|
|||||||
|
|
||||||
class PolicyGradient:
|
class PolicyGradient:
|
||||||
|
|
||||||
def __init__(self, n_states,model,memory,cfg):
|
def __init__(self, model,memory,cfg):
|
||||||
self.gamma = cfg.gamma
|
self.gamma = cfg['gamma']
|
||||||
self.device = torch.device(cfg.device)
|
self.device = torch.device(cfg['device'])
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.policy_net = model.to(self.device)
|
self.policy_net = model.to(self.device)
|
||||||
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
|
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg['lr'])
|
||||||
|
|
||||||
def sample_action(self,state):
|
def sample_action(self,state):
|
||||||
|
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
'''
|
|
||||||
Author: John
|
|
||||||
Email: johnjim0816@gmail.com
|
|
||||||
Date: 2020-11-22 23:21:53
|
|
||||||
LastEditor: John
|
|
||||||
LastEditTime: 2022-08-22 17:40:07
|
|
||||||
Discription:
|
|
||||||
Environment:
|
|
||||||
'''
|
|
||||||
import sys,os
|
|
||||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
|
||||||
parent_path = os.path.dirname(curr_path) # parent path
|
|
||||||
sys.path.append(parent_path) # add to system path
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import torch
|
|
||||||
import datetime
|
|
||||||
import argparse
|
|
||||||
from itertools import count
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from pg import PolicyGradient
|
|
||||||
from common.utils import save_results, make_dir,all_seed,save_args,plot_rewards
|
|
||||||
from common.models import MLP
|
|
||||||
from common.memories import PGReplay
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
""" Hyperparameters
|
|
||||||
"""
|
|
||||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # Obtain current time
|
|
||||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
|
||||||
parser.add_argument('--algo_name',default='PolicyGradient',type=str,help="name of algorithm")
|
|
||||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
|
||||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
|
||||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
|
||||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
|
||||||
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
|
|
||||||
parser.add_argument('--update_fre',default=8,type=int)
|
|
||||||
parser.add_argument('--hidden_dim',default=36,type=int)
|
|
||||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
|
||||||
parser.add_argument('--seed',default=1,type=int,help="seed")
|
|
||||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
|
||||||
'/' + curr_time + '/results/' )
|
|
||||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
|
||||||
'/' + curr_time + '/models/' ) # path to save models
|
|
||||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
|
||||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
|
||||||
args = parser.parse_args([])
|
|
||||||
return args
|
|
||||||
|
|
||||||
class PGNet(MLP):
|
|
||||||
''' instead of outputing action, PG Net outputs propabilities of actions, we can use class inheritance from MLP here
|
|
||||||
'''
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = F.relu(self.fc2(x))
|
|
||||||
x = F.sigmoid(self.fc3(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def env_agent_config(cfg):
|
|
||||||
env = gym.make(cfg.env_name)
|
|
||||||
if cfg.seed !=0: # set random seed
|
|
||||||
all_seed(env,seed=cfg.seed)
|
|
||||||
n_states = env.observation_space.shape[0]
|
|
||||||
n_actions = env.action_space.n # action dimension
|
|
||||||
print(f"state dim: {n_states}, action dim: {n_actions}")
|
|
||||||
model = PGNet(n_states,1,hidden_dim=cfg.hidden_dim)
|
|
||||||
memory = PGReplay()
|
|
||||||
agent = PolicyGradient(n_states,model,memory,cfg)
|
|
||||||
return env,agent
|
|
||||||
|
|
||||||
def train(cfg,env,agent):
|
|
||||||
print('Start training!')
|
|
||||||
print(f'Env:{cfg.env_name}, Algo:{cfg.algo_name}, Device:{cfg.device}')
|
|
||||||
rewards = []
|
|
||||||
for i_ep in range(cfg.train_eps):
|
|
||||||
state = env.reset()
|
|
||||||
ep_reward = 0
|
|
||||||
for _ in count():
|
|
||||||
action = agent.sample_action(state) # sample action
|
|
||||||
next_state, reward, done, _ = env.step(action)
|
|
||||||
ep_reward += reward
|
|
||||||
if done:
|
|
||||||
reward = 0
|
|
||||||
agent.memory.push((state,float(action),reward))
|
|
||||||
state = next_state
|
|
||||||
if done:
|
|
||||||
print(f'Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.2f}')
|
|
||||||
break
|
|
||||||
if (i_ep+1) % cfg.update_fre == 0:
|
|
||||||
agent.update()
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
print('Finish training!')
|
|
||||||
env.close() # close environment
|
|
||||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
|
||||||
return res_dic
|
|
||||||
|
|
||||||
|
|
||||||
def test(cfg,env,agent):
|
|
||||||
print("start testing!")
|
|
||||||
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
|
|
||||||
rewards = []
|
|
||||||
for i_ep in range(cfg.test_eps):
|
|
||||||
state = env.reset()
|
|
||||||
ep_reward = 0
|
|
||||||
for _ in count():
|
|
||||||
action = agent.predict_action(state)
|
|
||||||
next_state, reward, done, _ = env.step(action)
|
|
||||||
ep_reward += reward
|
|
||||||
if done:
|
|
||||||
reward = 0
|
|
||||||
state = next_state
|
|
||||||
if done:
|
|
||||||
print(f'Episode: {i_ep+1}/{cfg.test_eps},Reward: {ep_reward:.2f}')
|
|
||||||
break
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
print("finish testing!")
|
|
||||||
env.close()
|
|
||||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cfg = get_args()
|
|
||||||
env, agent = env_agent_config(cfg)
|
|
||||||
res_dic = train(cfg, env, agent)
|
|
||||||
save_args(cfg,path = cfg.result_path) # save parameters
|
|
||||||
agent.save_model(path = cfg.model_path) # save models
|
|
||||||
save_results(res_dic, tag = 'train', path = cfg.result_path) # save results
|
|
||||||
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train") # plot results
|
|
||||||
# testing
|
|
||||||
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
|
||||||
agent.load_model(path = cfg.model_path) # load model
|
|
||||||
res_dic = test(cfg, env, agent)
|
|
||||||
save_results(res_dic, tag='test',
|
|
||||||
path = cfg.result_path)
|
|
||||||
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test")
|
|
||||||
|
|
||||||
|
|
||||||
@@ -5,7 +5,7 @@ Author: John
|
|||||||
Email: johnjim0816@gmail.com
|
Email: johnjim0816@gmail.com
|
||||||
Date: 2020-09-11 23:03:00
|
Date: 2020-09-11 23:03:00
|
||||||
LastEditor: John
|
LastEditor: John
|
||||||
LastEditTime: 2022-08-24 11:27:01
|
LastEditTime: 2022-08-25 14:59:15
|
||||||
Discription:
|
Discription:
|
||||||
Environment:
|
Environment:
|
||||||
'''
|
'''
|
||||||
@@ -18,136 +18,102 @@ sys.path.append(parent_path) # add path to system path
|
|||||||
import gym
|
import gym
|
||||||
import datetime
|
import datetime
|
||||||
import argparse
|
import argparse
|
||||||
from envs.gridworld_env import CliffWalkingWapper,FrozenLakeWapper
|
from envs.gridworld_env import FrozenLakeWapper
|
||||||
|
from envs.wrappers import CliffWalkingWapper
|
||||||
|
from envs.register import register_env
|
||||||
from qlearning import QLearning
|
from qlearning import QLearning
|
||||||
from common.utils import plot_rewards,save_args,all_seed
|
from common.utils import all_seed
|
||||||
from common.utils import save_results,make_dir
|
from common.launcher import Launcher
|
||||||
|
|
||||||
def get_args():
|
|
||||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
|
||||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
|
||||||
parser.add_argument('--algo_name',default='Q-learning',type=str,help="name of algorithm")
|
|
||||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
|
||||||
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training")
|
|
||||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
|
||||||
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor")
|
|
||||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
|
||||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
|
||||||
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon")
|
|
||||||
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
|
|
||||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
|
||||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
|
||||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
|
||||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
|
||||||
args = parser.parse_args()
|
|
||||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
|
||||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
|
||||||
}
|
|
||||||
args = {**vars(args),**default_args} # type(dict)
|
|
||||||
return args
|
|
||||||
def env_agent_config(cfg):
|
|
||||||
''' create env and agent
|
|
||||||
'''
|
|
||||||
if cfg['env_name'] == 'CliffWalking-v0':
|
|
||||||
env = gym.make(cfg['env_name'])
|
|
||||||
env = CliffWalkingWapper(env)
|
|
||||||
if cfg['env_name'] == 'FrozenLake-v1':
|
|
||||||
env = gym.make(cfg['env_name'],is_slippery=False)
|
|
||||||
if cfg['seed'] !=0: # set random seed
|
|
||||||
all_seed(env,seed=cfg["seed"])
|
|
||||||
n_states = env.observation_space.n # state dimension
|
|
||||||
n_actions = env.action_space.n # action dimension
|
|
||||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
|
||||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
|
||||||
agent = QLearning(cfg)
|
|
||||||
return env,agent
|
|
||||||
|
|
||||||
def main(cfg,env,agent,tag = 'train'):
|
|
||||||
print(f"Start {tag}ing!")
|
|
||||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
|
||||||
rewards = [] # 记录奖励
|
|
||||||
for i_ep in range(cfg.train_eps):
|
|
||||||
ep_reward = 0 # 记录每个回合的奖励
|
|
||||||
state = env.reset() # 重置环境,即开始新的回合
|
|
||||||
while True:
|
|
||||||
if tag == 'train':action = agent.sample_action(state) # 根据算法采样一个动作
|
|
||||||
else: agent.predict_action(state)
|
|
||||||
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
|
|
||||||
if tag == 'train':agent.update(state, action, reward, next_state, done) # Q学习算法更新
|
|
||||||
state = next_state # 更新状态
|
|
||||||
ep_reward += reward
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
|
||||||
print(f"Finish {tag}ing!")
|
|
||||||
return {"rewards":rewards}
|
|
||||||
|
|
||||||
def train(cfg,env,agent):
|
|
||||||
print("Start training!")
|
|
||||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
|
||||||
rewards = [] # record rewards for all episodes
|
|
||||||
steps = [] # record steps for all episodes
|
|
||||||
for i_ep in range(cfg['train_eps']):
|
|
||||||
ep_reward = 0 # reward per episode
|
|
||||||
ep_step = 0 # step per episode
|
|
||||||
state = env.reset() # reset and obtain initial state
|
|
||||||
while True:
|
|
||||||
action = agent.sample_action(state) # sample action
|
|
||||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
|
||||||
agent.update(state, action, reward, next_state, done) # update agent
|
|
||||||
state = next_state # update state
|
|
||||||
ep_reward += reward
|
|
||||||
ep_step += 1
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
steps.append(ep_step)
|
|
||||||
if (i_ep+1)%10==0:
|
|
||||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
|
||||||
print("Finish training!")
|
|
||||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
|
||||||
|
|
||||||
def test(cfg,env,agent):
|
|
||||||
print("Start testing!")
|
|
||||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
|
||||||
rewards = [] # record rewards for all episodes
|
|
||||||
steps = [] # record steps for all episodes
|
|
||||||
for i_ep in range(cfg['test_eps']):
|
|
||||||
ep_reward = 0 # reward per episode
|
|
||||||
ep_step = 0
|
|
||||||
state = env.reset() # reset and obtain initial state
|
|
||||||
while True:
|
|
||||||
action = agent.predict_action(state) # predict action
|
|
||||||
next_state, reward, done, _ = env.step(action)
|
|
||||||
state = next_state
|
|
||||||
ep_reward += reward
|
|
||||||
ep_step += 1
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
steps.append(ep_step)
|
|
||||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
|
||||||
print("Finish testing!")
|
|
||||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
|
||||||
|
|
||||||
|
class Main(Launcher):
|
||||||
|
def get_args(self):
|
||||||
|
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||||
|
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||||
|
parser.add_argument('--algo_name',default='Q-learning',type=str,help="name of algorithm")
|
||||||
|
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||||
|
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training")
|
||||||
|
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||||
|
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor")
|
||||||
|
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
||||||
|
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||||
|
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon")
|
||||||
|
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
|
||||||
|
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||||
|
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||||
|
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||||
|
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||||
|
args = parser.parse_args()
|
||||||
|
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||||
|
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||||
|
}
|
||||||
|
args = {**vars(args),**default_args} # type(dict)
|
||||||
|
return args
|
||||||
|
def env_agent_config(self,cfg):
|
||||||
|
''' create env and agent
|
||||||
|
'''
|
||||||
|
register_env(cfg['env_name'])
|
||||||
|
env = gym.make(cfg['env_name'])
|
||||||
|
if cfg['env_name'] == 'CliffWalking-v0':
|
||||||
|
env = CliffWalkingWapper(env)
|
||||||
|
if cfg['seed'] !=0: # set random seed
|
||||||
|
all_seed(env,seed=cfg["seed"])
|
||||||
|
n_states = env.observation_space.n # state dimension
|
||||||
|
n_actions = env.action_space.n # action dimension
|
||||||
|
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||||
|
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||||
|
agent = QLearning(cfg)
|
||||||
|
return env,agent
|
||||||
|
def train(self,cfg,env,agent):
|
||||||
|
print("Start training!")
|
||||||
|
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||||
|
rewards = [] # record rewards for all episodes
|
||||||
|
steps = [] # record steps for all episodes
|
||||||
|
for i_ep in range(cfg['train_eps']):
|
||||||
|
ep_reward = 0 # reward per episode
|
||||||
|
ep_step = 0 # step per episode
|
||||||
|
state = env.reset() # reset and obtain initial state
|
||||||
|
while True:
|
||||||
|
action = agent.sample_action(state) # sample action
|
||||||
|
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||||
|
agent.update(state, action, reward, next_state, done) # update agent
|
||||||
|
state = next_state # update state
|
||||||
|
ep_reward += reward
|
||||||
|
ep_step += 1
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
steps.append(ep_step)
|
||||||
|
if (i_ep+1)%10==0:
|
||||||
|
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||||
|
print("Finish training!")
|
||||||
|
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||||
|
def test(self,cfg,env,agent):
|
||||||
|
print("Start testing!")
|
||||||
|
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||||
|
rewards = [] # record rewards for all episodes
|
||||||
|
steps = [] # record steps for all episodes
|
||||||
|
for i_ep in range(cfg['test_eps']):
|
||||||
|
ep_reward = 0 # reward per episode
|
||||||
|
ep_step = 0
|
||||||
|
state = env.reset() # reset and obtain initial state
|
||||||
|
while True:
|
||||||
|
action = agent.predict_action(state) # predict action
|
||||||
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
state = next_state
|
||||||
|
ep_reward += reward
|
||||||
|
ep_step += 1
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
steps.append(ep_step)
|
||||||
|
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||||
|
print("Finish testing!")
|
||||||
|
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cfg = get_args()
|
main = Main()
|
||||||
# training
|
main.run()
|
||||||
env, agent = env_agent_config(cfg)
|
|
||||||
res_dic = train(cfg, env, agent)
|
|
||||||
save_args(cfg,path = cfg['result_path']) # save parameters
|
|
||||||
agent.save_model(path = cfg['model_path']) # save models
|
|
||||||
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
|
||||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
|
||||||
# testing
|
|
||||||
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
|
||||||
agent.load_model(path = cfg['model_path']) # load model
|
|
||||||
res_dic = test(cfg, env, agent)
|
|
||||||
save_results(res_dic, tag='test',
|
|
||||||
path = cfg['result_path'])
|
|
||||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "test")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 22 KiB |
|
Before Width: | Height: | Size: 53 KiB |
@@ -1,801 +0,0 @@
|
|||||||
episodes,rewards,steps
|
|
||||||
0,0.0,20
|
|
||||||
1,0.0,14
|
|
||||||
2,0.0,13
|
|
||||||
3,0.0,9
|
|
||||||
4,0.0,10
|
|
||||||
5,0.0,6
|
|
||||||
6,0.0,11
|
|
||||||
7,0.0,6
|
|
||||||
8,0.0,3
|
|
||||||
9,0.0,9
|
|
||||||
10,0.0,11
|
|
||||||
11,0.0,22
|
|
||||||
12,0.0,5
|
|
||||||
13,0.0,16
|
|
||||||
14,0.0,4
|
|
||||||
15,0.0,9
|
|
||||||
16,0.0,18
|
|
||||||
17,0.0,2
|
|
||||||
18,0.0,4
|
|
||||||
19,0.0,8
|
|
||||||
20,0.0,7
|
|
||||||
21,0.0,4
|
|
||||||
22,0.0,22
|
|
||||||
23,0.0,15
|
|
||||||
24,0.0,5
|
|
||||||
25,0.0,16
|
|
||||||
26,0.0,7
|
|
||||||
27,0.0,19
|
|
||||||
28,0.0,22
|
|
||||||
29,0.0,16
|
|
||||||
30,0.0,11
|
|
||||||
31,0.0,22
|
|
||||||
32,0.0,28
|
|
||||||
33,0.0,23
|
|
||||||
34,0.0,4
|
|
||||||
35,0.0,11
|
|
||||||
36,0.0,8
|
|
||||||
37,0.0,15
|
|
||||||
38,0.0,5
|
|
||||||
39,0.0,7
|
|
||||||
40,0.0,9
|
|
||||||
41,0.0,4
|
|
||||||
42,0.0,3
|
|
||||||
43,0.0,6
|
|
||||||
44,0.0,41
|
|
||||||
45,0.0,9
|
|
||||||
46,0.0,23
|
|
||||||
47,0.0,3
|
|
||||||
48,1.0,38
|
|
||||||
49,0.0,29
|
|
||||||
50,0.0,17
|
|
||||||
51,0.0,4
|
|
||||||
52,0.0,2
|
|
||||||
53,0.0,25
|
|
||||||
54,0.0,6
|
|
||||||
55,0.0,2
|
|
||||||
56,0.0,30
|
|
||||||
57,0.0,6
|
|
||||||
58,0.0,7
|
|
||||||
59,0.0,11
|
|
||||||
60,0.0,9
|
|
||||||
61,0.0,8
|
|
||||||
62,0.0,23
|
|
||||||
63,0.0,10
|
|
||||||
64,0.0,3
|
|
||||||
65,0.0,5
|
|
||||||
66,0.0,7
|
|
||||||
67,0.0,18
|
|
||||||
68,0.0,8
|
|
||||||
69,0.0,26
|
|
||||||
70,0.0,6
|
|
||||||
71,0.0,14
|
|
||||||
72,0.0,4
|
|
||||||
73,0.0,25
|
|
||||||
74,0.0,21
|
|
||||||
75,0.0,13
|
|
||||||
76,0.0,4
|
|
||||||
77,0.0,29
|
|
||||||
78,0.0,21
|
|
||||||
79,0.0,6
|
|
||||||
80,0.0,6
|
|
||||||
81,0.0,11
|
|
||||||
82,0.0,21
|
|
||||||
83,0.0,9
|
|
||||||
84,0.0,9
|
|
||||||
85,0.0,7
|
|
||||||
86,0.0,48
|
|
||||||
87,0.0,23
|
|
||||||
88,0.0,100
|
|
||||||
89,0.0,60
|
|
||||||
90,0.0,7
|
|
||||||
91,0.0,10
|
|
||||||
92,0.0,24
|
|
||||||
93,0.0,4
|
|
||||||
94,0.0,7
|
|
||||||
95,0.0,17
|
|
||||||
96,0.0,87
|
|
||||||
97,0.0,28
|
|
||||||
98,0.0,7
|
|
||||||
99,0.0,5
|
|
||||||
100,0.0,12
|
|
||||||
101,0.0,14
|
|
||||||
102,0.0,6
|
|
||||||
103,0.0,13
|
|
||||||
104,0.0,93
|
|
||||||
105,0.0,4
|
|
||||||
106,0.0,50
|
|
||||||
107,0.0,8
|
|
||||||
108,0.0,12
|
|
||||||
109,0.0,43
|
|
||||||
110,0.0,30
|
|
||||||
111,0.0,15
|
|
||||||
112,0.0,19
|
|
||||||
113,0.0,100
|
|
||||||
114,0.0,82
|
|
||||||
115,0.0,40
|
|
||||||
116,0.0,88
|
|
||||||
117,0.0,19
|
|
||||||
118,0.0,30
|
|
||||||
119,0.0,27
|
|
||||||
120,0.0,5
|
|
||||||
121,0.0,87
|
|
||||||
122,0.0,9
|
|
||||||
123,0.0,64
|
|
||||||
124,0.0,27
|
|
||||||
125,0.0,68
|
|
||||||
126,0.0,81
|
|
||||||
127,0.0,86
|
|
||||||
128,0.0,100
|
|
||||||
129,0.0,100
|
|
||||||
130,0.0,27
|
|
||||||
131,0.0,41
|
|
||||||
132,0.0,70
|
|
||||||
133,0.0,27
|
|
||||||
134,0.0,6
|
|
||||||
135,0.0,18
|
|
||||||
136,0.0,38
|
|
||||||
137,0.0,26
|
|
||||||
138,0.0,36
|
|
||||||
139,0.0,3
|
|
||||||
140,0.0,61
|
|
||||||
141,0.0,100
|
|
||||||
142,0.0,4
|
|
||||||
143,0.0,39
|
|
||||||
144,0.0,18
|
|
||||||
145,0.0,33
|
|
||||||
146,0.0,29
|
|
||||||
147,0.0,49
|
|
||||||
148,0.0,88
|
|
||||||
149,0.0,22
|
|
||||||
150,0.0,65
|
|
||||||
151,0.0,36
|
|
||||||
152,0.0,30
|
|
||||||
153,0.0,58
|
|
||||||
154,0.0,43
|
|
||||||
155,0.0,53
|
|
||||||
156,0.0,43
|
|
||||||
157,0.0,13
|
|
||||||
158,0.0,8
|
|
||||||
159,0.0,39
|
|
||||||
160,0.0,29
|
|
||||||
161,0.0,26
|
|
||||||
162,0.0,60
|
|
||||||
163,0.0,100
|
|
||||||
164,0.0,31
|
|
||||||
165,0.0,22
|
|
||||||
166,0.0,100
|
|
||||||
167,0.0,46
|
|
||||||
168,0.0,23
|
|
||||||
169,0.0,54
|
|
||||||
170,0.0,8
|
|
||||||
171,0.0,58
|
|
||||||
172,0.0,3
|
|
||||||
173,0.0,47
|
|
||||||
174,0.0,16
|
|
||||||
175,0.0,21
|
|
||||||
176,0.0,44
|
|
||||||
177,0.0,29
|
|
||||||
178,0.0,100
|
|
||||||
179,0.0,100
|
|
||||||
180,0.0,62
|
|
||||||
181,0.0,83
|
|
||||||
182,0.0,26
|
|
||||||
183,0.0,24
|
|
||||||
184,0.0,10
|
|
||||||
185,0.0,12
|
|
||||||
186,0.0,40
|
|
||||||
187,0.0,25
|
|
||||||
188,0.0,18
|
|
||||||
189,0.0,60
|
|
||||||
190,0.0,100
|
|
||||||
191,0.0,100
|
|
||||||
192,0.0,24
|
|
||||||
193,0.0,56
|
|
||||||
194,0.0,71
|
|
||||||
195,0.0,19
|
|
||||||
196,0.0,100
|
|
||||||
197,0.0,44
|
|
||||||
198,0.0,41
|
|
||||||
199,0.0,41
|
|
||||||
200,0.0,60
|
|
||||||
201,0.0,31
|
|
||||||
202,0.0,34
|
|
||||||
203,0.0,35
|
|
||||||
204,0.0,59
|
|
||||||
205,0.0,51
|
|
||||||
206,0.0,100
|
|
||||||
207,0.0,100
|
|
||||||
208,0.0,100
|
|
||||||
209,0.0,100
|
|
||||||
210,0.0,37
|
|
||||||
211,0.0,68
|
|
||||||
212,0.0,40
|
|
||||||
213,0.0,17
|
|
||||||
214,0.0,79
|
|
||||||
215,0.0,100
|
|
||||||
216,0.0,26
|
|
||||||
217,0.0,61
|
|
||||||
218,0.0,25
|
|
||||||
219,0.0,18
|
|
||||||
220,0.0,27
|
|
||||||
221,0.0,13
|
|
||||||
222,0.0,100
|
|
||||||
223,0.0,87
|
|
||||||
224,0.0,100
|
|
||||||
225,0.0,92
|
|
||||||
226,0.0,100
|
|
||||||
227,0.0,8
|
|
||||||
228,0.0,100
|
|
||||||
229,0.0,64
|
|
||||||
230,0.0,17
|
|
||||||
231,0.0,82
|
|
||||||
232,0.0,100
|
|
||||||
233,0.0,94
|
|
||||||
234,0.0,7
|
|
||||||
235,0.0,36
|
|
||||||
236,0.0,100
|
|
||||||
237,0.0,56
|
|
||||||
238,0.0,17
|
|
||||||
239,0.0,100
|
|
||||||
240,0.0,83
|
|
||||||
241,0.0,100
|
|
||||||
242,0.0,100
|
|
||||||
243,0.0,43
|
|
||||||
244,0.0,87
|
|
||||||
245,0.0,42
|
|
||||||
246,0.0,80
|
|
||||||
247,0.0,54
|
|
||||||
248,0.0,82
|
|
||||||
249,0.0,97
|
|
||||||
250,0.0,65
|
|
||||||
251,0.0,83
|
|
||||||
252,0.0,100
|
|
||||||
253,0.0,59
|
|
||||||
254,0.0,100
|
|
||||||
255,0.0,78
|
|
||||||
256,0.0,100
|
|
||||||
257,0.0,100
|
|
||||||
258,0.0,43
|
|
||||||
259,0.0,80
|
|
||||||
260,0.0,100
|
|
||||||
261,0.0,70
|
|
||||||
262,0.0,94
|
|
||||||
263,0.0,100
|
|
||||||
264,0.0,100
|
|
||||||
265,0.0,37
|
|
||||||
266,0.0,11
|
|
||||||
267,0.0,31
|
|
||||||
268,0.0,100
|
|
||||||
269,0.0,34
|
|
||||||
270,0.0,32
|
|
||||||
271,0.0,58
|
|
||||||
272,0.0,38
|
|
||||||
273,0.0,28
|
|
||||||
274,0.0,100
|
|
||||||
275,0.0,59
|
|
||||||
276,0.0,100
|
|
||||||
277,0.0,82
|
|
||||||
278,0.0,51
|
|
||||||
279,0.0,25
|
|
||||||
280,0.0,73
|
|
||||||
281,0.0,56
|
|
||||||
282,0.0,55
|
|
||||||
283,0.0,38
|
|
||||||
284,0.0,100
|
|
||||||
285,0.0,100
|
|
||||||
286,0.0,92
|
|
||||||
287,0.0,100
|
|
||||||
288,0.0,100
|
|
||||||
289,0.0,100
|
|
||||||
290,0.0,37
|
|
||||||
291,0.0,100
|
|
||||||
292,0.0,66
|
|
||||||
293,0.0,24
|
|
||||||
294,0.0,17
|
|
||||||
295,0.0,100
|
|
||||||
296,0.0,59
|
|
||||||
297,0.0,25
|
|
||||||
298,0.0,73
|
|
||||||
299,0.0,100
|
|
||||||
300,0.0,29
|
|
||||||
301,0.0,100
|
|
||||||
302,0.0,72
|
|
||||||
303,0.0,6
|
|
||||||
304,1.0,57
|
|
||||||
305,0.0,47
|
|
||||||
306,0.0,48
|
|
||||||
307,0.0,13
|
|
||||||
308,0.0,100
|
|
||||||
309,0.0,38
|
|
||||||
310,0.0,100
|
|
||||||
311,0.0,20
|
|
||||||
312,0.0,100
|
|
||||||
313,0.0,100
|
|
||||||
314,0.0,5
|
|
||||||
315,0.0,39
|
|
||||||
316,0.0,11
|
|
||||||
317,0.0,83
|
|
||||||
318,0.0,42
|
|
||||||
319,0.0,100
|
|
||||||
320,0.0,99
|
|
||||||
321,0.0,83
|
|
||||||
322,0.0,28
|
|
||||||
323,0.0,46
|
|
||||||
324,0.0,100
|
|
||||||
325,0.0,100
|
|
||||||
326,0.0,62
|
|
||||||
327,0.0,100
|
|
||||||
328,0.0,23
|
|
||||||
329,0.0,91
|
|
||||||
330,0.0,53
|
|
||||||
331,0.0,19
|
|
||||||
332,0.0,26
|
|
||||||
333,0.0,93
|
|
||||||
334,0.0,38
|
|
||||||
335,0.0,22
|
|
||||||
336,0.0,43
|
|
||||||
337,0.0,100
|
|
||||||
338,0.0,90
|
|
||||||
339,0.0,18
|
|
||||||
340,0.0,45
|
|
||||||
341,0.0,65
|
|
||||||
342,1.0,22
|
|
||||||
343,0.0,100
|
|
||||||
344,1.0,15
|
|
||||||
345,1.0,72
|
|
||||||
346,0.0,5
|
|
||||||
347,1.0,6
|
|
||||||
348,1.0,6
|
|
||||||
349,1.0,9
|
|
||||||
350,1.0,8
|
|
||||||
351,1.0,9
|
|
||||||
352,1.0,8
|
|
||||||
353,1.0,6
|
|
||||||
354,1.0,6
|
|
||||||
355,1.0,10
|
|
||||||
356,1.0,6
|
|
||||||
357,0.0,5
|
|
||||||
358,0.0,3
|
|
||||||
359,1.0,6
|
|
||||||
360,1.0,6
|
|
||||||
361,1.0,6
|
|
||||||
362,1.0,6
|
|
||||||
363,1.0,8
|
|
||||||
364,1.0,6
|
|
||||||
365,1.0,8
|
|
||||||
366,1.0,6
|
|
||||||
367,1.0,6
|
|
||||||
368,1.0,8
|
|
||||||
369,1.0,6
|
|
||||||
370,1.0,6
|
|
||||||
371,0.0,5
|
|
||||||
372,1.0,6
|
|
||||||
373,0.0,6
|
|
||||||
374,1.0,6
|
|
||||||
375,1.0,12
|
|
||||||
376,1.0,6
|
|
||||||
377,1.0,6
|
|
||||||
378,1.0,9
|
|
||||||
379,1.0,6
|
|
||||||
380,1.0,6
|
|
||||||
381,0.0,2
|
|
||||||
382,0.0,3
|
|
||||||
383,0.0,2
|
|
||||||
384,0.0,4
|
|
||||||
385,0.0,3
|
|
||||||
386,1.0,7
|
|
||||||
387,1.0,6
|
|
||||||
388,1.0,6
|
|
||||||
389,1.0,8
|
|
||||||
390,1.0,9
|
|
||||||
391,1.0,8
|
|
||||||
392,1.0,8
|
|
||||||
393,1.0,6
|
|
||||||
394,1.0,6
|
|
||||||
395,1.0,7
|
|
||||||
396,1.0,6
|
|
||||||
397,0.0,5
|
|
||||||
398,0.0,5
|
|
||||||
399,1.0,10
|
|
||||||
400,1.0,6
|
|
||||||
401,0.0,3
|
|
||||||
402,1.0,6
|
|
||||||
403,1.0,7
|
|
||||||
404,1.0,6
|
|
||||||
405,1.0,6
|
|
||||||
406,1.0,6
|
|
||||||
407,1.0,6
|
|
||||||
408,1.0,6
|
|
||||||
409,1.0,6
|
|
||||||
410,1.0,6
|
|
||||||
411,0.0,5
|
|
||||||
412,1.0,6
|
|
||||||
413,1.0,6
|
|
||||||
414,0.0,2
|
|
||||||
415,1.0,6
|
|
||||||
416,1.0,6
|
|
||||||
417,1.0,6
|
|
||||||
418,1.0,6
|
|
||||||
419,1.0,6
|
|
||||||
420,1.0,8
|
|
||||||
421,1.0,6
|
|
||||||
422,1.0,6
|
|
||||||
423,1.0,6
|
|
||||||
424,1.0,6
|
|
||||||
425,1.0,7
|
|
||||||
426,0.0,5
|
|
||||||
427,1.0,6
|
|
||||||
428,1.0,6
|
|
||||||
429,1.0,6
|
|
||||||
430,1.0,8
|
|
||||||
431,1.0,6
|
|
||||||
432,1.0,6
|
|
||||||
433,1.0,6
|
|
||||||
434,1.0,6
|
|
||||||
435,0.0,2
|
|
||||||
436,1.0,8
|
|
||||||
437,1.0,7
|
|
||||||
438,1.0,6
|
|
||||||
439,1.0,7
|
|
||||||
440,1.0,6
|
|
||||||
441,1.0,6
|
|
||||||
442,0.0,3
|
|
||||||
443,0.0,4
|
|
||||||
444,1.0,6
|
|
||||||
445,1.0,6
|
|
||||||
446,1.0,7
|
|
||||||
447,1.0,6
|
|
||||||
448,1.0,6
|
|
||||||
449,1.0,6
|
|
||||||
450,1.0,6
|
|
||||||
451,1.0,6
|
|
||||||
452,1.0,6
|
|
||||||
453,1.0,8
|
|
||||||
454,1.0,6
|
|
||||||
455,1.0,6
|
|
||||||
456,1.0,6
|
|
||||||
457,1.0,6
|
|
||||||
458,1.0,6
|
|
||||||
459,1.0,7
|
|
||||||
460,1.0,8
|
|
||||||
461,1.0,6
|
|
||||||
462,1.0,7
|
|
||||||
463,1.0,6
|
|
||||||
464,1.0,6
|
|
||||||
465,1.0,6
|
|
||||||
466,1.0,6
|
|
||||||
467,1.0,8
|
|
||||||
468,1.0,6
|
|
||||||
469,1.0,6
|
|
||||||
470,1.0,8
|
|
||||||
471,1.0,6
|
|
||||||
472,1.0,11
|
|
||||||
473,1.0,6
|
|
||||||
474,1.0,6
|
|
||||||
475,1.0,6
|
|
||||||
476,1.0,8
|
|
||||||
477,0.0,2
|
|
||||||
478,1.0,7
|
|
||||||
479,1.0,6
|
|
||||||
480,1.0,6
|
|
||||||
481,1.0,7
|
|
||||||
482,1.0,6
|
|
||||||
483,1.0,6
|
|
||||||
484,1.0,6
|
|
||||||
485,1.0,6
|
|
||||||
486,0.0,3
|
|
||||||
487,1.0,7
|
|
||||||
488,1.0,6
|
|
||||||
489,1.0,6
|
|
||||||
490,1.0,6
|
|
||||||
491,0.0,3
|
|
||||||
492,1.0,6
|
|
||||||
493,1.0,7
|
|
||||||
494,1.0,12
|
|
||||||
495,1.0,6
|
|
||||||
496,0.0,9
|
|
||||||
497,1.0,6
|
|
||||||
498,1.0,6
|
|
||||||
499,0.0,8
|
|
||||||
500,1.0,6
|
|
||||||
501,0.0,3
|
|
||||||
502,0.0,5
|
|
||||||
503,0.0,3
|
|
||||||
504,1.0,6
|
|
||||||
505,1.0,6
|
|
||||||
506,1.0,6
|
|
||||||
507,1.0,6
|
|
||||||
508,1.0,6
|
|
||||||
509,1.0,6
|
|
||||||
510,1.0,6
|
|
||||||
511,1.0,6
|
|
||||||
512,1.0,6
|
|
||||||
513,1.0,6
|
|
||||||
514,0.0,2
|
|
||||||
515,1.0,7
|
|
||||||
516,1.0,6
|
|
||||||
517,1.0,6
|
|
||||||
518,1.0,6
|
|
||||||
519,1.0,6
|
|
||||||
520,1.0,6
|
|
||||||
521,1.0,7
|
|
||||||
522,0.0,4
|
|
||||||
523,1.0,6
|
|
||||||
524,0.0,5
|
|
||||||
525,1.0,6
|
|
||||||
526,1.0,6
|
|
||||||
527,1.0,6
|
|
||||||
528,1.0,6
|
|
||||||
529,0.0,3
|
|
||||||
530,1.0,6
|
|
||||||
531,1.0,6
|
|
||||||
532,1.0,6
|
|
||||||
533,1.0,7
|
|
||||||
534,1.0,8
|
|
||||||
535,1.0,6
|
|
||||||
536,1.0,6
|
|
||||||
537,1.0,6
|
|
||||||
538,1.0,6
|
|
||||||
539,1.0,7
|
|
||||||
540,1.0,7
|
|
||||||
541,1.0,7
|
|
||||||
542,1.0,8
|
|
||||||
543,1.0,6
|
|
||||||
544,1.0,10
|
|
||||||
545,1.0,6
|
|
||||||
546,1.0,6
|
|
||||||
547,1.0,6
|
|
||||||
548,1.0,8
|
|
||||||
549,1.0,6
|
|
||||||
550,1.0,6
|
|
||||||
551,1.0,8
|
|
||||||
552,1.0,6
|
|
||||||
553,1.0,7
|
|
||||||
554,1.0,6
|
|
||||||
555,1.0,7
|
|
||||||
556,1.0,6
|
|
||||||
557,1.0,6
|
|
||||||
558,1.0,7
|
|
||||||
559,1.0,7
|
|
||||||
560,1.0,7
|
|
||||||
561,1.0,6
|
|
||||||
562,1.0,6
|
|
||||||
563,1.0,6
|
|
||||||
564,1.0,6
|
|
||||||
565,1.0,6
|
|
||||||
566,1.0,6
|
|
||||||
567,1.0,6
|
|
||||||
568,1.0,7
|
|
||||||
569,0.0,4
|
|
||||||
570,1.0,8
|
|
||||||
571,1.0,8
|
|
||||||
572,1.0,7
|
|
||||||
573,1.0,6
|
|
||||||
574,1.0,8
|
|
||||||
575,1.0,6
|
|
||||||
576,1.0,6
|
|
||||||
577,1.0,7
|
|
||||||
578,1.0,6
|
|
||||||
579,1.0,6
|
|
||||||
580,1.0,8
|
|
||||||
581,1.0,7
|
|
||||||
582,1.0,6
|
|
||||||
583,1.0,6
|
|
||||||
584,0.0,3
|
|
||||||
585,1.0,11
|
|
||||||
586,1.0,6
|
|
||||||
587,1.0,8
|
|
||||||
588,0.0,2
|
|
||||||
589,1.0,6
|
|
||||||
590,1.0,6
|
|
||||||
591,1.0,6
|
|
||||||
592,1.0,6
|
|
||||||
593,1.0,8
|
|
||||||
594,1.0,6
|
|
||||||
595,1.0,7
|
|
||||||
596,1.0,6
|
|
||||||
597,1.0,7
|
|
||||||
598,1.0,6
|
|
||||||
599,1.0,8
|
|
||||||
600,0.0,2
|
|
||||||
601,1.0,6
|
|
||||||
602,1.0,7
|
|
||||||
603,1.0,6
|
|
||||||
604,1.0,6
|
|
||||||
605,1.0,10
|
|
||||||
606,1.0,7
|
|
||||||
607,1.0,6
|
|
||||||
608,1.0,6
|
|
||||||
609,1.0,6
|
|
||||||
610,1.0,6
|
|
||||||
611,1.0,6
|
|
||||||
612,1.0,7
|
|
||||||
613,0.0,4
|
|
||||||
614,1.0,7
|
|
||||||
615,1.0,6
|
|
||||||
616,1.0,8
|
|
||||||
617,0.0,3
|
|
||||||
618,1.0,6
|
|
||||||
619,1.0,6
|
|
||||||
620,1.0,6
|
|
||||||
621,1.0,6
|
|
||||||
622,0.0,2
|
|
||||||
623,1.0,6
|
|
||||||
624,1.0,6
|
|
||||||
625,1.0,6
|
|
||||||
626,1.0,6
|
|
||||||
627,1.0,6
|
|
||||||
628,1.0,7
|
|
||||||
629,1.0,6
|
|
||||||
630,1.0,6
|
|
||||||
631,1.0,7
|
|
||||||
632,1.0,6
|
|
||||||
633,1.0,6
|
|
||||||
634,1.0,6
|
|
||||||
635,1.0,6
|
|
||||||
636,1.0,6
|
|
||||||
637,1.0,6
|
|
||||||
638,1.0,6
|
|
||||||
639,1.0,8
|
|
||||||
640,1.0,6
|
|
||||||
641,1.0,8
|
|
||||||
642,1.0,7
|
|
||||||
643,1.0,6
|
|
||||||
644,0.0,3
|
|
||||||
645,1.0,6
|
|
||||||
646,1.0,7
|
|
||||||
647,1.0,6
|
|
||||||
648,1.0,6
|
|
||||||
649,1.0,6
|
|
||||||
650,1.0,10
|
|
||||||
651,1.0,6
|
|
||||||
652,1.0,6
|
|
||||||
653,1.0,6
|
|
||||||
654,1.0,6
|
|
||||||
655,1.0,10
|
|
||||||
656,1.0,6
|
|
||||||
657,1.0,8
|
|
||||||
658,1.0,8
|
|
||||||
659,1.0,7
|
|
||||||
660,1.0,6
|
|
||||||
661,0.0,5
|
|
||||||
662,0.0,2
|
|
||||||
663,1.0,8
|
|
||||||
664,1.0,6
|
|
||||||
665,1.0,10
|
|
||||||
666,1.0,6
|
|
||||||
667,1.0,8
|
|
||||||
668,1.0,10
|
|
||||||
669,1.0,6
|
|
||||||
670,1.0,6
|
|
||||||
671,1.0,6
|
|
||||||
672,1.0,10
|
|
||||||
673,1.0,6
|
|
||||||
674,0.0,4
|
|
||||||
675,1.0,6
|
|
||||||
676,1.0,6
|
|
||||||
677,1.0,6
|
|
||||||
678,1.0,15
|
|
||||||
679,1.0,6
|
|
||||||
680,1.0,6
|
|
||||||
681,1.0,6
|
|
||||||
682,1.0,6
|
|
||||||
683,1.0,6
|
|
||||||
684,1.0,6
|
|
||||||
685,1.0,8
|
|
||||||
686,1.0,6
|
|
||||||
687,1.0,7
|
|
||||||
688,1.0,6
|
|
||||||
689,1.0,6
|
|
||||||
690,1.0,8
|
|
||||||
691,1.0,6
|
|
||||||
692,1.0,6
|
|
||||||
693,1.0,8
|
|
||||||
694,1.0,8
|
|
||||||
695,1.0,6
|
|
||||||
696,1.0,6
|
|
||||||
697,1.0,6
|
|
||||||
698,1.0,10
|
|
||||||
699,1.0,6
|
|
||||||
700,1.0,6
|
|
||||||
701,1.0,6
|
|
||||||
702,1.0,6
|
|
||||||
703,1.0,6
|
|
||||||
704,1.0,6
|
|
||||||
705,1.0,6
|
|
||||||
706,1.0,8
|
|
||||||
707,1.0,8
|
|
||||||
708,1.0,6
|
|
||||||
709,1.0,6
|
|
||||||
710,0.0,2
|
|
||||||
711,1.0,6
|
|
||||||
712,1.0,6
|
|
||||||
713,1.0,6
|
|
||||||
714,1.0,8
|
|
||||||
715,1.0,6
|
|
||||||
716,1.0,6
|
|
||||||
717,1.0,6
|
|
||||||
718,1.0,6
|
|
||||||
719,1.0,6
|
|
||||||
720,1.0,6
|
|
||||||
721,1.0,6
|
|
||||||
722,1.0,6
|
|
||||||
723,1.0,6
|
|
||||||
724,1.0,7
|
|
||||||
725,0.0,3
|
|
||||||
726,1.0,7
|
|
||||||
727,1.0,6
|
|
||||||
728,1.0,6
|
|
||||||
729,1.0,6
|
|
||||||
730,0.0,2
|
|
||||||
731,1.0,6
|
|
||||||
732,1.0,8
|
|
||||||
733,1.0,6
|
|
||||||
734,1.0,6
|
|
||||||
735,1.0,6
|
|
||||||
736,1.0,6
|
|
||||||
737,1.0,9
|
|
||||||
738,1.0,6
|
|
||||||
739,1.0,6
|
|
||||||
740,1.0,6
|
|
||||||
741,1.0,6
|
|
||||||
742,1.0,6
|
|
||||||
743,1.0,6
|
|
||||||
744,1.0,9
|
|
||||||
745,1.0,7
|
|
||||||
746,0.0,4
|
|
||||||
747,1.0,6
|
|
||||||
748,1.0,8
|
|
||||||
749,1.0,11
|
|
||||||
750,1.0,6
|
|
||||||
751,1.0,6
|
|
||||||
752,1.0,6
|
|
||||||
753,1.0,6
|
|
||||||
754,1.0,6
|
|
||||||
755,1.0,8
|
|
||||||
756,1.0,6
|
|
||||||
757,1.0,6
|
|
||||||
758,1.0,8
|
|
||||||
759,1.0,7
|
|
||||||
760,1.0,6
|
|
||||||
761,1.0,8
|
|
||||||
762,1.0,6
|
|
||||||
763,0.0,5
|
|
||||||
764,1.0,9
|
|
||||||
765,1.0,8
|
|
||||||
766,1.0,8
|
|
||||||
767,1.0,6
|
|
||||||
768,1.0,8
|
|
||||||
769,1.0,8
|
|
||||||
770,1.0,6
|
|
||||||
771,0.0,5
|
|
||||||
772,0.0,3
|
|
||||||
773,0.0,2
|
|
||||||
774,1.0,8
|
|
||||||
775,1.0,6
|
|
||||||
776,1.0,6
|
|
||||||
777,1.0,6
|
|
||||||
778,1.0,6
|
|
||||||
779,1.0,6
|
|
||||||
780,1.0,6
|
|
||||||
781,1.0,6
|
|
||||||
782,1.0,6
|
|
||||||
783,1.0,6
|
|
||||||
784,1.0,6
|
|
||||||
785,1.0,6
|
|
||||||
786,1.0,6
|
|
||||||
787,1.0,6
|
|
||||||
788,1.0,6
|
|
||||||
789,0.0,2
|
|
||||||
790,1.0,6
|
|
||||||
791,0.0,4
|
|
||||||
792,1.0,6
|
|
||||||
793,1.0,6
|
|
||||||
794,1.0,6
|
|
||||||
795,1.0,6
|
|
||||||
796,1.0,6
|
|
||||||
797,1.0,8
|
|
||||||
798,0.0,5
|
|
||||||
799,1.0,6
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"algo_name": "Q-learning",
|
"algo_name": "Q-learning",
|
||||||
"env_name": "FrozenLake-v1",
|
"env_name": "FrozenLakeNoSlippery-v1",
|
||||||
"train_eps": 800,
|
"train_eps": 800,
|
||||||
"test_eps": 20,
|
"test_eps": 20,
|
||||||
"gamma": 0.9,
|
"gamma": 0.9,
|
||||||
@@ -12,8 +12,8 @@
|
|||||||
"seed": 10,
|
"seed": 10,
|
||||||
"show_fig": false,
|
"show_fig": false,
|
||||||
"save_fig": true,
|
"save_fig": true,
|
||||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLake-v1/20220824-112735/results/",
|
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLakeNoSlippery-v1/20220825-114335/results/",
|
||||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLake-v1/20220824-112735/models/",
|
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLakeNoSlippery-v1/20220825-114335/models/",
|
||||||
"n_states": 16,
|
"n_states": 16,
|
||||||
"n_actions": 4
|
"n_actions": 4
|
||||||
}
|
}
|
||||||
|
After Width: | Height: | Size: 24 KiB |
|
After Width: | Height: | Size: 55 KiB |
@@ -0,0 +1,801 @@
|
|||||||
|
episodes,rewards,steps
|
||||||
|
0,0.0,20
|
||||||
|
1,0.0,14
|
||||||
|
2,0.0,13
|
||||||
|
3,0.0,9
|
||||||
|
4,0.0,10
|
||||||
|
5,0.0,6
|
||||||
|
6,0.0,11
|
||||||
|
7,0.0,6
|
||||||
|
8,0.0,3
|
||||||
|
9,0.0,9
|
||||||
|
10,0.0,11
|
||||||
|
11,0.0,22
|
||||||
|
12,0.0,5
|
||||||
|
13,0.0,16
|
||||||
|
14,0.0,4
|
||||||
|
15,0.0,9
|
||||||
|
16,0.0,18
|
||||||
|
17,0.0,2
|
||||||
|
18,0.0,4
|
||||||
|
19,0.0,8
|
||||||
|
20,0.0,7
|
||||||
|
21,0.0,4
|
||||||
|
22,0.0,22
|
||||||
|
23,0.0,15
|
||||||
|
24,0.0,5
|
||||||
|
25,0.0,16
|
||||||
|
26,0.0,7
|
||||||
|
27,0.0,19
|
||||||
|
28,0.0,22
|
||||||
|
29,0.0,16
|
||||||
|
30,0.0,11
|
||||||
|
31,0.0,22
|
||||||
|
32,0.0,28
|
||||||
|
33,0.0,23
|
||||||
|
34,0.0,4
|
||||||
|
35,0.0,11
|
||||||
|
36,0.0,8
|
||||||
|
37,0.0,15
|
||||||
|
38,0.0,5
|
||||||
|
39,0.0,7
|
||||||
|
40,0.0,9
|
||||||
|
41,0.0,4
|
||||||
|
42,0.0,3
|
||||||
|
43,0.0,6
|
||||||
|
44,0.0,41
|
||||||
|
45,0.0,9
|
||||||
|
46,0.0,23
|
||||||
|
47,0.0,3
|
||||||
|
48,1.0,38
|
||||||
|
49,0.0,29
|
||||||
|
50,0.0,17
|
||||||
|
51,0.0,4
|
||||||
|
52,0.0,2
|
||||||
|
53,0.0,25
|
||||||
|
54,0.0,6
|
||||||
|
55,0.0,2
|
||||||
|
56,0.0,30
|
||||||
|
57,0.0,6
|
||||||
|
58,0.0,7
|
||||||
|
59,0.0,11
|
||||||
|
60,0.0,9
|
||||||
|
61,0.0,8
|
||||||
|
62,0.0,23
|
||||||
|
63,0.0,10
|
||||||
|
64,0.0,3
|
||||||
|
65,0.0,5
|
||||||
|
66,0.0,7
|
||||||
|
67,0.0,18
|
||||||
|
68,0.0,8
|
||||||
|
69,0.0,26
|
||||||
|
70,0.0,6
|
||||||
|
71,0.0,14
|
||||||
|
72,0.0,4
|
||||||
|
73,0.0,25
|
||||||
|
74,0.0,21
|
||||||
|
75,0.0,13
|
||||||
|
76,0.0,4
|
||||||
|
77,0.0,29
|
||||||
|
78,0.0,21
|
||||||
|
79,0.0,6
|
||||||
|
80,0.0,6
|
||||||
|
81,0.0,11
|
||||||
|
82,0.0,21
|
||||||
|
83,0.0,9
|
||||||
|
84,0.0,9
|
||||||
|
85,0.0,7
|
||||||
|
86,0.0,48
|
||||||
|
87,0.0,23
|
||||||
|
88,0.0,160
|
||||||
|
89,0.0,7
|
||||||
|
90,0.0,10
|
||||||
|
91,0.0,24
|
||||||
|
92,0.0,4
|
||||||
|
93,0.0,7
|
||||||
|
94,0.0,17
|
||||||
|
95,0.0,87
|
||||||
|
96,0.0,28
|
||||||
|
97,0.0,7
|
||||||
|
98,0.0,5
|
||||||
|
99,0.0,12
|
||||||
|
100,0.0,14
|
||||||
|
101,0.0,6
|
||||||
|
102,0.0,13
|
||||||
|
103,0.0,93
|
||||||
|
104,0.0,4
|
||||||
|
105,0.0,50
|
||||||
|
106,0.0,8
|
||||||
|
107,0.0,12
|
||||||
|
108,0.0,43
|
||||||
|
109,0.0,30
|
||||||
|
110,0.0,15
|
||||||
|
111,0.0,19
|
||||||
|
112,0.0,182
|
||||||
|
113,0.0,40
|
||||||
|
114,0.0,88
|
||||||
|
115,0.0,19
|
||||||
|
116,0.0,30
|
||||||
|
117,0.0,27
|
||||||
|
118,0.0,5
|
||||||
|
119,0.0,87
|
||||||
|
120,0.0,9
|
||||||
|
121,0.0,64
|
||||||
|
122,0.0,27
|
||||||
|
123,0.0,68
|
||||||
|
124,0.0,81
|
||||||
|
125,0.0,86
|
||||||
|
126,0.0,227
|
||||||
|
127,0.0,41
|
||||||
|
128,0.0,70
|
||||||
|
129,0.0,27
|
||||||
|
130,0.0,6
|
||||||
|
131,0.0,18
|
||||||
|
132,0.0,38
|
||||||
|
133,0.0,26
|
||||||
|
134,0.0,36
|
||||||
|
135,0.0,3
|
||||||
|
136,0.0,61
|
||||||
|
137,0.0,105
|
||||||
|
138,0.0,38
|
||||||
|
139,0.0,18
|
||||||
|
140,0.0,33
|
||||||
|
141,0.0,29
|
||||||
|
142,0.0,49
|
||||||
|
143,0.0,88
|
||||||
|
144,0.0,22
|
||||||
|
145,0.0,65
|
||||||
|
146,0.0,36
|
||||||
|
147,0.0,30
|
||||||
|
148,0.0,58
|
||||||
|
149,0.0,43
|
||||||
|
150,0.0,53
|
||||||
|
151,0.0,43
|
||||||
|
152,0.0,13
|
||||||
|
153,0.0,8
|
||||||
|
154,0.0,39
|
||||||
|
155,0.0,29
|
||||||
|
156,0.0,26
|
||||||
|
157,0.0,60
|
||||||
|
158,0.0,153
|
||||||
|
159,0.0,116
|
||||||
|
160,0.0,53
|
||||||
|
161,0.0,54
|
||||||
|
162,0.0,8
|
||||||
|
163,0.0,58
|
||||||
|
164,0.0,3
|
||||||
|
165,0.0,47
|
||||||
|
166,0.0,16
|
||||||
|
167,0.0,21
|
||||||
|
168,0.0,44
|
||||||
|
169,0.0,29
|
||||||
|
170,0.0,104
|
||||||
|
171,0.0,158
|
||||||
|
172,0.0,83
|
||||||
|
173,0.0,26
|
||||||
|
174,0.0,24
|
||||||
|
175,0.0,10
|
||||||
|
176,0.0,12
|
||||||
|
177,0.0,40
|
||||||
|
178,0.0,25
|
||||||
|
179,0.0,18
|
||||||
|
180,0.0,60
|
||||||
|
181,0.0,203
|
||||||
|
182,0.0,23
|
||||||
|
183,0.0,54
|
||||||
|
184,0.0,71
|
||||||
|
185,0.0,19
|
||||||
|
186,0.0,118
|
||||||
|
187,0.0,26
|
||||||
|
188,0.0,41
|
||||||
|
189,0.0,41
|
||||||
|
190,0.0,60
|
||||||
|
191,0.0,31
|
||||||
|
192,0.0,34
|
||||||
|
193,0.0,35
|
||||||
|
194,0.0,59
|
||||||
|
195,0.0,51
|
||||||
|
196,0.0,426
|
||||||
|
197,0.0,79
|
||||||
|
198,0.0,40
|
||||||
|
199,0.0,17
|
||||||
|
200,0.0,79
|
||||||
|
201,0.0,126
|
||||||
|
202,0.0,61
|
||||||
|
203,0.0,25
|
||||||
|
204,0.0,18
|
||||||
|
205,0.0,27
|
||||||
|
206,0.0,13
|
||||||
|
207,0.0,187
|
||||||
|
208,0.0,160
|
||||||
|
209,0.0,32
|
||||||
|
210,0.0,108
|
||||||
|
211,0.0,164
|
||||||
|
212,0.0,17
|
||||||
|
213,0.0,82
|
||||||
|
214,0.0,194
|
||||||
|
215,0.0,7
|
||||||
|
216,0.0,36
|
||||||
|
217,0.0,156
|
||||||
|
218,0.0,17
|
||||||
|
219,0.0,183
|
||||||
|
220,0.0,243
|
||||||
|
221,0.0,87
|
||||||
|
222,0.0,42
|
||||||
|
223,0.0,80
|
||||||
|
224,0.0,54
|
||||||
|
225,0.0,82
|
||||||
|
226,0.0,97
|
||||||
|
227,0.0,65
|
||||||
|
228,0.0,83
|
||||||
|
229,0.0,159
|
||||||
|
230,0.0,178
|
||||||
|
231,0.0,104
|
||||||
|
232,0.0,21
|
||||||
|
233,0.0,118
|
||||||
|
234,0.0,80
|
||||||
|
235,0.0,170
|
||||||
|
236,0.0,94
|
||||||
|
237,0.0,235
|
||||||
|
238,0.0,13
|
||||||
|
239,0.0,31
|
||||||
|
240,0.0,134
|
||||||
|
241,0.0,32
|
||||||
|
242,0.0,58
|
||||||
|
243,0.0,38
|
||||||
|
244,0.0,28
|
||||||
|
245,0.0,159
|
||||||
|
246,0.0,182
|
||||||
|
247,0.0,51
|
||||||
|
248,0.0,25
|
||||||
|
249,0.0,73
|
||||||
|
250,0.0,56
|
||||||
|
251,0.0,55
|
||||||
|
252,0.0,38
|
||||||
|
253,0.0,292
|
||||||
|
254,0.0,319
|
||||||
|
255,0.0,100
|
||||||
|
256,0.0,84
|
||||||
|
257,0.0,24
|
||||||
|
258,0.0,17
|
||||||
|
259,0.0,159
|
||||||
|
260,0.0,25
|
||||||
|
261,0.0,73
|
||||||
|
262,0.0,130
|
||||||
|
263,0.0,111
|
||||||
|
264,0.0,65
|
||||||
|
265,1.0,58
|
||||||
|
266,0.0,47
|
||||||
|
267,0.0,48
|
||||||
|
268,0.0,13
|
||||||
|
269,0.0,100
|
||||||
|
270,0.0,38
|
||||||
|
271,0.0,111
|
||||||
|
272,0.0,226
|
||||||
|
273,0.0,38
|
||||||
|
274,0.0,83
|
||||||
|
275,0.0,42
|
||||||
|
276,0.0,199
|
||||||
|
277,0.0,83
|
||||||
|
278,0.0,28
|
||||||
|
279,0.0,46
|
||||||
|
280,0.0,262
|
||||||
|
281,0.0,123
|
||||||
|
282,0.0,91
|
||||||
|
283,0.0,53
|
||||||
|
284,0.0,19
|
||||||
|
285,0.0,26
|
||||||
|
286,0.0,93
|
||||||
|
287,0.0,38
|
||||||
|
288,0.0,22
|
||||||
|
289,0.0,43
|
||||||
|
290,0.0,163
|
||||||
|
291,0.0,25
|
||||||
|
292,0.0,59
|
||||||
|
293,0.0,71
|
||||||
|
294,0.0,20
|
||||||
|
295,0.0,115
|
||||||
|
296,0.0,248
|
||||||
|
297,0.0,66
|
||||||
|
298,0.0,58
|
||||||
|
299,0.0,129
|
||||||
|
300,0.0,122
|
||||||
|
301,0.0,47
|
||||||
|
302,0.0,60
|
||||||
|
303,0.0,79
|
||||||
|
304,1.0,137
|
||||||
|
305,0.0,27
|
||||||
|
306,1.0,93
|
||||||
|
307,0.0,46
|
||||||
|
308,1.0,83
|
||||||
|
309,1.0,8
|
||||||
|
310,1.0,6
|
||||||
|
311,1.0,6
|
||||||
|
312,0.0,4
|
||||||
|
313,1.0,6
|
||||||
|
314,0.0,2
|
||||||
|
315,1.0,6
|
||||||
|
316,1.0,6
|
||||||
|
317,1.0,6
|
||||||
|
318,1.0,6
|
||||||
|
319,1.0,8
|
||||||
|
320,0.0,5
|
||||||
|
321,1.0,6
|
||||||
|
322,1.0,7
|
||||||
|
323,0.0,5
|
||||||
|
324,1.0,6
|
||||||
|
325,1.0,6
|
||||||
|
326,1.0,8
|
||||||
|
327,1.0,6
|
||||||
|
328,1.0,6
|
||||||
|
329,1.0,6
|
||||||
|
330,1.0,7
|
||||||
|
331,1.0,6
|
||||||
|
332,1.0,6
|
||||||
|
333,0.0,3
|
||||||
|
334,1.0,7
|
||||||
|
335,0.0,4
|
||||||
|
336,1.0,6
|
||||||
|
337,1.0,6
|
||||||
|
338,1.0,7
|
||||||
|
339,1.0,6
|
||||||
|
340,1.0,6
|
||||||
|
341,1.0,7
|
||||||
|
342,1.0,7
|
||||||
|
343,1.0,7
|
||||||
|
344,1.0,6
|
||||||
|
345,1.0,6
|
||||||
|
346,1.0,6
|
||||||
|
347,1.0,6
|
||||||
|
348,1.0,6
|
||||||
|
349,1.0,6
|
||||||
|
350,1.0,6
|
||||||
|
351,1.0,7
|
||||||
|
352,0.0,4
|
||||||
|
353,1.0,8
|
||||||
|
354,1.0,8
|
||||||
|
355,1.0,7
|
||||||
|
356,1.0,6
|
||||||
|
357,1.0,8
|
||||||
|
358,1.0,6
|
||||||
|
359,1.0,6
|
||||||
|
360,1.0,7
|
||||||
|
361,1.0,6
|
||||||
|
362,1.0,6
|
||||||
|
363,1.0,8
|
||||||
|
364,1.0,7
|
||||||
|
365,1.0,6
|
||||||
|
366,1.0,6
|
||||||
|
367,0.0,3
|
||||||
|
368,1.0,11
|
||||||
|
369,1.0,6
|
||||||
|
370,1.0,8
|
||||||
|
371,0.0,2
|
||||||
|
372,1.0,6
|
||||||
|
373,1.0,6
|
||||||
|
374,1.0,6
|
||||||
|
375,1.0,6
|
||||||
|
376,1.0,8
|
||||||
|
377,1.0,6
|
||||||
|
378,1.0,7
|
||||||
|
379,1.0,6
|
||||||
|
380,1.0,7
|
||||||
|
381,1.0,6
|
||||||
|
382,1.0,8
|
||||||
|
383,0.0,2
|
||||||
|
384,1.0,6
|
||||||
|
385,1.0,7
|
||||||
|
386,1.0,6
|
||||||
|
387,1.0,6
|
||||||
|
388,1.0,10
|
||||||
|
389,1.0,7
|
||||||
|
390,1.0,6
|
||||||
|
391,1.0,6
|
||||||
|
392,1.0,6
|
||||||
|
393,1.0,6
|
||||||
|
394,1.0,6
|
||||||
|
395,1.0,7
|
||||||
|
396,0.0,4
|
||||||
|
397,1.0,7
|
||||||
|
398,1.0,6
|
||||||
|
399,1.0,8
|
||||||
|
400,0.0,3
|
||||||
|
401,1.0,6
|
||||||
|
402,1.0,6
|
||||||
|
403,1.0,6
|
||||||
|
404,1.0,6
|
||||||
|
405,0.0,2
|
||||||
|
406,1.0,6
|
||||||
|
407,1.0,6
|
||||||
|
408,1.0,6
|
||||||
|
409,1.0,6
|
||||||
|
410,1.0,6
|
||||||
|
411,1.0,7
|
||||||
|
412,1.0,6
|
||||||
|
413,1.0,6
|
||||||
|
414,1.0,7
|
||||||
|
415,1.0,6
|
||||||
|
416,1.0,6
|
||||||
|
417,1.0,6
|
||||||
|
418,1.0,6
|
||||||
|
419,1.0,6
|
||||||
|
420,1.0,6
|
||||||
|
421,1.0,6
|
||||||
|
422,1.0,8
|
||||||
|
423,1.0,6
|
||||||
|
424,1.0,8
|
||||||
|
425,1.0,7
|
||||||
|
426,1.0,6
|
||||||
|
427,0.0,3
|
||||||
|
428,1.0,6
|
||||||
|
429,1.0,7
|
||||||
|
430,1.0,6
|
||||||
|
431,1.0,6
|
||||||
|
432,1.0,6
|
||||||
|
433,1.0,10
|
||||||
|
434,1.0,6
|
||||||
|
435,1.0,6
|
||||||
|
436,1.0,6
|
||||||
|
437,1.0,6
|
||||||
|
438,1.0,10
|
||||||
|
439,1.0,6
|
||||||
|
440,1.0,8
|
||||||
|
441,1.0,8
|
||||||
|
442,1.0,7
|
||||||
|
443,1.0,6
|
||||||
|
444,0.0,5
|
||||||
|
445,0.0,2
|
||||||
|
446,1.0,8
|
||||||
|
447,1.0,6
|
||||||
|
448,1.0,10
|
||||||
|
449,1.0,6
|
||||||
|
450,1.0,8
|
||||||
|
451,1.0,10
|
||||||
|
452,1.0,6
|
||||||
|
453,1.0,6
|
||||||
|
454,1.0,6
|
||||||
|
455,1.0,10
|
||||||
|
456,1.0,6
|
||||||
|
457,0.0,4
|
||||||
|
458,1.0,6
|
||||||
|
459,1.0,6
|
||||||
|
460,1.0,6
|
||||||
|
461,1.0,15
|
||||||
|
462,1.0,6
|
||||||
|
463,1.0,6
|
||||||
|
464,1.0,6
|
||||||
|
465,1.0,6
|
||||||
|
466,1.0,6
|
||||||
|
467,1.0,6
|
||||||
|
468,1.0,8
|
||||||
|
469,1.0,6
|
||||||
|
470,1.0,7
|
||||||
|
471,1.0,6
|
||||||
|
472,1.0,6
|
||||||
|
473,1.0,8
|
||||||
|
474,1.0,6
|
||||||
|
475,1.0,6
|
||||||
|
476,1.0,8
|
||||||
|
477,1.0,8
|
||||||
|
478,1.0,6
|
||||||
|
479,1.0,6
|
||||||
|
480,1.0,6
|
||||||
|
481,1.0,10
|
||||||
|
482,1.0,6
|
||||||
|
483,1.0,6
|
||||||
|
484,1.0,6
|
||||||
|
485,1.0,6
|
||||||
|
486,1.0,6
|
||||||
|
487,1.0,6
|
||||||
|
488,1.0,6
|
||||||
|
489,1.0,8
|
||||||
|
490,1.0,8
|
||||||
|
491,1.0,6
|
||||||
|
492,1.0,6
|
||||||
|
493,0.0,2
|
||||||
|
494,1.0,6
|
||||||
|
495,1.0,6
|
||||||
|
496,1.0,6
|
||||||
|
497,1.0,8
|
||||||
|
498,1.0,6
|
||||||
|
499,1.0,6
|
||||||
|
500,1.0,6
|
||||||
|
501,1.0,6
|
||||||
|
502,1.0,6
|
||||||
|
503,1.0,6
|
||||||
|
504,1.0,6
|
||||||
|
505,1.0,6
|
||||||
|
506,1.0,6
|
||||||
|
507,1.0,7
|
||||||
|
508,0.0,3
|
||||||
|
509,1.0,7
|
||||||
|
510,1.0,6
|
||||||
|
511,1.0,6
|
||||||
|
512,1.0,6
|
||||||
|
513,0.0,2
|
||||||
|
514,1.0,6
|
||||||
|
515,1.0,8
|
||||||
|
516,1.0,6
|
||||||
|
517,1.0,6
|
||||||
|
518,1.0,6
|
||||||
|
519,1.0,6
|
||||||
|
520,1.0,9
|
||||||
|
521,1.0,6
|
||||||
|
522,1.0,6
|
||||||
|
523,1.0,6
|
||||||
|
524,1.0,6
|
||||||
|
525,1.0,6
|
||||||
|
526,1.0,6
|
||||||
|
527,1.0,9
|
||||||
|
528,1.0,7
|
||||||
|
529,0.0,4
|
||||||
|
530,1.0,6
|
||||||
|
531,1.0,8
|
||||||
|
532,1.0,11
|
||||||
|
533,1.0,6
|
||||||
|
534,1.0,6
|
||||||
|
535,1.0,6
|
||||||
|
536,1.0,6
|
||||||
|
537,1.0,6
|
||||||
|
538,1.0,8
|
||||||
|
539,1.0,6
|
||||||
|
540,1.0,6
|
||||||
|
541,1.0,8
|
||||||
|
542,1.0,7
|
||||||
|
543,1.0,6
|
||||||
|
544,1.0,8
|
||||||
|
545,1.0,6
|
||||||
|
546,0.0,5
|
||||||
|
547,1.0,9
|
||||||
|
548,1.0,8
|
||||||
|
549,1.0,8
|
||||||
|
550,1.0,6
|
||||||
|
551,1.0,8
|
||||||
|
552,1.0,8
|
||||||
|
553,1.0,6
|
||||||
|
554,0.0,5
|
||||||
|
555,0.0,3
|
||||||
|
556,0.0,2
|
||||||
|
557,1.0,8
|
||||||
|
558,1.0,6
|
||||||
|
559,1.0,6
|
||||||
|
560,1.0,6
|
||||||
|
561,1.0,6
|
||||||
|
562,1.0,6
|
||||||
|
563,1.0,6
|
||||||
|
564,1.0,6
|
||||||
|
565,1.0,6
|
||||||
|
566,1.0,6
|
||||||
|
567,1.0,6
|
||||||
|
568,1.0,6
|
||||||
|
569,1.0,6
|
||||||
|
570,1.0,6
|
||||||
|
571,1.0,6
|
||||||
|
572,0.0,2
|
||||||
|
573,1.0,6
|
||||||
|
574,0.0,4
|
||||||
|
575,1.0,6
|
||||||
|
576,1.0,6
|
||||||
|
577,1.0,6
|
||||||
|
578,1.0,6
|
||||||
|
579,1.0,6
|
||||||
|
580,1.0,8
|
||||||
|
581,0.0,5
|
||||||
|
582,1.0,6
|
||||||
|
583,1.0,6
|
||||||
|
584,1.0,6
|
||||||
|
585,1.0,6
|
||||||
|
586,1.0,6
|
||||||
|
587,1.0,6
|
||||||
|
588,0.0,3
|
||||||
|
589,1.0,6
|
||||||
|
590,1.0,6
|
||||||
|
591,1.0,6
|
||||||
|
592,0.0,2
|
||||||
|
593,1.0,6
|
||||||
|
594,0.0,4
|
||||||
|
595,1.0,6
|
||||||
|
596,1.0,6
|
||||||
|
597,1.0,6
|
||||||
|
598,1.0,6
|
||||||
|
599,1.0,8
|
||||||
|
600,1.0,6
|
||||||
|
601,1.0,7
|
||||||
|
602,1.0,6
|
||||||
|
603,1.0,7
|
||||||
|
604,1.0,6
|
||||||
|
605,0.0,2
|
||||||
|
606,1.0,6
|
||||||
|
607,1.0,6
|
||||||
|
608,0.0,5
|
||||||
|
609,0.0,3
|
||||||
|
610,0.0,3
|
||||||
|
611,1.0,6
|
||||||
|
612,0.0,5
|
||||||
|
613,1.0,8
|
||||||
|
614,1.0,8
|
||||||
|
615,1.0,6
|
||||||
|
616,1.0,6
|
||||||
|
617,1.0,7
|
||||||
|
618,1.0,6
|
||||||
|
619,1.0,6
|
||||||
|
620,1.0,6
|
||||||
|
621,1.0,6
|
||||||
|
622,1.0,6
|
||||||
|
623,1.0,8
|
||||||
|
624,0.0,2
|
||||||
|
625,1.0,6
|
||||||
|
626,1.0,6
|
||||||
|
627,1.0,6
|
||||||
|
628,1.0,6
|
||||||
|
629,1.0,6
|
||||||
|
630,1.0,6
|
||||||
|
631,1.0,6
|
||||||
|
632,1.0,8
|
||||||
|
633,1.0,6
|
||||||
|
634,1.0,8
|
||||||
|
635,1.0,6
|
||||||
|
636,1.0,6
|
||||||
|
637,1.0,8
|
||||||
|
638,1.0,8
|
||||||
|
639,0.0,5
|
||||||
|
640,0.0,4
|
||||||
|
641,0.0,4
|
||||||
|
642,1.0,6
|
||||||
|
643,1.0,6
|
||||||
|
644,1.0,6
|
||||||
|
645,1.0,6
|
||||||
|
646,1.0,8
|
||||||
|
647,1.0,6
|
||||||
|
648,0.0,4
|
||||||
|
649,1.0,6
|
||||||
|
650,1.0,8
|
||||||
|
651,1.0,6
|
||||||
|
652,1.0,6
|
||||||
|
653,1.0,6
|
||||||
|
654,1.0,6
|
||||||
|
655,1.0,6
|
||||||
|
656,1.0,6
|
||||||
|
657,1.0,6
|
||||||
|
658,1.0,8
|
||||||
|
659,1.0,8
|
||||||
|
660,1.0,6
|
||||||
|
661,1.0,8
|
||||||
|
662,1.0,9
|
||||||
|
663,1.0,6
|
||||||
|
664,1.0,6
|
||||||
|
665,1.0,6
|
||||||
|
666,1.0,6
|
||||||
|
667,1.0,10
|
||||||
|
668,1.0,6
|
||||||
|
669,1.0,6
|
||||||
|
670,1.0,6
|
||||||
|
671,1.0,11
|
||||||
|
672,1.0,10
|
||||||
|
673,1.0,8
|
||||||
|
674,1.0,6
|
||||||
|
675,1.0,6
|
||||||
|
676,1.0,6
|
||||||
|
677,0.0,5
|
||||||
|
678,1.0,6
|
||||||
|
679,0.0,2
|
||||||
|
680,1.0,9
|
||||||
|
681,1.0,6
|
||||||
|
682,1.0,8
|
||||||
|
683,1.0,7
|
||||||
|
684,1.0,6
|
||||||
|
685,1.0,6
|
||||||
|
686,1.0,7
|
||||||
|
687,0.0,3
|
||||||
|
688,1.0,7
|
||||||
|
689,0.0,2
|
||||||
|
690,1.0,6
|
||||||
|
691,1.0,6
|
||||||
|
692,1.0,8
|
||||||
|
693,1.0,8
|
||||||
|
694,1.0,6
|
||||||
|
695,1.0,6
|
||||||
|
696,0.0,2
|
||||||
|
697,1.0,8
|
||||||
|
698,1.0,6
|
||||||
|
699,1.0,8
|
||||||
|
700,1.0,6
|
||||||
|
701,1.0,6
|
||||||
|
702,1.0,9
|
||||||
|
703,1.0,6
|
||||||
|
704,1.0,8
|
||||||
|
705,1.0,11
|
||||||
|
706,1.0,6
|
||||||
|
707,1.0,6
|
||||||
|
708,1.0,6
|
||||||
|
709,1.0,6
|
||||||
|
710,1.0,8
|
||||||
|
711,1.0,6
|
||||||
|
712,1.0,6
|
||||||
|
713,1.0,6
|
||||||
|
714,0.0,5
|
||||||
|
715,1.0,6
|
||||||
|
716,1.0,6
|
||||||
|
717,1.0,6
|
||||||
|
718,1.0,6
|
||||||
|
719,1.0,6
|
||||||
|
720,1.0,7
|
||||||
|
721,1.0,6
|
||||||
|
722,1.0,6
|
||||||
|
723,1.0,6
|
||||||
|
724,1.0,6
|
||||||
|
725,1.0,10
|
||||||
|
726,1.0,6
|
||||||
|
727,1.0,6
|
||||||
|
728,1.0,6
|
||||||
|
729,1.0,6
|
||||||
|
730,1.0,6
|
||||||
|
731,1.0,7
|
||||||
|
732,1.0,6
|
||||||
|
733,1.0,8
|
||||||
|
734,1.0,7
|
||||||
|
735,1.0,6
|
||||||
|
736,1.0,6
|
||||||
|
737,1.0,14
|
||||||
|
738,1.0,6
|
||||||
|
739,1.0,6
|
||||||
|
740,1.0,12
|
||||||
|
741,1.0,6
|
||||||
|
742,1.0,6
|
||||||
|
743,1.0,6
|
||||||
|
744,1.0,6
|
||||||
|
745,1.0,6
|
||||||
|
746,1.0,6
|
||||||
|
747,0.0,3
|
||||||
|
748,1.0,6
|
||||||
|
749,1.0,6
|
||||||
|
750,1.0,6
|
||||||
|
751,1.0,7
|
||||||
|
752,1.0,6
|
||||||
|
753,1.0,6
|
||||||
|
754,1.0,6
|
||||||
|
755,1.0,8
|
||||||
|
756,0.0,2
|
||||||
|
757,1.0,6
|
||||||
|
758,1.0,6
|
||||||
|
759,1.0,6
|
||||||
|
760,1.0,6
|
||||||
|
761,1.0,6
|
||||||
|
762,1.0,6
|
||||||
|
763,1.0,6
|
||||||
|
764,1.0,6
|
||||||
|
765,1.0,6
|
||||||
|
766,0.0,4
|
||||||
|
767,1.0,8
|
||||||
|
768,1.0,6
|
||||||
|
769,0.0,2
|
||||||
|
770,1.0,10
|
||||||
|
771,1.0,8
|
||||||
|
772,1.0,6
|
||||||
|
773,1.0,6
|
||||||
|
774,1.0,6
|
||||||
|
775,0.0,3
|
||||||
|
776,1.0,6
|
||||||
|
777,1.0,6
|
||||||
|
778,0.0,6
|
||||||
|
779,1.0,8
|
||||||
|
780,1.0,6
|
||||||
|
781,1.0,9
|
||||||
|
782,1.0,6
|
||||||
|
783,1.0,6
|
||||||
|
784,1.0,8
|
||||||
|
785,1.0,8
|
||||||
|
786,1.0,6
|
||||||
|
787,0.0,5
|
||||||
|
788,1.0,6
|
||||||
|
789,1.0,6
|
||||||
|
790,1.0,6
|
||||||
|
791,1.0,6
|
||||||
|
792,1.0,6
|
||||||
|
793,1.0,6
|
||||||
|
794,1.0,8
|
||||||
|
795,1.0,6
|
||||||
|
796,0.0,2
|
||||||
|
797,1.0,8
|
||||||
|
798,1.0,7
|
||||||
|
799,1.0,6
|
||||||
|
136
projects/codes/Sarsa/main.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
'''
|
||||||
|
Author: John
|
||||||
|
Email: johnjim0816@gmail.com
|
||||||
|
Date: 2021-03-11 17:59:16
|
||||||
|
LastEditor: John
|
||||||
|
LastEditTime: 2022-08-25 14:26:36
|
||||||
|
Discription:
|
||||||
|
Environment:
|
||||||
|
'''
|
||||||
|
import sys,os
|
||||||
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||||
|
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||||
|
parent_path = os.path.dirname(curr_path) # parent path
|
||||||
|
sys.path.append(parent_path) # add path to system path
|
||||||
|
import gym
|
||||||
|
import datetime
|
||||||
|
import argparse
|
||||||
|
from envs.register import register_env
|
||||||
|
from envs.wrappers import CliffWalkingWapper
|
||||||
|
from Sarsa.sarsa import Sarsa
|
||||||
|
from common.utils import save_results,make_dir,plot_rewards,save_args,all_seed
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||||
|
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||||
|
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||||
|
parser.add_argument('--env_name',default='Racetrack-v0',type=str,help="name of environment")
|
||||||
|
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training")
|
||||||
|
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||||
|
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||||
|
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
|
||||||
|
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||||
|
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
|
||||||
|
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||||
|
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||||
|
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||||
|
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||||
|
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||||
|
args = parser.parse_args()
|
||||||
|
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||||
|
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||||
|
}
|
||||||
|
args = {**vars(args),**default_args} # type(dict)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def env_agent_config(cfg):
|
||||||
|
register_env(cfg['env_name'])
|
||||||
|
env = gym.make(cfg['env_name'])
|
||||||
|
if cfg['seed'] !=0: # set random seed
|
||||||
|
all_seed(env,seed= cfg['seed'])
|
||||||
|
if cfg['env_name'] == 'CliffWalking-v0':
|
||||||
|
env = CliffWalkingWapper(env)
|
||||||
|
try: # state dimension
|
||||||
|
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||||
|
except AttributeError:
|
||||||
|
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||||
|
n_actions = env.action_space.n # action dimension
|
||||||
|
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||||
|
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||||
|
agent = Sarsa(cfg)
|
||||||
|
return env,agent
|
||||||
|
|
||||||
|
def train(cfg,env,agent):
|
||||||
|
print("Start training!")
|
||||||
|
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||||
|
rewards = [] # record rewards for all episodes
|
||||||
|
steps = [] # record steps for all episodes
|
||||||
|
for i_ep in range(cfg['train_eps']):
|
||||||
|
ep_reward = 0 # reward per episode
|
||||||
|
ep_step = 0 # step per episode
|
||||||
|
state = env.reset() # reset and obtain initial state
|
||||||
|
action = agent.sample_action(state)
|
||||||
|
while True:
|
||||||
|
# for _ in range(cfg.ep_max_steps):
|
||||||
|
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||||
|
next_action = agent.sample_action(next_state)
|
||||||
|
agent.update(state, action, reward, next_state, next_action,done) # update agent
|
||||||
|
state = next_state # update state
|
||||||
|
action = next_action
|
||||||
|
ep_reward += reward
|
||||||
|
ep_step += 1
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
steps.append(ep_step)
|
||||||
|
if (i_ep+1)%10==0:
|
||||||
|
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||||
|
print("Finish training!")
|
||||||
|
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||||
|
|
||||||
|
def test(cfg,env,agent):
|
||||||
|
print("Start testing!")
|
||||||
|
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||||
|
rewards = [] # record rewards for all episodes
|
||||||
|
steps = [] # record steps for all episodes
|
||||||
|
for i_ep in range(cfg['test_eps']):
|
||||||
|
ep_reward = 0 # reward per episode
|
||||||
|
ep_step = 0
|
||||||
|
while True:
|
||||||
|
# for _ in range(cfg.ep_max_steps):
|
||||||
|
action = agent.predict_action(state)
|
||||||
|
next_state, reward, done = env.step(action)
|
||||||
|
state = next_state
|
||||||
|
ep_reward+=reward
|
||||||
|
ep_step+=1
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
steps.append(ep_step)
|
||||||
|
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||||
|
print("Finish testing!")
|
||||||
|
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cfg = get_args()
|
||||||
|
# 训练
|
||||||
|
env, agent = env_agent_config(cfg)
|
||||||
|
res_dic = train(cfg, env, agent)
|
||||||
|
make_dir(cfg.result_path, cfg.model_path)
|
||||||
|
save_args(cfg) # save parameters
|
||||||
|
agent.save(path=cfg.model_path) # save model
|
||||||
|
save_results(res_dic, tag='train',
|
||||||
|
path=cfg.result_path)
|
||||||
|
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||||
|
# 测试
|
||||||
|
env, agent = env_agent_config(cfg)
|
||||||
|
agent.load(path=cfg.model_path) # 导入模型
|
||||||
|
res_dic = test(cfg, env, agent)
|
||||||
|
save_results(res_dic, tag='test',
|
||||||
|
path=cfg.result_path) # 保存结果
|
||||||
|
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -5,7 +5,7 @@ Author: John
|
|||||||
Email: johnjim0816@gmail.com
|
Email: johnjim0816@gmail.com
|
||||||
Date: 2021-03-12 16:58:16
|
Date: 2021-03-12 16:58:16
|
||||||
LastEditor: John
|
LastEditor: John
|
||||||
LastEditTime: 2022-08-04 22:22:16
|
LastEditTime: 2022-08-25 00:23:22
|
||||||
Discription:
|
Discription:
|
||||||
Environment:
|
Environment:
|
||||||
'''
|
'''
|
||||||
@@ -14,45 +14,51 @@ from collections import defaultdict
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
class Sarsa(object):
|
class Sarsa(object):
|
||||||
def __init__(self,
|
def __init__(self,cfg):
|
||||||
n_actions,cfg):
|
self.n_actions = cfg['n_actions']
|
||||||
self.n_actions = n_actions
|
self.lr = cfg['lr']
|
||||||
self.lr = cfg.lr
|
self.gamma = cfg['gamma']
|
||||||
self.gamma = cfg.gamma
|
self.epsilon = cfg['epsilon_start']
|
||||||
self.sample_count = 0
|
self.sample_count = 0
|
||||||
self.epsilon_start = cfg.epsilon_start
|
self.epsilon_start = cfg['epsilon_start']
|
||||||
self.epsilon_end = cfg.epsilon_end
|
self.epsilon_end = cfg['epsilon_end']
|
||||||
self.epsilon_decay = cfg.epsilon_decay
|
self.epsilon_decay = cfg['epsilon_decay']
|
||||||
self.Q = defaultdict(lambda: np.zeros(n_actions)) # Q table
|
self.Q_table = defaultdict(lambda: np.zeros(self.n_actions)) # Q table
|
||||||
def sample(self, state):
|
def sample_action(self, state):
|
||||||
|
''' another way to represent e-greedy policy
|
||||||
|
'''
|
||||||
self.sample_count += 1
|
self.sample_count += 1
|
||||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||||
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
|
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
|
||||||
best_action = np.argmax(self.Q[state])
|
best_action = np.argmax(self.Q_table[state])
|
||||||
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
|
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
|
||||||
action_probs[best_action] += (1.0 - self.epsilon)
|
action_probs[best_action] += (1.0 - self.epsilon)
|
||||||
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
|
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
|
||||||
return action
|
return action
|
||||||
def predict(self,state):
|
def predict_action(self,state):
|
||||||
return np.argmax(self.Q[state])
|
''' predict action while testing
|
||||||
def update(self, state, action, reward, next_state, next_action,done):
|
|
||||||
Q_predict = self.Q[state][action]
|
|
||||||
if done:
|
|
||||||
Q_target = reward # 终止状态
|
|
||||||
else:
|
|
||||||
Q_target = reward + self.gamma * self.Q[next_state][next_action] # 与Q learning不同,Sarsa是拿下一步动作对应的Q值去更新
|
|
||||||
self.Q[state][action] += self.lr * (Q_target - Q_predict)
|
|
||||||
def save(self,path):
|
|
||||||
'''把 Q表格 的数据保存到文件中
|
|
||||||
'''
|
'''
|
||||||
|
action = np.argmax(self.Q_table[state])
|
||||||
|
return action
|
||||||
|
def update(self, state, action, reward, next_state, next_action,done):
|
||||||
|
Q_predict = self.Q_table[state][action]
|
||||||
|
if done:
|
||||||
|
Q_target = reward # terminal state
|
||||||
|
else:
|
||||||
|
Q_target = reward + self.gamma * self.Q_table[next_state][next_action] # the only difference from Q learning
|
||||||
|
self.Q_table[state][action] += self.lr * (Q_target - Q_predict)
|
||||||
|
def save_model(self,path):
|
||||||
import dill
|
import dill
|
||||||
|
from pathlib import Path
|
||||||
|
# create path
|
||||||
|
Path(path).mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(
|
torch.save(
|
||||||
obj=self.Q,
|
obj=self.Q_table_table,
|
||||||
f=path+"sarsa_model.pkl",
|
f=path+"checkpoint.pkl",
|
||||||
pickle_module=dill
|
pickle_module=dill
|
||||||
)
|
)
|
||||||
def load(self, path):
|
print("Model saved!")
|
||||||
'''从文件中读取数据到 Q表格
|
def load_model(self, path):
|
||||||
'''
|
|
||||||
import dill
|
import dill
|
||||||
self.Q =torch.load(f=path+'sarsa_model.pkl',pickle_module=dill)
|
self.Q_table_table =torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
|
||||||
|
print("Mode loaded!")
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
'''
|
|
||||||
Author: John
|
|
||||||
Email: johnjim0816@gmail.com
|
|
||||||
Date: 2021-03-11 17:59:16
|
|
||||||
LastEditor: John
|
|
||||||
LastEditTime: 2022-08-04 22:28:51
|
|
||||||
Discription:
|
|
||||||
Environment:
|
|
||||||
'''
|
|
||||||
import sys,os
|
|
||||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
|
|
||||||
parent_path = os.path.dirname(curr_path) # 父路径
|
|
||||||
sys.path.append(parent_path) # 添加路径到系统路径
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import argparse
|
|
||||||
from envs.racetrack_env import RacetrackEnv
|
|
||||||
from Sarsa.sarsa import Sarsa
|
|
||||||
from common.utils import save_results,make_dir,plot_rewards,save_args
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
""" 超参数
|
|
||||||
"""
|
|
||||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
|
|
||||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
|
||||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
|
||||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
|
||||||
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training") # 训练的回合数
|
|
||||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing") # 测试的回合数
|
|
||||||
parser.add_argument('--ep_max_steps',default=200,type=int) # 每回合最大的部署
|
|
||||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor") # 折扣因子
|
|
||||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon") # e-greedy策略中初始epsilon
|
|
||||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon") # e-greedy策略中的终止epsilon
|
|
||||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon") # e-greedy策略中epsilon的衰减率
|
|
||||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
|
||||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
|
||||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
|
||||||
'/' + curr_time + '/results/' )
|
|
||||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
|
||||||
'/' + curr_time + '/models/' ) # path to save models
|
|
||||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def env_agent_config(cfg,seed=1):
|
|
||||||
env = RacetrackEnv()
|
|
||||||
n_actions = 9 # 动作数
|
|
||||||
agent = Sarsa(n_actions,cfg)
|
|
||||||
return env,agent
|
|
||||||
|
|
||||||
def train(cfg,env,agent):
|
|
||||||
print('开始训练!')
|
|
||||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
|
||||||
rewards = [] # 记录奖励
|
|
||||||
for i_ep in range(cfg.train_eps):
|
|
||||||
state = env.reset()
|
|
||||||
action = agent.sample(state)
|
|
||||||
ep_reward = 0
|
|
||||||
# while True:
|
|
||||||
for _ in range(cfg.ep_max_steps):
|
|
||||||
next_state, reward, done = env.step(action)
|
|
||||||
ep_reward+=reward
|
|
||||||
next_action = agent.sample(next_state)
|
|
||||||
agent.update(state, action, reward, next_state, next_action,done)
|
|
||||||
state = next_state
|
|
||||||
action = next_action
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
if (i_ep+1)%2==0:
|
|
||||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
|
||||||
print('完成训练!')
|
|
||||||
return {"rewards":rewards}
|
|
||||||
|
|
||||||
def test(cfg,env,agent):
|
|
||||||
print('开始测试!')
|
|
||||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
|
||||||
rewards = []
|
|
||||||
for i_ep in range(cfg.test_eps):
|
|
||||||
state = env.reset()
|
|
||||||
ep_reward = 0
|
|
||||||
# while True:
|
|
||||||
for _ in range(cfg.ep_max_steps):
|
|
||||||
action = agent.predict(state)
|
|
||||||
next_state, reward, done = env.step(action)
|
|
||||||
ep_reward+=reward
|
|
||||||
state = next_state
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
rewards.append(ep_reward)
|
|
||||||
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
|
|
||||||
print('完成测试!')
|
|
||||||
return {"rewards":rewards}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cfg = get_args()
|
|
||||||
# 训练
|
|
||||||
env, agent = env_agent_config(cfg)
|
|
||||||
res_dic = train(cfg, env, agent)
|
|
||||||
make_dir(cfg.result_path, cfg.model_path)
|
|
||||||
save_args(cfg) # save parameters
|
|
||||||
agent.save(path=cfg.model_path) # save model
|
|
||||||
save_results(res_dic, tag='train',
|
|
||||||
path=cfg.result_path)
|
|
||||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
|
||||||
# 测试
|
|
||||||
env, agent = env_agent_config(cfg)
|
|
||||||
agent.load(path=cfg.model_path) # 导入模型
|
|
||||||
res_dic = test(cfg, env, agent)
|
|
||||||
save_results(res_dic, tag='test',
|
|
||||||
path=cfg.result_path) # 保存结果
|
|
||||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
32
projects/codes/common/launcher.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from common.utils import save_args,save_results,plot_rewards
|
||||||
|
class Launcher:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
def get_args(self):
|
||||||
|
cfg = {}
|
||||||
|
return cfg
|
||||||
|
def env_agent_config(self,cfg):
|
||||||
|
env,agent = None,None
|
||||||
|
return env,agent
|
||||||
|
def train(self,cfg, env, agent):
|
||||||
|
res_dic = {}
|
||||||
|
return res_dic
|
||||||
|
def test(self,cfg, env, agent):
|
||||||
|
res_dic = {}
|
||||||
|
return res_dic
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
cfg = self.get_args()
|
||||||
|
env, agent = self.env_agent_config(cfg)
|
||||||
|
res_dic = self.train(cfg, env, agent)
|
||||||
|
save_args(cfg,path = cfg['result_path']) # save parameters
|
||||||
|
agent.save_model(path = cfg['model_path']) # save models
|
||||||
|
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
||||||
|
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
||||||
|
# testing
|
||||||
|
env, agent = self.env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||||
|
agent.load_model(path = cfg['model_path']) # load model
|
||||||
|
res_dic = self.test(cfg, env, agent)
|
||||||
|
save_results(res_dic, tag='test',
|
||||||
|
path = cfg['result_path'])
|
||||||
|
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "test")
|
||||||
@@ -72,84 +72,6 @@ class FrozenLakeWapper(gym.Wrapper):
|
|||||||
self.move_player(x_pos, y_pos)
|
self.move_player(x_pos, y_pos)
|
||||||
|
|
||||||
|
|
||||||
class CliffWalkingWapper(gym.Wrapper):
|
|
||||||
def __init__(self, env):
|
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
self.t = None
|
|
||||||
self.unit = 50
|
|
||||||
self.max_x = 12
|
|
||||||
self.max_y = 4
|
|
||||||
|
|
||||||
def draw_x_line(self, y, x0, x1, color='gray'):
|
|
||||||
assert x1 > x0
|
|
||||||
self.t.color(color)
|
|
||||||
self.t.setheading(0)
|
|
||||||
self.t.up()
|
|
||||||
self.t.goto(x0, y)
|
|
||||||
self.t.down()
|
|
||||||
self.t.forward(x1 - x0)
|
|
||||||
|
|
||||||
def draw_y_line(self, x, y0, y1, color='gray'):
|
|
||||||
assert y1 > y0
|
|
||||||
self.t.color(color)
|
|
||||||
self.t.setheading(90)
|
|
||||||
self.t.up()
|
|
||||||
self.t.goto(x, y0)
|
|
||||||
self.t.down()
|
|
||||||
self.t.forward(y1 - y0)
|
|
||||||
|
|
||||||
def draw_box(self, x, y, fillcolor='', line_color='gray'):
|
|
||||||
self.t.up()
|
|
||||||
self.t.goto(x * self.unit, y * self.unit)
|
|
||||||
self.t.color(line_color)
|
|
||||||
self.t.fillcolor(fillcolor)
|
|
||||||
self.t.setheading(90)
|
|
||||||
self.t.down()
|
|
||||||
self.t.begin_fill()
|
|
||||||
for i in range(4):
|
|
||||||
self.t.forward(self.unit)
|
|
||||||
self.t.right(90)
|
|
||||||
self.t.end_fill()
|
|
||||||
|
|
||||||
def move_player(self, x, y):
|
|
||||||
self.t.up()
|
|
||||||
self.t.setheading(90)
|
|
||||||
self.t.fillcolor('red')
|
|
||||||
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
|
|
||||||
|
|
||||||
def render(self):
|
|
||||||
if self.t == None:
|
|
||||||
self.t = turtle.Turtle()
|
|
||||||
self.wn = turtle.Screen()
|
|
||||||
self.wn.setup(self.unit * self.max_x + 100,
|
|
||||||
self.unit * self.max_y + 100)
|
|
||||||
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
|
|
||||||
self.unit * self.max_y)
|
|
||||||
self.t.shape('circle')
|
|
||||||
self.t.width(2)
|
|
||||||
self.t.speed(0)
|
|
||||||
self.t.color('gray')
|
|
||||||
for _ in range(2):
|
|
||||||
self.t.forward(self.max_x * self.unit)
|
|
||||||
self.t.left(90)
|
|
||||||
self.t.forward(self.max_y * self.unit)
|
|
||||||
self.t.left(90)
|
|
||||||
for i in range(1, self.max_y):
|
|
||||||
self.draw_x_line(
|
|
||||||
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
|
|
||||||
for i in range(1, self.max_x):
|
|
||||||
self.draw_y_line(
|
|
||||||
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
|
|
||||||
|
|
||||||
for i in range(1, self.max_x - 1):
|
|
||||||
self.draw_box(i, 0, 'black')
|
|
||||||
self.draw_box(self.max_x - 1, 0, 'yellow')
|
|
||||||
self.t.shape('turtle')
|
|
||||||
|
|
||||||
x_pos = self.s % self.max_x
|
|
||||||
y_pos = self.max_y - 1 - int(self.s / self.max_x)
|
|
||||||
self.move_player(x_pos, y_pos)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 环境1:FrozenLake, 可以配置冰面是否是滑的
|
# 环境1:FrozenLake, 可以配置冰面是否是滑的
|
||||||
|
|||||||
@@ -1,10 +1,3 @@
|
|||||||
# Please do not make changes to this file - it will be overwritten with a clean
|
|
||||||
# version when your work is marked.
|
|
||||||
#
|
|
||||||
# This file contains code for the racetrack environment that you will be using
|
|
||||||
# as part of the second part of the CM50270: Reinforcement Learning coursework.
|
|
||||||
|
|
||||||
import imp
|
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -12,23 +5,20 @@ import os
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.patheffects as pe
|
import matplotlib.patheffects as pe
|
||||||
from IPython.display import clear_output
|
from IPython.display import clear_output
|
||||||
from gym.spaces import Discrete
|
from gym.spaces import Discrete,Box
|
||||||
from matplotlib import colors
|
from matplotlib import colors
|
||||||
|
import gym
|
||||||
|
|
||||||
class RacetrackEnv(object) :
|
class RacetrackEnv(gym.Env) :
|
||||||
"""
|
"""
|
||||||
Class representing a race-track environment inspired by exercise 5.12 in Sutton & Barto 2018 (p.111).
|
Class representing a race-track environment inspired by exercise 5.12 in Sutton & Barto 2018 (p.111).
|
||||||
Please do not make changes to this class - it will be overwritten with a clean version when it comes to marking.
|
Please do not make changes to this class - it will be overwritten with a clean version when it comes to marking.
|
||||||
|
|
||||||
The dynamics of this environment are detailed in this coursework exercise's jupyter notebook, although I have
|
The dynamics of this environment are detailed in this coursework exercise's jupyter notebook, although I have
|
||||||
included rather verbose comments here for those of you who are interested in how the environment has been
|
included rather verbose comments here for those of you who are interested in how the environment has been
|
||||||
implemented (though this should not impact your solution code).
|
implemented (though this should not impact your solution code).ss
|
||||||
|
|
||||||
If you find any *bugs* with this code, please let me know immediately - thank you for finding them, sorry that I didn't!
|
|
||||||
However, please do not suggest optimisations - some things have been purposely simplified for readability's sake.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
ACTIONS_DICT = {
|
ACTIONS_DICT = {
|
||||||
0 : (1, -1), # Acc Vert., Brake Horiz.
|
0 : (1, -1), # Acc Vert., Brake Horiz.
|
||||||
1 : (1, 0), # Acc Vert., Hold Horiz.
|
1 : (1, 0), # Acc Vert., Hold Horiz.
|
||||||
@@ -61,18 +51,15 @@ class RacetrackEnv(object) :
|
|||||||
for x in range(self.track.shape[1]) :
|
for x in range(self.track.shape[1]) :
|
||||||
if (self.CELL_TYPES_DICT[self.track[y, x]] == "start") :
|
if (self.CELL_TYPES_DICT[self.track[y, x]] == "start") :
|
||||||
self.initial_states.append((y, x))
|
self.initial_states.append((y, x))
|
||||||
|
high= np.array([np.finfo(np.float32).max, np.finfo(np.float32).max, np.finfo(np.float32).max, np.finfo(np.float32).max])
|
||||||
|
self.observation_space = Box(low=-high, high=high, shape=(4,), dtype=np.float32)
|
||||||
self.action_space = Discrete(9)
|
self.action_space = Discrete(9)
|
||||||
self.is_reset = False
|
self.is_reset = False
|
||||||
|
|
||||||
#print("Racetrack Environment File Loaded Successfully.")
|
|
||||||
#print("Be sure to call .reset() before starting to initialise the environment and get an initial state!")
|
|
||||||
|
|
||||||
|
|
||||||
def step(self, action : int) :
|
def step(self, action : int) :
|
||||||
"""
|
"""
|
||||||
Takes a given action in the environment's current state, and returns a next state,
|
Takes a given action in the environment's current state, and returns a next state,
|
||||||
reward, and whether the next state is terminal or not.
|
reward, and whether the next state is done or not.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
action {int} -- The action to take in the environment's current state. Should be an integer in the range [0-8].
|
action {int} -- The action to take in the environment's current state. Should be an integer in the range [0-8].
|
||||||
@@ -86,7 +73,7 @@ class RacetrackEnv(object) :
|
|||||||
A tuple of:\n
|
A tuple of:\n
|
||||||
{(int, int, int, int)} -- The next state, a tuple of (y_pos, x_pos, y_velocity, x_velocity).\n
|
{(int, int, int, int)} -- The next state, a tuple of (y_pos, x_pos, y_velocity, x_velocity).\n
|
||||||
{int} -- The reward earned by taking the given action in the current environment state.\n
|
{int} -- The reward earned by taking the given action in the current environment state.\n
|
||||||
{bool} -- Whether the environment's next state is terminal or not.\n
|
{bool} -- Whether the environment's next state is done or not.\n
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -131,7 +118,7 @@ class RacetrackEnv(object) :
|
|||||||
new_position = (self.position[0] + self.velocity[0], self.position[1] + self.velocity[1])
|
new_position = (self.position[0] + self.velocity[0], self.position[1] + self.velocity[1])
|
||||||
|
|
||||||
reward = 0
|
reward = 0
|
||||||
terminal = False
|
done = False
|
||||||
|
|
||||||
# If position is out-of-bounds, return to start and set velocity components to zero.
|
# If position is out-of-bounds, return to start and set velocity components to zero.
|
||||||
if (new_position[0] < 0 or new_position[1] < 0 or new_position[0] >= self.track.shape[0] or new_position[1] >= self.track.shape[1]) :
|
if (new_position[0] < 0 or new_position[1] < 0 or new_position[0] >= self.track.shape[0] or new_position[1] >= self.track.shape[1]) :
|
||||||
@@ -150,7 +137,7 @@ class RacetrackEnv(object) :
|
|||||||
elif (self.CELL_TYPES_DICT[self.track[new_position]] == "goal") :
|
elif (self.CELL_TYPES_DICT[self.track[new_position]] == "goal") :
|
||||||
self.position = new_position
|
self.position = new_position
|
||||||
reward += 10
|
reward += 10
|
||||||
terminal = True
|
done = True
|
||||||
# If this gets reached, then the student has touched something they shouldn't have. Naughty!
|
# If this gets reached, then the student has touched something they shouldn't have. Naughty!
|
||||||
else :
|
else :
|
||||||
raise RuntimeError("You've met with a terrible fate, haven't you?\nDon't modify things you shouldn't!")
|
raise RuntimeError("You've met with a terrible fate, haven't you?\nDon't modify things you shouldn't!")
|
||||||
@@ -158,12 +145,12 @@ class RacetrackEnv(object) :
|
|||||||
# Penalise every timestep.
|
# Penalise every timestep.
|
||||||
reward -= 1
|
reward -= 1
|
||||||
|
|
||||||
# Require a reset if the current state is terminal.
|
# Require a reset if the current state is done.
|
||||||
if (terminal) :
|
if (done) :
|
||||||
self.is_reset = False
|
self.is_reset = False
|
||||||
|
|
||||||
# Return next state, reward, and whether the episode has ended.
|
# Return next state, reward, and whether the episode has ended.
|
||||||
return (self.position[0], self.position[1], self.velocity[0], self.velocity[1]), reward, terminal
|
return np.array([self.position[0], self.position[1], self.velocity[0], self.velocity[1]]), reward, done,{}
|
||||||
|
|
||||||
|
|
||||||
def reset(self) :
|
def reset(self) :
|
||||||
@@ -184,10 +171,10 @@ class RacetrackEnv(object) :
|
|||||||
|
|
||||||
self.is_reset = True
|
self.is_reset = True
|
||||||
|
|
||||||
return (self.position[0], self.position[1], self.velocity[0], self.velocity[1])
|
return np.array([self.position[0], self.position[1], self.velocity[0], self.velocity[1]])
|
||||||
|
|
||||||
|
|
||||||
def render(self, sleep_time : float = 0.1) :
|
def render(self, mode = 'human') :
|
||||||
"""
|
"""
|
||||||
Renders a pretty matplotlib plot representing the current state of the environment.
|
Renders a pretty matplotlib plot representing the current state of the environment.
|
||||||
Calling this method on subsequent timesteps will update the plot.
|
Calling this method on subsequent timesteps will update the plot.
|
||||||
@@ -230,13 +217,9 @@ class RacetrackEnv(object) :
|
|||||||
# Draw everything.
|
# Draw everything.
|
||||||
#fig.canvas.draw()
|
#fig.canvas.draw()
|
||||||
#fig.canvas.flush_events()
|
#fig.canvas.flush_events()
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
# time sleep
|
||||||
# Sleep if desired.
|
time.sleep(0.1)
|
||||||
if (sleep_time > 0) :
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
|
|
||||||
|
|
||||||
def get_actions(self) :
|
def get_actions(self) :
|
||||||
"""
|
"""
|
||||||
@@ -244,18 +227,16 @@ class RacetrackEnv(object) :
|
|||||||
of integers in the range [0-8].
|
of integers in the range [0-8].
|
||||||
"""
|
"""
|
||||||
return [*self.ACTIONS_DICT]
|
return [*self.ACTIONS_DICT]
|
||||||
|
if __name__ == "__main__":
|
||||||
|
num_steps = 1000000
|
||||||
|
env = RacetrackEnv()
|
||||||
|
state = env.reset()
|
||||||
|
print(state)
|
||||||
|
for _ in range(num_steps) :
|
||||||
|
|
||||||
# num_steps = 1000000
|
next_state, reward, done,_ = env.step(random.choice(env.get_actions()))
|
||||||
|
print(next_state)
|
||||||
|
env.render()
|
||||||
|
|
||||||
# env = RacetrackEnv()
|
if (done) :
|
||||||
# state = env.reset()
|
_ = env.reset()
|
||||||
# print(state)
|
|
||||||
|
|
||||||
# for _ in range(num_steps) :
|
|
||||||
|
|
||||||
# next_state, reward, terminal = env.step(random.choice(env.get_actions()))
|
|
||||||
# print(next_state)
|
|
||||||
# env.render()
|
|
||||||
|
|
||||||
# if (terminal) :
|
|
||||||
# _ = env.reset()
|
|
||||||
34
projects/codes/envs/register.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
def register_env(env_name):
|
||||||
|
if env_name == 'Racetrack-v0':
|
||||||
|
register(
|
||||||
|
id='Racetrack-v0',
|
||||||
|
entry_point='racetrack:RacetrackEnv',
|
||||||
|
max_episode_steps=1000,
|
||||||
|
kwargs={}
|
||||||
|
)
|
||||||
|
elif env_name == 'FrozenLakeNoSlippery-v1':
|
||||||
|
register(
|
||||||
|
id='FrozenLakeNoSlippery-v1',
|
||||||
|
entry_point='gym.envs.toy_text.frozen_lake:FrozenLakeEnv',
|
||||||
|
kwargs={'map_name':"4x4",'is_slippery':False},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("The env name must be wrong or the environment donot need to register!")
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# import random
|
||||||
|
# import gym
|
||||||
|
# env = gym.make('FrozenLakeNoSlippery-v1')
|
||||||
|
# num_steps = 1000000
|
||||||
|
# state = env.reset()
|
||||||
|
# n_actions = env.action_space.n
|
||||||
|
# print(state)
|
||||||
|
# for _ in range(num_steps) :
|
||||||
|
# next_state, reward, done,_ = env.step(random.choice(range(n_actions)))
|
||||||
|
# print(next_state)
|
||||||
|
# if (done) :
|
||||||
|
# _ = env.reset()
|
||||||
|
|
||||||
78
projects/codes/envs/wrappers.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import gym
|
||||||
|
class CliffWalkingWapper(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.t = None
|
||||||
|
self.unit = 50
|
||||||
|
self.max_x = 12
|
||||||
|
self.max_y = 4
|
||||||
|
|
||||||
|
def draw_x_line(self, y, x0, x1, color='gray'):
|
||||||
|
assert x1 > x0
|
||||||
|
self.t.color(color)
|
||||||
|
self.t.setheading(0)
|
||||||
|
self.t.up()
|
||||||
|
self.t.goto(x0, y)
|
||||||
|
self.t.down()
|
||||||
|
self.t.forward(x1 - x0)
|
||||||
|
|
||||||
|
def draw_y_line(self, x, y0, y1, color='gray'):
|
||||||
|
assert y1 > y0
|
||||||
|
self.t.color(color)
|
||||||
|
self.t.setheading(90)
|
||||||
|
self.t.up()
|
||||||
|
self.t.goto(x, y0)
|
||||||
|
self.t.down()
|
||||||
|
self.t.forward(y1 - y0)
|
||||||
|
|
||||||
|
def draw_box(self, x, y, fillcolor='', line_color='gray'):
|
||||||
|
self.t.up()
|
||||||
|
self.t.goto(x * self.unit, y * self.unit)
|
||||||
|
self.t.color(line_color)
|
||||||
|
self.t.fillcolor(fillcolor)
|
||||||
|
self.t.setheading(90)
|
||||||
|
self.t.down()
|
||||||
|
self.t.begin_fill()
|
||||||
|
for i in range(4):
|
||||||
|
self.t.forward(self.unit)
|
||||||
|
self.t.right(90)
|
||||||
|
self.t.end_fill()
|
||||||
|
|
||||||
|
def move_player(self, x, y):
|
||||||
|
self.t.up()
|
||||||
|
self.t.setheading(90)
|
||||||
|
self.t.fillcolor('red')
|
||||||
|
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
if self.t == None:
|
||||||
|
self.t = turtle.Turtle()
|
||||||
|
self.wn = turtle.Screen()
|
||||||
|
self.wn.setup(self.unit * self.max_x + 100,
|
||||||
|
self.unit * self.max_y + 100)
|
||||||
|
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
|
||||||
|
self.unit * self.max_y)
|
||||||
|
self.t.shape('circle')
|
||||||
|
self.t.width(2)
|
||||||
|
self.t.speed(0)
|
||||||
|
self.t.color('gray')
|
||||||
|
for _ in range(2):
|
||||||
|
self.t.forward(self.max_x * self.unit)
|
||||||
|
self.t.left(90)
|
||||||
|
self.t.forward(self.max_y * self.unit)
|
||||||
|
self.t.left(90)
|
||||||
|
for i in range(1, self.max_y):
|
||||||
|
self.draw_x_line(
|
||||||
|
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
|
||||||
|
for i in range(1, self.max_x):
|
||||||
|
self.draw_y_line(
|
||||||
|
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
|
||||||
|
|
||||||
|
for i in range(1, self.max_x - 1):
|
||||||
|
self.draw_box(i, 0, 'black')
|
||||||
|
self.draw_box(self.max_x - 1, 0, 'yellow')
|
||||||
|
self.t.shape('turtle')
|
||||||
|
|
||||||
|
x_pos = self.s % self.max_x
|
||||||
|
y_pos = self.max_y - 1 - int(self.s / self.max_x)
|
||||||
|
self.move_player(x_pos, y_pos)
|
||||||
@@ -11,4 +11,5 @@ else
|
|||||||
fi
|
fi
|
||||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||||
python $codes_dir/QLearning/main.py --env_name FrozenLake-v1 --train_eps 800 --epsilon_start 0.70 --epsilon_end 0.1 --epsilon_decay 2000 --gamma 0.9 --lr 0.9 --device cpu
|
python $codes_dir/envs/register.py # register environment
|
||||||
|
python $codes_dir/QLearning/main.py --env_name FrozenLakeNoSlippery-v1 --train_eps 800 --epsilon_start 0.70 --epsilon_end 0.1 --epsilon_decay 2000 --gamma 0.9 --lr 0.9 --device cpu
|
||||||
13
projects/codes/scripts/Sarsa_task0.sh
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
source ~/anaconda3/etc/profile.d/conda.sh
|
||||||
|
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||||
|
else
|
||||||
|
echo 'please manually config the conda source path'
|
||||||
|
fi
|
||||||
|
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||||
|
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||||
|
python $codes_dir/envs/register.py # register environment
|
||||||
|
python $codes_dir/Sarsa/main.py
|
||||||