hot update PG
@@ -1,318 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 定义模型\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import paddle\n",
|
||||
"import paddle.nn as nn\n",
|
||||
"import paddle.nn.functional as F\n",
|
||||
"import parl\n",
|
||||
"\n",
|
||||
"class CartpoleModel(parl.Model):\n",
|
||||
" \"\"\" Linear network to solve Cartpole problem.\n",
|
||||
" Args:\n",
|
||||
" n_states (int): Dimension of observation space.\n",
|
||||
" n_actions (int): Dimension of action space.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, n_states, n_actions):\n",
|
||||
" super(CartpoleModel, self).__init__()\n",
|
||||
" hid1_size = 128\n",
|
||||
" hid2_size = 128\n",
|
||||
" self.fc1 = nn.Linear(n_states, hid1_size)\n",
|
||||
" self.fc2 = nn.Linear(hid1_size, hid2_size)\n",
|
||||
" self.fc3 = nn.Linear(hid2_size, n_actions)\n",
|
||||
"\n",
|
||||
" def forward(self, obs):\n",
|
||||
" h1 = F.relu(self.fc1(obs))\n",
|
||||
" h2 = F.relu(self.fc2(h1))\n",
|
||||
" Q = self.fc3(h2)\n",
|
||||
" return Q"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import parl\n",
|
||||
"import paddle\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CartpoleAgent(parl.Agent):\n",
|
||||
" \"\"\"Agent of Cartpole env.\n",
|
||||
" Args:\n",
|
||||
" algorithm(parl.Algorithm): algorithm used to solve the problem.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, algorithm, n_actions, e_greed=0.1, e_greed_decrement=0):\n",
|
||||
" super(CartpoleAgent, self).__init__(algorithm)\n",
|
||||
" assert isinstance(n_actions, int)\n",
|
||||
" self.n_actions = n_actions\n",
|
||||
"\n",
|
||||
" self.global_step = 0\n",
|
||||
" self.update_target_steps = 200\n",
|
||||
"\n",
|
||||
" self.e_greed = e_greed\n",
|
||||
" self.e_greed_decrement = e_greed_decrement\n",
|
||||
"\n",
|
||||
" def sample(self, obs):\n",
|
||||
" \"\"\"Sample an action `for exploration` when given an observation\n",
|
||||
" Args:\n",
|
||||
" obs(np.float32): shape of (n_states,)\n",
|
||||
" Returns:\n",
|
||||
" act(int): action\n",
|
||||
" \"\"\"\n",
|
||||
" sample = np.random.random()\n",
|
||||
" if sample < self.e_greed:\n",
|
||||
" act = np.random.randint(self.n_actions)\n",
|
||||
" else:\n",
|
||||
" if np.random.random() < 0.01:\n",
|
||||
" act = np.random.randint(self.n_actions)\n",
|
||||
" else:\n",
|
||||
" act = self.predict(obs)\n",
|
||||
" self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)\n",
|
||||
" return act\n",
|
||||
"\n",
|
||||
" def predict(self, obs):\n",
|
||||
" \"\"\"Predict an action when given an observation\n",
|
||||
" Args:\n",
|
||||
" obs(np.float32): shape of (n_states,)\n",
|
||||
" Returns:\n",
|
||||
" act(int): action\n",
|
||||
" \"\"\"\n",
|
||||
" obs = paddle.to_tensor(obs, dtype='float32')\n",
|
||||
" pred_q = self.alg.predict(obs)\n",
|
||||
" act = pred_q.argmax().numpy()[0]\n",
|
||||
" return act\n",
|
||||
"\n",
|
||||
" def learn(self, obs, act, reward, next_obs, terminal):\n",
|
||||
" \"\"\"Update model with an episode data\n",
|
||||
" Args:\n",
|
||||
" obs(np.float32): shape of (batch_size, n_states)\n",
|
||||
" act(np.int32): shape of (batch_size)\n",
|
||||
" reward(np.float32): shape of (batch_size)\n",
|
||||
" next_obs(np.float32): shape of (batch_size, n_states)\n",
|
||||
" terminal(np.float32): shape of (batch_size)\n",
|
||||
" Returns:\n",
|
||||
" loss(float)\n",
|
||||
" \"\"\"\n",
|
||||
" if self.global_step % self.update_target_steps == 0:\n",
|
||||
" self.alg.sync_target()\n",
|
||||
" self.global_step += 1\n",
|
||||
"\n",
|
||||
" act = np.expand_dims(act, axis=-1)\n",
|
||||
" reward = np.expand_dims(reward, axis=-1)\n",
|
||||
" terminal = np.expand_dims(terminal, axis=-1)\n",
|
||||
"\n",
|
||||
" obs = paddle.to_tensor(obs, dtype='float32')\n",
|
||||
" act = paddle.to_tensor(act, dtype='int32')\n",
|
||||
" reward = paddle.to_tensor(reward, dtype='float32')\n",
|
||||
" next_obs = paddle.to_tensor(next_obs, dtype='float32')\n",
|
||||
" terminal = paddle.to_tensor(terminal, dtype='float32')\n",
|
||||
" loss = self.alg.learn(obs, act, reward, next_obs, terminal)\n",
|
||||
" return loss.numpy()[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"import parl\n",
|
||||
"\n",
|
||||
"from parl.utils import logger, ReplayMemory\n",
|
||||
"from parl.algorithms import DQN\n",
|
||||
"\n",
|
||||
"LEARN_FREQ = 5 # training frequency\n",
|
||||
"MEMORY_SIZE = 200000\n",
|
||||
"MEMORY_WARMUP_SIZE = 200\n",
|
||||
"BATCH_SIZE = 64\n",
|
||||
"LEARNING_RATE = 0.0005\n",
|
||||
"GAMMA = 0.99\n",
|
||||
"\n",
|
||||
"# train an episode\n",
|
||||
"def run_train_episode(agent, env, rpm):\n",
|
||||
" total_reward = 0\n",
|
||||
" obs = env.reset()\n",
|
||||
" step = 0\n",
|
||||
" while True:\n",
|
||||
" step += 1\n",
|
||||
" action = agent.sample(obs)\n",
|
||||
" next_obs, reward, done, _ = env.step(action)\n",
|
||||
" rpm.append(obs, action, reward, next_obs, done)\n",
|
||||
"\n",
|
||||
" # train model\n",
|
||||
" if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):\n",
|
||||
" # s,a,r,s',done\n",
|
||||
" (batch_obs, batch_action, batch_reward, batch_next_obs,\n",
|
||||
" batch_done) = rpm.sample_batch(BATCH_SIZE)\n",
|
||||
" train_loss = agent.learn(batch_obs, batch_action, batch_reward,\n",
|
||||
" batch_next_obs, batch_done)\n",
|
||||
"\n",
|
||||
" total_reward += reward\n",
|
||||
" obs = next_obs\n",
|
||||
" if done:\n",
|
||||
" break\n",
|
||||
" return total_reward\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# evaluate 5 episodes\n",
|
||||
"def run_evaluate_episodes(agent, env, eval_episodes=5, render=False):\n",
|
||||
" eval_reward = []\n",
|
||||
" for i in range(eval_episodes):\n",
|
||||
" obs = env.reset()\n",
|
||||
" episode_reward = 0\n",
|
||||
" while True:\n",
|
||||
" action = agent.predict(obs)\n",
|
||||
" obs, reward, done, _ = env.step(action)\n",
|
||||
" episode_reward += reward\n",
|
||||
" if render:\n",
|
||||
" env.render()\n",
|
||||
" if done:\n",
|
||||
" break\n",
|
||||
" eval_reward.append(episode_reward)\n",
|
||||
" return np.mean(eval_reward)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def main(args):\n",
|
||||
" env = gym.make('CartPole-v0')\n",
|
||||
" n_states = env.observation_space.shape[0]\n",
|
||||
" n_actions = env.action_space.n\n",
|
||||
" logger.info('n_states {}, n_actions {}'.format(n_states, n_actions))\n",
|
||||
"\n",
|
||||
" # set action_shape = 0 while in discrete control environment\n",
|
||||
" rpm = ReplayMemory(MEMORY_SIZE, n_states, 0)\n",
|
||||
"\n",
|
||||
" # build an agent\n",
|
||||
" model = CartpoleModel(n_states=n_states, n_actions=n_actions)\n",
|
||||
" alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)\n",
|
||||
" agent = CartpoleAgent(\n",
|
||||
" alg, n_actions=n_actions, e_greed=0.1, e_greed_decrement=1e-6)\n",
|
||||
"\n",
|
||||
" # warmup memory\n",
|
||||
" while len(rpm) < MEMORY_WARMUP_SIZE:\n",
|
||||
" run_train_episode(agent, env, rpm)\n",
|
||||
"\n",
|
||||
" max_episode = args.max_episode\n",
|
||||
"\n",
|
||||
" # start training\n",
|
||||
" episode = 0\n",
|
||||
" while episode < max_episode:\n",
|
||||
" # train part\n",
|
||||
" for i in range(50):\n",
|
||||
" total_reward = run_train_episode(agent, env, rpm)\n",
|
||||
" episode += 1\n",
|
||||
"\n",
|
||||
" # test part\n",
|
||||
" eval_reward = run_evaluate_episodes(agent, env, render=False)\n",
|
||||
" logger.info('episode:{} e_greed:{} Test reward:{}'.format(\n",
|
||||
" episode, agent.e_greed, eval_reward))\n",
|
||||
"\n",
|
||||
" # save the parameters to ./model.ckpt\n",
|
||||
" save_path = './model.ckpt'\n",
|
||||
" agent.save(save_path)\n",
|
||||
"\n",
|
||||
" # save the model and parameters of policy network for inference\n",
|
||||
" save_inference_path = './inference_model'\n",
|
||||
" input_shapes = [[None, env.observation_space.shape[0]]]\n",
|
||||
" input_dtypes = ['float32']\n",
|
||||
" agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:64]\u001b[0m obs_dim 4, act_dim 2\n",
|
||||
"\u001b[32m[08-01 21:48:19 MainThread @3996942455.py:92]\u001b[0m episode:50 e_greed:0.0988929999999989 Test reward:18.4\n",
|
||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:100 e_greed:0.09794799999999795 Test reward:9.6\n",
|
||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:150 e_greed:0.0973899999999974 Test reward:37.8\n",
|
||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:200 e_greed:0.09684299999999685 Test reward:8.8\n",
|
||||
"\u001b[32m[08-01 21:48:20 MainThread @3996942455.py:92]\u001b[0m episode:250 e_greed:0.09635499999999636 Test reward:9.4\n",
|
||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:300 e_greed:0.09585299999999586 Test reward:9.2\n",
|
||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:350 e_greed:0.09535799999999536 Test reward:9.2\n",
|
||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:400 e_greed:0.09486399999999487 Test reward:10.0\n",
|
||||
"\u001b[32m[08-01 21:48:21 MainThread @3996942455.py:92]\u001b[0m episode:450 e_greed:0.09435299999999436 Test reward:9.2\n",
|
||||
"\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:500 e_greed:0.09384899999999385 Test reward:9.4\n",
|
||||
"\u001b[32m[08-01 21:48:22 MainThread @3996942455.py:92]\u001b[0m episode:550 e_greed:0.09302299999999303 Test reward:69.0\n",
|
||||
"\u001b[32m[08-01 21:48:25 MainThread @3996942455.py:92]\u001b[0m episode:600 e_greed:0.08774199999998775 Test reward:141.2\n",
|
||||
"\u001b[32m[08-01 21:48:30 MainThread @3996942455.py:92]\u001b[0m episode:650 e_greed:0.0791019999999791 Test reward:184.0\n",
|
||||
"\u001b[32m[08-01 21:48:35 MainThread @3996942455.py:92]\u001b[0m episode:700 e_greed:0.07011299999997012 Test reward:182.0\n",
|
||||
"\u001b[32m[08-01 21:48:40 MainThread @3996942455.py:92]\u001b[0m episode:750 e_greed:0.06089099999996089 Test reward:197.4\n",
|
||||
"\u001b[32m[08-01 21:48:45 MainThread @3996942455.py:92]\u001b[0m episode:800 e_greed:0.05139199999995139 Test reward:183.4\n",
|
||||
"\u001b[32m[08-01 21:48:50 MainThread @3996942455.py:92]\u001b[0m episode:850 e_greed:0.042255999999942256 Test reward:153.0\n",
|
||||
"\u001b[32m[08-01 21:48:55 MainThread @3996942455.py:92]\u001b[0m episode:900 e_greed:0.033495999999933496 Test reward:192.6\n",
|
||||
"\u001b[32m[08-01 21:49:00 MainThread @3996942455.py:92]\u001b[0m episode:950 e_greed:0.024318999999924318 Test reward:166.6\n",
|
||||
"\u001b[32m[08-01 21:49:06 MainThread @3996942455.py:92]\u001b[0m episode:1000 e_greed:0.014873999999916176 Test reward:187.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import argparse\n",
|
||||
"parser = argparse.ArgumentParser()\n",
|
||||
"parser.add_argument(\n",
|
||||
" '--max_episode',\n",
|
||||
" type=int,\n",
|
||||
" default=1000,\n",
|
||||
" help='stop condition: number of max episode')\n",
|
||||
"args = parser.parse_args(args=[])\n",
|
||||
"\n",
|
||||
"main(args)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.7.12 ('rl_tutorials')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.12"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "4f613f1ab80ec98dc1b91d6e720de51301598a187317378e53e49b773c1123dd"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
[PARL](https://github.com/PaddlePaddle/PARL)是一个高性能、灵活的强化学习框架,由百度AI Studio开发。
|
||||
|
||||
## 安装
|
||||
|
||||
1. 安装parl,参考[PARL Github](https://github.com/PaddlePaddle/PARL)
|
||||
2. 安装paddlepaddle:```pip install paddlepaddle```
|
||||
|
||||
## 常见问题
|
||||
|
||||
```jupyter-client 7.3.1 requires pyzmq>=22.3, but you have pyzmq 18.1.1 which is incompatible.```:
|
||||
```pip install -U pyzmq```
|
||||
@@ -11,7 +11,6 @@
|
||||
项目内容主要包含以下几个部分:
|
||||
* [Jupyter Notebook](./notebooks/):使用Notebook写的算法,有比较详细的实战引导,推荐新手食用
|
||||
* [codes](./codes/):这些是基于Python脚本写的算法,风格比较接近实际项目的写法,推荐有一定代码基础的人阅读,下面会说明其具体的一些架构
|
||||
* [parl](./PARL/):应业务需求,写了一些基于百度飞浆平台和```parl```模块的RL实例
|
||||
* [附件](./assets/):目前包含强化学习各算法的中文伪代码
|
||||
|
||||
|
||||
|
||||
129
projects/codes/PolicyGradient/main.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:21:53
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 20:59:23
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add to system path
|
||||
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import argparse
|
||||
from itertools import count
|
||||
import torch.nn.functional as F
|
||||
from pg import PolicyGradient
|
||||
from common.utils import save_results, make_dir,all_seed,save_args,plot_rewards
|
||||
from common.models import MLP
|
||||
from common.memories import PGReplay
|
||||
from common.launcher import Launcher
|
||||
from envs.register import register_env
|
||||
|
||||
|
||||
class PGNet(MLP):
|
||||
''' instead of outputing action, PG Net outputs propabilities of actions, we can use class inheritance from MLP here
|
||||
'''
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = F.sigmoid(self.fc3(x))
|
||||
return x
|
||||
|
||||
class Main(Launcher):
|
||||
def get_args(self):
|
||||
""" Hyperparameters
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # Obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='PolicyGradient',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
|
||||
parser.add_argument('--update_fre',default=8,type=int)
|
||||
parser.add_argument('--hidden_dim',default=36,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=1,type=int,help="seed")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
def env_agent_config(self,cfg):
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg['seed'])
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"state dim: {n_states}, action dim: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
model = PGNet(n_states,1,hidden_dim=cfg['hidden_dim'])
|
||||
memory = PGReplay()
|
||||
agent = PolicyGradient(model,memory,cfg)
|
||||
return env,agent
|
||||
def train(self,cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = []
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
if done:
|
||||
reward = 0
|
||||
agent.memory.push((state,float(action),reward))
|
||||
state = next_state
|
||||
if done:
|
||||
print(f"Episode:{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
|
||||
break
|
||||
if (i_ep+1) % cfg['update_fre'] == 0:
|
||||
agent.update()
|
||||
rewards.append(ep_reward)
|
||||
print('Finish training!')
|
||||
env.close() # close environment
|
||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
return res_dic
|
||||
|
||||
def test(self,cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = []
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
if done:
|
||||
reward = 0
|
||||
state = next_state
|
||||
if done:
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print("Finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main = Main()
|
||||
main.run()
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"algo_name": "PolicyGradient",
|
||||
"env_name": "CartPole-v0",
|
||||
"train_eps": 200,
|
||||
"test_eps": 20,
|
||||
"gamma": 0.99,
|
||||
"lr": 0.005,
|
||||
"update_fre": 8,
|
||||
"hidden_dim": 36,
|
||||
"device": "cpu",
|
||||
"seed": 1,
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220822-174059/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220822-174059/models/",
|
||||
"save_fig": true,
|
||||
"show_fig": false
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.99, "lr": 0.005, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/models/", "n_states": 4, "n_actions": 2}
|
||||
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 66 KiB |
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:27:44
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-22 17:35:34
|
||||
LastEditTime: 2022-08-25 20:58:59
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -19,12 +19,12 @@ import numpy as np
|
||||
|
||||
class PolicyGradient:
|
||||
|
||||
def __init__(self, n_states,model,memory,cfg):
|
||||
self.gamma = cfg.gamma
|
||||
self.device = torch.device(cfg.device)
|
||||
def __init__(self, model,memory,cfg):
|
||||
self.gamma = cfg['gamma']
|
||||
self.device = torch.device(cfg['device'])
|
||||
self.memory = memory
|
||||
self.policy_net = model.to(self.device)
|
||||
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg['lr'])
|
||||
|
||||
def sample_action(self,state):
|
||||
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:21:53
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-22 17:40:07
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add to system path
|
||||
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import argparse
|
||||
from itertools import count
|
||||
import torch.nn.functional as F
|
||||
from pg import PolicyGradient
|
||||
from common.utils import save_results, make_dir,all_seed,save_args,plot_rewards
|
||||
from common.models import MLP
|
||||
from common.memories import PGReplay
|
||||
|
||||
|
||||
def get_args():
|
||||
""" Hyperparameters
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # Obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='PolicyGradient',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
|
||||
parser.add_argument('--update_fre',default=8,type=int)
|
||||
parser.add_argument('--hidden_dim',default=36,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=1,type=int,help="seed")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' ) # path to save models
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
args = parser.parse_args([])
|
||||
return args
|
||||
|
||||
class PGNet(MLP):
|
||||
''' instead of outputing action, PG Net outputs propabilities of actions, we can use class inheritance from MLP here
|
||||
'''
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = F.sigmoid(self.fc3(x))
|
||||
return x
|
||||
|
||||
def env_agent_config(cfg):
|
||||
env = gym.make(cfg.env_name)
|
||||
if cfg.seed !=0: # set random seed
|
||||
all_seed(env,seed=cfg.seed)
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"state dim: {n_states}, action dim: {n_actions}")
|
||||
model = PGNet(n_states,1,hidden_dim=cfg.hidden_dim)
|
||||
memory = PGReplay()
|
||||
agent = PolicyGradient(n_states,model,memory,cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print('Start training!')
|
||||
print(f'Env:{cfg.env_name}, Algo:{cfg.algo_name}, Device:{cfg.device}')
|
||||
rewards = []
|
||||
for i_ep in range(cfg.train_eps):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
if done:
|
||||
reward = 0
|
||||
agent.memory.push((state,float(action),reward))
|
||||
state = next_state
|
||||
if done:
|
||||
print(f'Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.2f}')
|
||||
break
|
||||
if (i_ep+1) % cfg.update_fre == 0:
|
||||
agent.update()
|
||||
rewards.append(ep_reward)
|
||||
print('Finish training!')
|
||||
env.close() # close environment
|
||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
return res_dic
|
||||
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print("start testing!")
|
||||
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
|
||||
rewards = []
|
||||
for i_ep in range(cfg.test_eps):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
if done:
|
||||
reward = 0
|
||||
state = next_state
|
||||
if done:
|
||||
print(f'Episode: {i_ep+1}/{cfg.test_eps},Reward: {ep_reward:.2f}')
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print("finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
save_args(cfg,path = cfg.result_path) # save parameters
|
||||
agent.save_model(path = cfg.model_path) # save models
|
||||
save_results(res_dic, tag = 'train', path = cfg.result_path) # save results
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train") # plot results
|
||||
# testing
|
||||
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
agent.load_model(path = cfg.model_path) # load model
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path = cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-24 11:27:01
|
||||
LastEditTime: 2022-08-25 14:59:15
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -18,136 +18,102 @@ sys.path.append(parent_path) # add path to system path
|
||||
import gym
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.gridworld_env import CliffWalkingWapper,FrozenLakeWapper
|
||||
from envs.gridworld_env import FrozenLakeWapper
|
||||
from envs.wrappers import CliffWalkingWapper
|
||||
from envs.register import register_env
|
||||
from qlearning import QLearning
|
||||
from common.utils import plot_rewards,save_args,all_seed
|
||||
from common.utils import save_results,make_dir
|
||||
|
||||
def get_args():
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Q-learning',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
def env_agent_config(cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = gym.make(cfg['env_name'])
|
||||
env = CliffWalkingWapper(env)
|
||||
if cfg['env_name'] == 'FrozenLake-v1':
|
||||
env = gym.make(cfg['env_name'],is_slippery=False)
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
n_states = env.observation_space.n # state dimension
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = QLearning(cfg)
|
||||
return env,agent
|
||||
|
||||
def main(cfg,env,agent,tag = 'train'):
|
||||
print(f"Start {tag}ing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # 记录奖励
|
||||
for i_ep in range(cfg.train_eps):
|
||||
ep_reward = 0 # 记录每个回合的奖励
|
||||
state = env.reset() # 重置环境,即开始新的回合
|
||||
while True:
|
||||
if tag == 'train':action = agent.sample_action(state) # 根据算法采样一个动作
|
||||
else: agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
|
||||
if tag == 'train':agent.update(state, action, reward, next_state, done) # Q学习算法更新
|
||||
state = next_state # 更新状态
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
||||
print(f"Finish {tag}ing!")
|
||||
return {"rewards":rewards}
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
agent.update(state, action, reward, next_state, done) # update agent
|
||||
state = next_state # update state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
from common.utils import all_seed
|
||||
from common.launcher import Launcher
|
||||
|
||||
class Main(Launcher):
|
||||
def get_args(self):
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Q-learning',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
def env_agent_config(self,cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = CliffWalkingWapper(env)
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
n_states = env.observation_space.n # state dimension
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = QLearning(cfg)
|
||||
return env,agent
|
||||
def train(self,cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
agent.update(state, action, reward, next_state, done) # update agent
|
||||
state = next_state # update state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
def test(self,cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# training
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
save_args(cfg,path = cfg['result_path']) # save parameters
|
||||
agent.save_model(path = cfg['model_path']) # save models
|
||||
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
||||
# testing
|
||||
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
agent.load_model(path = cfg['model_path']) # load model
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path = cfg['result_path'])
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "test")
|
||||
main = Main()
|
||||
main.run()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 22 KiB |
|
Before Width: | Height: | Size: 53 KiB |
@@ -1,801 +0,0 @@
|
||||
episodes,rewards,steps
|
||||
0,0.0,20
|
||||
1,0.0,14
|
||||
2,0.0,13
|
||||
3,0.0,9
|
||||
4,0.0,10
|
||||
5,0.0,6
|
||||
6,0.0,11
|
||||
7,0.0,6
|
||||
8,0.0,3
|
||||
9,0.0,9
|
||||
10,0.0,11
|
||||
11,0.0,22
|
||||
12,0.0,5
|
||||
13,0.0,16
|
||||
14,0.0,4
|
||||
15,0.0,9
|
||||
16,0.0,18
|
||||
17,0.0,2
|
||||
18,0.0,4
|
||||
19,0.0,8
|
||||
20,0.0,7
|
||||
21,0.0,4
|
||||
22,0.0,22
|
||||
23,0.0,15
|
||||
24,0.0,5
|
||||
25,0.0,16
|
||||
26,0.0,7
|
||||
27,0.0,19
|
||||
28,0.0,22
|
||||
29,0.0,16
|
||||
30,0.0,11
|
||||
31,0.0,22
|
||||
32,0.0,28
|
||||
33,0.0,23
|
||||
34,0.0,4
|
||||
35,0.0,11
|
||||
36,0.0,8
|
||||
37,0.0,15
|
||||
38,0.0,5
|
||||
39,0.0,7
|
||||
40,0.0,9
|
||||
41,0.0,4
|
||||
42,0.0,3
|
||||
43,0.0,6
|
||||
44,0.0,41
|
||||
45,0.0,9
|
||||
46,0.0,23
|
||||
47,0.0,3
|
||||
48,1.0,38
|
||||
49,0.0,29
|
||||
50,0.0,17
|
||||
51,0.0,4
|
||||
52,0.0,2
|
||||
53,0.0,25
|
||||
54,0.0,6
|
||||
55,0.0,2
|
||||
56,0.0,30
|
||||
57,0.0,6
|
||||
58,0.0,7
|
||||
59,0.0,11
|
||||
60,0.0,9
|
||||
61,0.0,8
|
||||
62,0.0,23
|
||||
63,0.0,10
|
||||
64,0.0,3
|
||||
65,0.0,5
|
||||
66,0.0,7
|
||||
67,0.0,18
|
||||
68,0.0,8
|
||||
69,0.0,26
|
||||
70,0.0,6
|
||||
71,0.0,14
|
||||
72,0.0,4
|
||||
73,0.0,25
|
||||
74,0.0,21
|
||||
75,0.0,13
|
||||
76,0.0,4
|
||||
77,0.0,29
|
||||
78,0.0,21
|
||||
79,0.0,6
|
||||
80,0.0,6
|
||||
81,0.0,11
|
||||
82,0.0,21
|
||||
83,0.0,9
|
||||
84,0.0,9
|
||||
85,0.0,7
|
||||
86,0.0,48
|
||||
87,0.0,23
|
||||
88,0.0,100
|
||||
89,0.0,60
|
||||
90,0.0,7
|
||||
91,0.0,10
|
||||
92,0.0,24
|
||||
93,0.0,4
|
||||
94,0.0,7
|
||||
95,0.0,17
|
||||
96,0.0,87
|
||||
97,0.0,28
|
||||
98,0.0,7
|
||||
99,0.0,5
|
||||
100,0.0,12
|
||||
101,0.0,14
|
||||
102,0.0,6
|
||||
103,0.0,13
|
||||
104,0.0,93
|
||||
105,0.0,4
|
||||
106,0.0,50
|
||||
107,0.0,8
|
||||
108,0.0,12
|
||||
109,0.0,43
|
||||
110,0.0,30
|
||||
111,0.0,15
|
||||
112,0.0,19
|
||||
113,0.0,100
|
||||
114,0.0,82
|
||||
115,0.0,40
|
||||
116,0.0,88
|
||||
117,0.0,19
|
||||
118,0.0,30
|
||||
119,0.0,27
|
||||
120,0.0,5
|
||||
121,0.0,87
|
||||
122,0.0,9
|
||||
123,0.0,64
|
||||
124,0.0,27
|
||||
125,0.0,68
|
||||
126,0.0,81
|
||||
127,0.0,86
|
||||
128,0.0,100
|
||||
129,0.0,100
|
||||
130,0.0,27
|
||||
131,0.0,41
|
||||
132,0.0,70
|
||||
133,0.0,27
|
||||
134,0.0,6
|
||||
135,0.0,18
|
||||
136,0.0,38
|
||||
137,0.0,26
|
||||
138,0.0,36
|
||||
139,0.0,3
|
||||
140,0.0,61
|
||||
141,0.0,100
|
||||
142,0.0,4
|
||||
143,0.0,39
|
||||
144,0.0,18
|
||||
145,0.0,33
|
||||
146,0.0,29
|
||||
147,0.0,49
|
||||
148,0.0,88
|
||||
149,0.0,22
|
||||
150,0.0,65
|
||||
151,0.0,36
|
||||
152,0.0,30
|
||||
153,0.0,58
|
||||
154,0.0,43
|
||||
155,0.0,53
|
||||
156,0.0,43
|
||||
157,0.0,13
|
||||
158,0.0,8
|
||||
159,0.0,39
|
||||
160,0.0,29
|
||||
161,0.0,26
|
||||
162,0.0,60
|
||||
163,0.0,100
|
||||
164,0.0,31
|
||||
165,0.0,22
|
||||
166,0.0,100
|
||||
167,0.0,46
|
||||
168,0.0,23
|
||||
169,0.0,54
|
||||
170,0.0,8
|
||||
171,0.0,58
|
||||
172,0.0,3
|
||||
173,0.0,47
|
||||
174,0.0,16
|
||||
175,0.0,21
|
||||
176,0.0,44
|
||||
177,0.0,29
|
||||
178,0.0,100
|
||||
179,0.0,100
|
||||
180,0.0,62
|
||||
181,0.0,83
|
||||
182,0.0,26
|
||||
183,0.0,24
|
||||
184,0.0,10
|
||||
185,0.0,12
|
||||
186,0.0,40
|
||||
187,0.0,25
|
||||
188,0.0,18
|
||||
189,0.0,60
|
||||
190,0.0,100
|
||||
191,0.0,100
|
||||
192,0.0,24
|
||||
193,0.0,56
|
||||
194,0.0,71
|
||||
195,0.0,19
|
||||
196,0.0,100
|
||||
197,0.0,44
|
||||
198,0.0,41
|
||||
199,0.0,41
|
||||
200,0.0,60
|
||||
201,0.0,31
|
||||
202,0.0,34
|
||||
203,0.0,35
|
||||
204,0.0,59
|
||||
205,0.0,51
|
||||
206,0.0,100
|
||||
207,0.0,100
|
||||
208,0.0,100
|
||||
209,0.0,100
|
||||
210,0.0,37
|
||||
211,0.0,68
|
||||
212,0.0,40
|
||||
213,0.0,17
|
||||
214,0.0,79
|
||||
215,0.0,100
|
||||
216,0.0,26
|
||||
217,0.0,61
|
||||
218,0.0,25
|
||||
219,0.0,18
|
||||
220,0.0,27
|
||||
221,0.0,13
|
||||
222,0.0,100
|
||||
223,0.0,87
|
||||
224,0.0,100
|
||||
225,0.0,92
|
||||
226,0.0,100
|
||||
227,0.0,8
|
||||
228,0.0,100
|
||||
229,0.0,64
|
||||
230,0.0,17
|
||||
231,0.0,82
|
||||
232,0.0,100
|
||||
233,0.0,94
|
||||
234,0.0,7
|
||||
235,0.0,36
|
||||
236,0.0,100
|
||||
237,0.0,56
|
||||
238,0.0,17
|
||||
239,0.0,100
|
||||
240,0.0,83
|
||||
241,0.0,100
|
||||
242,0.0,100
|
||||
243,0.0,43
|
||||
244,0.0,87
|
||||
245,0.0,42
|
||||
246,0.0,80
|
||||
247,0.0,54
|
||||
248,0.0,82
|
||||
249,0.0,97
|
||||
250,0.0,65
|
||||
251,0.0,83
|
||||
252,0.0,100
|
||||
253,0.0,59
|
||||
254,0.0,100
|
||||
255,0.0,78
|
||||
256,0.0,100
|
||||
257,0.0,100
|
||||
258,0.0,43
|
||||
259,0.0,80
|
||||
260,0.0,100
|
||||
261,0.0,70
|
||||
262,0.0,94
|
||||
263,0.0,100
|
||||
264,0.0,100
|
||||
265,0.0,37
|
||||
266,0.0,11
|
||||
267,0.0,31
|
||||
268,0.0,100
|
||||
269,0.0,34
|
||||
270,0.0,32
|
||||
271,0.0,58
|
||||
272,0.0,38
|
||||
273,0.0,28
|
||||
274,0.0,100
|
||||
275,0.0,59
|
||||
276,0.0,100
|
||||
277,0.0,82
|
||||
278,0.0,51
|
||||
279,0.0,25
|
||||
280,0.0,73
|
||||
281,0.0,56
|
||||
282,0.0,55
|
||||
283,0.0,38
|
||||
284,0.0,100
|
||||
285,0.0,100
|
||||
286,0.0,92
|
||||
287,0.0,100
|
||||
288,0.0,100
|
||||
289,0.0,100
|
||||
290,0.0,37
|
||||
291,0.0,100
|
||||
292,0.0,66
|
||||
293,0.0,24
|
||||
294,0.0,17
|
||||
295,0.0,100
|
||||
296,0.0,59
|
||||
297,0.0,25
|
||||
298,0.0,73
|
||||
299,0.0,100
|
||||
300,0.0,29
|
||||
301,0.0,100
|
||||
302,0.0,72
|
||||
303,0.0,6
|
||||
304,1.0,57
|
||||
305,0.0,47
|
||||
306,0.0,48
|
||||
307,0.0,13
|
||||
308,0.0,100
|
||||
309,0.0,38
|
||||
310,0.0,100
|
||||
311,0.0,20
|
||||
312,0.0,100
|
||||
313,0.0,100
|
||||
314,0.0,5
|
||||
315,0.0,39
|
||||
316,0.0,11
|
||||
317,0.0,83
|
||||
318,0.0,42
|
||||
319,0.0,100
|
||||
320,0.0,99
|
||||
321,0.0,83
|
||||
322,0.0,28
|
||||
323,0.0,46
|
||||
324,0.0,100
|
||||
325,0.0,100
|
||||
326,0.0,62
|
||||
327,0.0,100
|
||||
328,0.0,23
|
||||
329,0.0,91
|
||||
330,0.0,53
|
||||
331,0.0,19
|
||||
332,0.0,26
|
||||
333,0.0,93
|
||||
334,0.0,38
|
||||
335,0.0,22
|
||||
336,0.0,43
|
||||
337,0.0,100
|
||||
338,0.0,90
|
||||
339,0.0,18
|
||||
340,0.0,45
|
||||
341,0.0,65
|
||||
342,1.0,22
|
||||
343,0.0,100
|
||||
344,1.0,15
|
||||
345,1.0,72
|
||||
346,0.0,5
|
||||
347,1.0,6
|
||||
348,1.0,6
|
||||
349,1.0,9
|
||||
350,1.0,8
|
||||
351,1.0,9
|
||||
352,1.0,8
|
||||
353,1.0,6
|
||||
354,1.0,6
|
||||
355,1.0,10
|
||||
356,1.0,6
|
||||
357,0.0,5
|
||||
358,0.0,3
|
||||
359,1.0,6
|
||||
360,1.0,6
|
||||
361,1.0,6
|
||||
362,1.0,6
|
||||
363,1.0,8
|
||||
364,1.0,6
|
||||
365,1.0,8
|
||||
366,1.0,6
|
||||
367,1.0,6
|
||||
368,1.0,8
|
||||
369,1.0,6
|
||||
370,1.0,6
|
||||
371,0.0,5
|
||||
372,1.0,6
|
||||
373,0.0,6
|
||||
374,1.0,6
|
||||
375,1.0,12
|
||||
376,1.0,6
|
||||
377,1.0,6
|
||||
378,1.0,9
|
||||
379,1.0,6
|
||||
380,1.0,6
|
||||
381,0.0,2
|
||||
382,0.0,3
|
||||
383,0.0,2
|
||||
384,0.0,4
|
||||
385,0.0,3
|
||||
386,1.0,7
|
||||
387,1.0,6
|
||||
388,1.0,6
|
||||
389,1.0,8
|
||||
390,1.0,9
|
||||
391,1.0,8
|
||||
392,1.0,8
|
||||
393,1.0,6
|
||||
394,1.0,6
|
||||
395,1.0,7
|
||||
396,1.0,6
|
||||
397,0.0,5
|
||||
398,0.0,5
|
||||
399,1.0,10
|
||||
400,1.0,6
|
||||
401,0.0,3
|
||||
402,1.0,6
|
||||
403,1.0,7
|
||||
404,1.0,6
|
||||
405,1.0,6
|
||||
406,1.0,6
|
||||
407,1.0,6
|
||||
408,1.0,6
|
||||
409,1.0,6
|
||||
410,1.0,6
|
||||
411,0.0,5
|
||||
412,1.0,6
|
||||
413,1.0,6
|
||||
414,0.0,2
|
||||
415,1.0,6
|
||||
416,1.0,6
|
||||
417,1.0,6
|
||||
418,1.0,6
|
||||
419,1.0,6
|
||||
420,1.0,8
|
||||
421,1.0,6
|
||||
422,1.0,6
|
||||
423,1.0,6
|
||||
424,1.0,6
|
||||
425,1.0,7
|
||||
426,0.0,5
|
||||
427,1.0,6
|
||||
428,1.0,6
|
||||
429,1.0,6
|
||||
430,1.0,8
|
||||
431,1.0,6
|
||||
432,1.0,6
|
||||
433,1.0,6
|
||||
434,1.0,6
|
||||
435,0.0,2
|
||||
436,1.0,8
|
||||
437,1.0,7
|
||||
438,1.0,6
|
||||
439,1.0,7
|
||||
440,1.0,6
|
||||
441,1.0,6
|
||||
442,0.0,3
|
||||
443,0.0,4
|
||||
444,1.0,6
|
||||
445,1.0,6
|
||||
446,1.0,7
|
||||
447,1.0,6
|
||||
448,1.0,6
|
||||
449,1.0,6
|
||||
450,1.0,6
|
||||
451,1.0,6
|
||||
452,1.0,6
|
||||
453,1.0,8
|
||||
454,1.0,6
|
||||
455,1.0,6
|
||||
456,1.0,6
|
||||
457,1.0,6
|
||||
458,1.0,6
|
||||
459,1.0,7
|
||||
460,1.0,8
|
||||
461,1.0,6
|
||||
462,1.0,7
|
||||
463,1.0,6
|
||||
464,1.0,6
|
||||
465,1.0,6
|
||||
466,1.0,6
|
||||
467,1.0,8
|
||||
468,1.0,6
|
||||
469,1.0,6
|
||||
470,1.0,8
|
||||
471,1.0,6
|
||||
472,1.0,11
|
||||
473,1.0,6
|
||||
474,1.0,6
|
||||
475,1.0,6
|
||||
476,1.0,8
|
||||
477,0.0,2
|
||||
478,1.0,7
|
||||
479,1.0,6
|
||||
480,1.0,6
|
||||
481,1.0,7
|
||||
482,1.0,6
|
||||
483,1.0,6
|
||||
484,1.0,6
|
||||
485,1.0,6
|
||||
486,0.0,3
|
||||
487,1.0,7
|
||||
488,1.0,6
|
||||
489,1.0,6
|
||||
490,1.0,6
|
||||
491,0.0,3
|
||||
492,1.0,6
|
||||
493,1.0,7
|
||||
494,1.0,12
|
||||
495,1.0,6
|
||||
496,0.0,9
|
||||
497,1.0,6
|
||||
498,1.0,6
|
||||
499,0.0,8
|
||||
500,1.0,6
|
||||
501,0.0,3
|
||||
502,0.0,5
|
||||
503,0.0,3
|
||||
504,1.0,6
|
||||
505,1.0,6
|
||||
506,1.0,6
|
||||
507,1.0,6
|
||||
508,1.0,6
|
||||
509,1.0,6
|
||||
510,1.0,6
|
||||
511,1.0,6
|
||||
512,1.0,6
|
||||
513,1.0,6
|
||||
514,0.0,2
|
||||
515,1.0,7
|
||||
516,1.0,6
|
||||
517,1.0,6
|
||||
518,1.0,6
|
||||
519,1.0,6
|
||||
520,1.0,6
|
||||
521,1.0,7
|
||||
522,0.0,4
|
||||
523,1.0,6
|
||||
524,0.0,5
|
||||
525,1.0,6
|
||||
526,1.0,6
|
||||
527,1.0,6
|
||||
528,1.0,6
|
||||
529,0.0,3
|
||||
530,1.0,6
|
||||
531,1.0,6
|
||||
532,1.0,6
|
||||
533,1.0,7
|
||||
534,1.0,8
|
||||
535,1.0,6
|
||||
536,1.0,6
|
||||
537,1.0,6
|
||||
538,1.0,6
|
||||
539,1.0,7
|
||||
540,1.0,7
|
||||
541,1.0,7
|
||||
542,1.0,8
|
||||
543,1.0,6
|
||||
544,1.0,10
|
||||
545,1.0,6
|
||||
546,1.0,6
|
||||
547,1.0,6
|
||||
548,1.0,8
|
||||
549,1.0,6
|
||||
550,1.0,6
|
||||
551,1.0,8
|
||||
552,1.0,6
|
||||
553,1.0,7
|
||||
554,1.0,6
|
||||
555,1.0,7
|
||||
556,1.0,6
|
||||
557,1.0,6
|
||||
558,1.0,7
|
||||
559,1.0,7
|
||||
560,1.0,7
|
||||
561,1.0,6
|
||||
562,1.0,6
|
||||
563,1.0,6
|
||||
564,1.0,6
|
||||
565,1.0,6
|
||||
566,1.0,6
|
||||
567,1.0,6
|
||||
568,1.0,7
|
||||
569,0.0,4
|
||||
570,1.0,8
|
||||
571,1.0,8
|
||||
572,1.0,7
|
||||
573,1.0,6
|
||||
574,1.0,8
|
||||
575,1.0,6
|
||||
576,1.0,6
|
||||
577,1.0,7
|
||||
578,1.0,6
|
||||
579,1.0,6
|
||||
580,1.0,8
|
||||
581,1.0,7
|
||||
582,1.0,6
|
||||
583,1.0,6
|
||||
584,0.0,3
|
||||
585,1.0,11
|
||||
586,1.0,6
|
||||
587,1.0,8
|
||||
588,0.0,2
|
||||
589,1.0,6
|
||||
590,1.0,6
|
||||
591,1.0,6
|
||||
592,1.0,6
|
||||
593,1.0,8
|
||||
594,1.0,6
|
||||
595,1.0,7
|
||||
596,1.0,6
|
||||
597,1.0,7
|
||||
598,1.0,6
|
||||
599,1.0,8
|
||||
600,0.0,2
|
||||
601,1.0,6
|
||||
602,1.0,7
|
||||
603,1.0,6
|
||||
604,1.0,6
|
||||
605,1.0,10
|
||||
606,1.0,7
|
||||
607,1.0,6
|
||||
608,1.0,6
|
||||
609,1.0,6
|
||||
610,1.0,6
|
||||
611,1.0,6
|
||||
612,1.0,7
|
||||
613,0.0,4
|
||||
614,1.0,7
|
||||
615,1.0,6
|
||||
616,1.0,8
|
||||
617,0.0,3
|
||||
618,1.0,6
|
||||
619,1.0,6
|
||||
620,1.0,6
|
||||
621,1.0,6
|
||||
622,0.0,2
|
||||
623,1.0,6
|
||||
624,1.0,6
|
||||
625,1.0,6
|
||||
626,1.0,6
|
||||
627,1.0,6
|
||||
628,1.0,7
|
||||
629,1.0,6
|
||||
630,1.0,6
|
||||
631,1.0,7
|
||||
632,1.0,6
|
||||
633,1.0,6
|
||||
634,1.0,6
|
||||
635,1.0,6
|
||||
636,1.0,6
|
||||
637,1.0,6
|
||||
638,1.0,6
|
||||
639,1.0,8
|
||||
640,1.0,6
|
||||
641,1.0,8
|
||||
642,1.0,7
|
||||
643,1.0,6
|
||||
644,0.0,3
|
||||
645,1.0,6
|
||||
646,1.0,7
|
||||
647,1.0,6
|
||||
648,1.0,6
|
||||
649,1.0,6
|
||||
650,1.0,10
|
||||
651,1.0,6
|
||||
652,1.0,6
|
||||
653,1.0,6
|
||||
654,1.0,6
|
||||
655,1.0,10
|
||||
656,1.0,6
|
||||
657,1.0,8
|
||||
658,1.0,8
|
||||
659,1.0,7
|
||||
660,1.0,6
|
||||
661,0.0,5
|
||||
662,0.0,2
|
||||
663,1.0,8
|
||||
664,1.0,6
|
||||
665,1.0,10
|
||||
666,1.0,6
|
||||
667,1.0,8
|
||||
668,1.0,10
|
||||
669,1.0,6
|
||||
670,1.0,6
|
||||
671,1.0,6
|
||||
672,1.0,10
|
||||
673,1.0,6
|
||||
674,0.0,4
|
||||
675,1.0,6
|
||||
676,1.0,6
|
||||
677,1.0,6
|
||||
678,1.0,15
|
||||
679,1.0,6
|
||||
680,1.0,6
|
||||
681,1.0,6
|
||||
682,1.0,6
|
||||
683,1.0,6
|
||||
684,1.0,6
|
||||
685,1.0,8
|
||||
686,1.0,6
|
||||
687,1.0,7
|
||||
688,1.0,6
|
||||
689,1.0,6
|
||||
690,1.0,8
|
||||
691,1.0,6
|
||||
692,1.0,6
|
||||
693,1.0,8
|
||||
694,1.0,8
|
||||
695,1.0,6
|
||||
696,1.0,6
|
||||
697,1.0,6
|
||||
698,1.0,10
|
||||
699,1.0,6
|
||||
700,1.0,6
|
||||
701,1.0,6
|
||||
702,1.0,6
|
||||
703,1.0,6
|
||||
704,1.0,6
|
||||
705,1.0,6
|
||||
706,1.0,8
|
||||
707,1.0,8
|
||||
708,1.0,6
|
||||
709,1.0,6
|
||||
710,0.0,2
|
||||
711,1.0,6
|
||||
712,1.0,6
|
||||
713,1.0,6
|
||||
714,1.0,8
|
||||
715,1.0,6
|
||||
716,1.0,6
|
||||
717,1.0,6
|
||||
718,1.0,6
|
||||
719,1.0,6
|
||||
720,1.0,6
|
||||
721,1.0,6
|
||||
722,1.0,6
|
||||
723,1.0,6
|
||||
724,1.0,7
|
||||
725,0.0,3
|
||||
726,1.0,7
|
||||
727,1.0,6
|
||||
728,1.0,6
|
||||
729,1.0,6
|
||||
730,0.0,2
|
||||
731,1.0,6
|
||||
732,1.0,8
|
||||
733,1.0,6
|
||||
734,1.0,6
|
||||
735,1.0,6
|
||||
736,1.0,6
|
||||
737,1.0,9
|
||||
738,1.0,6
|
||||
739,1.0,6
|
||||
740,1.0,6
|
||||
741,1.0,6
|
||||
742,1.0,6
|
||||
743,1.0,6
|
||||
744,1.0,9
|
||||
745,1.0,7
|
||||
746,0.0,4
|
||||
747,1.0,6
|
||||
748,1.0,8
|
||||
749,1.0,11
|
||||
750,1.0,6
|
||||
751,1.0,6
|
||||
752,1.0,6
|
||||
753,1.0,6
|
||||
754,1.0,6
|
||||
755,1.0,8
|
||||
756,1.0,6
|
||||
757,1.0,6
|
||||
758,1.0,8
|
||||
759,1.0,7
|
||||
760,1.0,6
|
||||
761,1.0,8
|
||||
762,1.0,6
|
||||
763,0.0,5
|
||||
764,1.0,9
|
||||
765,1.0,8
|
||||
766,1.0,8
|
||||
767,1.0,6
|
||||
768,1.0,8
|
||||
769,1.0,8
|
||||
770,1.0,6
|
||||
771,0.0,5
|
||||
772,0.0,3
|
||||
773,0.0,2
|
||||
774,1.0,8
|
||||
775,1.0,6
|
||||
776,1.0,6
|
||||
777,1.0,6
|
||||
778,1.0,6
|
||||
779,1.0,6
|
||||
780,1.0,6
|
||||
781,1.0,6
|
||||
782,1.0,6
|
||||
783,1.0,6
|
||||
784,1.0,6
|
||||
785,1.0,6
|
||||
786,1.0,6
|
||||
787,1.0,6
|
||||
788,1.0,6
|
||||
789,0.0,2
|
||||
790,1.0,6
|
||||
791,0.0,4
|
||||
792,1.0,6
|
||||
793,1.0,6
|
||||
794,1.0,6
|
||||
795,1.0,6
|
||||
796,1.0,6
|
||||
797,1.0,8
|
||||
798,0.0,5
|
||||
799,1.0,6
|
||||
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"algo_name": "Q-learning",
|
||||
"env_name": "FrozenLake-v1",
|
||||
"env_name": "FrozenLakeNoSlippery-v1",
|
||||
"train_eps": 800,
|
||||
"test_eps": 20,
|
||||
"gamma": 0.9,
|
||||
@@ -12,8 +12,8 @@
|
||||
"seed": 10,
|
||||
"show_fig": false,
|
||||
"save_fig": true,
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLake-v1/20220824-112735/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLake-v1/20220824-112735/models/",
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLakeNoSlippery-v1/20220825-114335/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLakeNoSlippery-v1/20220825-114335/models/",
|
||||
"n_states": 16,
|
||||
"n_actions": 4
|
||||
}
|
||||
|
After Width: | Height: | Size: 24 KiB |
|
After Width: | Height: | Size: 55 KiB |
@@ -0,0 +1,801 @@
|
||||
episodes,rewards,steps
|
||||
0,0.0,20
|
||||
1,0.0,14
|
||||
2,0.0,13
|
||||
3,0.0,9
|
||||
4,0.0,10
|
||||
5,0.0,6
|
||||
6,0.0,11
|
||||
7,0.0,6
|
||||
8,0.0,3
|
||||
9,0.0,9
|
||||
10,0.0,11
|
||||
11,0.0,22
|
||||
12,0.0,5
|
||||
13,0.0,16
|
||||
14,0.0,4
|
||||
15,0.0,9
|
||||
16,0.0,18
|
||||
17,0.0,2
|
||||
18,0.0,4
|
||||
19,0.0,8
|
||||
20,0.0,7
|
||||
21,0.0,4
|
||||
22,0.0,22
|
||||
23,0.0,15
|
||||
24,0.0,5
|
||||
25,0.0,16
|
||||
26,0.0,7
|
||||
27,0.0,19
|
||||
28,0.0,22
|
||||
29,0.0,16
|
||||
30,0.0,11
|
||||
31,0.0,22
|
||||
32,0.0,28
|
||||
33,0.0,23
|
||||
34,0.0,4
|
||||
35,0.0,11
|
||||
36,0.0,8
|
||||
37,0.0,15
|
||||
38,0.0,5
|
||||
39,0.0,7
|
||||
40,0.0,9
|
||||
41,0.0,4
|
||||
42,0.0,3
|
||||
43,0.0,6
|
||||
44,0.0,41
|
||||
45,0.0,9
|
||||
46,0.0,23
|
||||
47,0.0,3
|
||||
48,1.0,38
|
||||
49,0.0,29
|
||||
50,0.0,17
|
||||
51,0.0,4
|
||||
52,0.0,2
|
||||
53,0.0,25
|
||||
54,0.0,6
|
||||
55,0.0,2
|
||||
56,0.0,30
|
||||
57,0.0,6
|
||||
58,0.0,7
|
||||
59,0.0,11
|
||||
60,0.0,9
|
||||
61,0.0,8
|
||||
62,0.0,23
|
||||
63,0.0,10
|
||||
64,0.0,3
|
||||
65,0.0,5
|
||||
66,0.0,7
|
||||
67,0.0,18
|
||||
68,0.0,8
|
||||
69,0.0,26
|
||||
70,0.0,6
|
||||
71,0.0,14
|
||||
72,0.0,4
|
||||
73,0.0,25
|
||||
74,0.0,21
|
||||
75,0.0,13
|
||||
76,0.0,4
|
||||
77,0.0,29
|
||||
78,0.0,21
|
||||
79,0.0,6
|
||||
80,0.0,6
|
||||
81,0.0,11
|
||||
82,0.0,21
|
||||
83,0.0,9
|
||||
84,0.0,9
|
||||
85,0.0,7
|
||||
86,0.0,48
|
||||
87,0.0,23
|
||||
88,0.0,160
|
||||
89,0.0,7
|
||||
90,0.0,10
|
||||
91,0.0,24
|
||||
92,0.0,4
|
||||
93,0.0,7
|
||||
94,0.0,17
|
||||
95,0.0,87
|
||||
96,0.0,28
|
||||
97,0.0,7
|
||||
98,0.0,5
|
||||
99,0.0,12
|
||||
100,0.0,14
|
||||
101,0.0,6
|
||||
102,0.0,13
|
||||
103,0.0,93
|
||||
104,0.0,4
|
||||
105,0.0,50
|
||||
106,0.0,8
|
||||
107,0.0,12
|
||||
108,0.0,43
|
||||
109,0.0,30
|
||||
110,0.0,15
|
||||
111,0.0,19
|
||||
112,0.0,182
|
||||
113,0.0,40
|
||||
114,0.0,88
|
||||
115,0.0,19
|
||||
116,0.0,30
|
||||
117,0.0,27
|
||||
118,0.0,5
|
||||
119,0.0,87
|
||||
120,0.0,9
|
||||
121,0.0,64
|
||||
122,0.0,27
|
||||
123,0.0,68
|
||||
124,0.0,81
|
||||
125,0.0,86
|
||||
126,0.0,227
|
||||
127,0.0,41
|
||||
128,0.0,70
|
||||
129,0.0,27
|
||||
130,0.0,6
|
||||
131,0.0,18
|
||||
132,0.0,38
|
||||
133,0.0,26
|
||||
134,0.0,36
|
||||
135,0.0,3
|
||||
136,0.0,61
|
||||
137,0.0,105
|
||||
138,0.0,38
|
||||
139,0.0,18
|
||||
140,0.0,33
|
||||
141,0.0,29
|
||||
142,0.0,49
|
||||
143,0.0,88
|
||||
144,0.0,22
|
||||
145,0.0,65
|
||||
146,0.0,36
|
||||
147,0.0,30
|
||||
148,0.0,58
|
||||
149,0.0,43
|
||||
150,0.0,53
|
||||
151,0.0,43
|
||||
152,0.0,13
|
||||
153,0.0,8
|
||||
154,0.0,39
|
||||
155,0.0,29
|
||||
156,0.0,26
|
||||
157,0.0,60
|
||||
158,0.0,153
|
||||
159,0.0,116
|
||||
160,0.0,53
|
||||
161,0.0,54
|
||||
162,0.0,8
|
||||
163,0.0,58
|
||||
164,0.0,3
|
||||
165,0.0,47
|
||||
166,0.0,16
|
||||
167,0.0,21
|
||||
168,0.0,44
|
||||
169,0.0,29
|
||||
170,0.0,104
|
||||
171,0.0,158
|
||||
172,0.0,83
|
||||
173,0.0,26
|
||||
174,0.0,24
|
||||
175,0.0,10
|
||||
176,0.0,12
|
||||
177,0.0,40
|
||||
178,0.0,25
|
||||
179,0.0,18
|
||||
180,0.0,60
|
||||
181,0.0,203
|
||||
182,0.0,23
|
||||
183,0.0,54
|
||||
184,0.0,71
|
||||
185,0.0,19
|
||||
186,0.0,118
|
||||
187,0.0,26
|
||||
188,0.0,41
|
||||
189,0.0,41
|
||||
190,0.0,60
|
||||
191,0.0,31
|
||||
192,0.0,34
|
||||
193,0.0,35
|
||||
194,0.0,59
|
||||
195,0.0,51
|
||||
196,0.0,426
|
||||
197,0.0,79
|
||||
198,0.0,40
|
||||
199,0.0,17
|
||||
200,0.0,79
|
||||
201,0.0,126
|
||||
202,0.0,61
|
||||
203,0.0,25
|
||||
204,0.0,18
|
||||
205,0.0,27
|
||||
206,0.0,13
|
||||
207,0.0,187
|
||||
208,0.0,160
|
||||
209,0.0,32
|
||||
210,0.0,108
|
||||
211,0.0,164
|
||||
212,0.0,17
|
||||
213,0.0,82
|
||||
214,0.0,194
|
||||
215,0.0,7
|
||||
216,0.0,36
|
||||
217,0.0,156
|
||||
218,0.0,17
|
||||
219,0.0,183
|
||||
220,0.0,243
|
||||
221,0.0,87
|
||||
222,0.0,42
|
||||
223,0.0,80
|
||||
224,0.0,54
|
||||
225,0.0,82
|
||||
226,0.0,97
|
||||
227,0.0,65
|
||||
228,0.0,83
|
||||
229,0.0,159
|
||||
230,0.0,178
|
||||
231,0.0,104
|
||||
232,0.0,21
|
||||
233,0.0,118
|
||||
234,0.0,80
|
||||
235,0.0,170
|
||||
236,0.0,94
|
||||
237,0.0,235
|
||||
238,0.0,13
|
||||
239,0.0,31
|
||||
240,0.0,134
|
||||
241,0.0,32
|
||||
242,0.0,58
|
||||
243,0.0,38
|
||||
244,0.0,28
|
||||
245,0.0,159
|
||||
246,0.0,182
|
||||
247,0.0,51
|
||||
248,0.0,25
|
||||
249,0.0,73
|
||||
250,0.0,56
|
||||
251,0.0,55
|
||||
252,0.0,38
|
||||
253,0.0,292
|
||||
254,0.0,319
|
||||
255,0.0,100
|
||||
256,0.0,84
|
||||
257,0.0,24
|
||||
258,0.0,17
|
||||
259,0.0,159
|
||||
260,0.0,25
|
||||
261,0.0,73
|
||||
262,0.0,130
|
||||
263,0.0,111
|
||||
264,0.0,65
|
||||
265,1.0,58
|
||||
266,0.0,47
|
||||
267,0.0,48
|
||||
268,0.0,13
|
||||
269,0.0,100
|
||||
270,0.0,38
|
||||
271,0.0,111
|
||||
272,0.0,226
|
||||
273,0.0,38
|
||||
274,0.0,83
|
||||
275,0.0,42
|
||||
276,0.0,199
|
||||
277,0.0,83
|
||||
278,0.0,28
|
||||
279,0.0,46
|
||||
280,0.0,262
|
||||
281,0.0,123
|
||||
282,0.0,91
|
||||
283,0.0,53
|
||||
284,0.0,19
|
||||
285,0.0,26
|
||||
286,0.0,93
|
||||
287,0.0,38
|
||||
288,0.0,22
|
||||
289,0.0,43
|
||||
290,0.0,163
|
||||
291,0.0,25
|
||||
292,0.0,59
|
||||
293,0.0,71
|
||||
294,0.0,20
|
||||
295,0.0,115
|
||||
296,0.0,248
|
||||
297,0.0,66
|
||||
298,0.0,58
|
||||
299,0.0,129
|
||||
300,0.0,122
|
||||
301,0.0,47
|
||||
302,0.0,60
|
||||
303,0.0,79
|
||||
304,1.0,137
|
||||
305,0.0,27
|
||||
306,1.0,93
|
||||
307,0.0,46
|
||||
308,1.0,83
|
||||
309,1.0,8
|
||||
310,1.0,6
|
||||
311,1.0,6
|
||||
312,0.0,4
|
||||
313,1.0,6
|
||||
314,0.0,2
|
||||
315,1.0,6
|
||||
316,1.0,6
|
||||
317,1.0,6
|
||||
318,1.0,6
|
||||
319,1.0,8
|
||||
320,0.0,5
|
||||
321,1.0,6
|
||||
322,1.0,7
|
||||
323,0.0,5
|
||||
324,1.0,6
|
||||
325,1.0,6
|
||||
326,1.0,8
|
||||
327,1.0,6
|
||||
328,1.0,6
|
||||
329,1.0,6
|
||||
330,1.0,7
|
||||
331,1.0,6
|
||||
332,1.0,6
|
||||
333,0.0,3
|
||||
334,1.0,7
|
||||
335,0.0,4
|
||||
336,1.0,6
|
||||
337,1.0,6
|
||||
338,1.0,7
|
||||
339,1.0,6
|
||||
340,1.0,6
|
||||
341,1.0,7
|
||||
342,1.0,7
|
||||
343,1.0,7
|
||||
344,1.0,6
|
||||
345,1.0,6
|
||||
346,1.0,6
|
||||
347,1.0,6
|
||||
348,1.0,6
|
||||
349,1.0,6
|
||||
350,1.0,6
|
||||
351,1.0,7
|
||||
352,0.0,4
|
||||
353,1.0,8
|
||||
354,1.0,8
|
||||
355,1.0,7
|
||||
356,1.0,6
|
||||
357,1.0,8
|
||||
358,1.0,6
|
||||
359,1.0,6
|
||||
360,1.0,7
|
||||
361,1.0,6
|
||||
362,1.0,6
|
||||
363,1.0,8
|
||||
364,1.0,7
|
||||
365,1.0,6
|
||||
366,1.0,6
|
||||
367,0.0,3
|
||||
368,1.0,11
|
||||
369,1.0,6
|
||||
370,1.0,8
|
||||
371,0.0,2
|
||||
372,1.0,6
|
||||
373,1.0,6
|
||||
374,1.0,6
|
||||
375,1.0,6
|
||||
376,1.0,8
|
||||
377,1.0,6
|
||||
378,1.0,7
|
||||
379,1.0,6
|
||||
380,1.0,7
|
||||
381,1.0,6
|
||||
382,1.0,8
|
||||
383,0.0,2
|
||||
384,1.0,6
|
||||
385,1.0,7
|
||||
386,1.0,6
|
||||
387,1.0,6
|
||||
388,1.0,10
|
||||
389,1.0,7
|
||||
390,1.0,6
|
||||
391,1.0,6
|
||||
392,1.0,6
|
||||
393,1.0,6
|
||||
394,1.0,6
|
||||
395,1.0,7
|
||||
396,0.0,4
|
||||
397,1.0,7
|
||||
398,1.0,6
|
||||
399,1.0,8
|
||||
400,0.0,3
|
||||
401,1.0,6
|
||||
402,1.0,6
|
||||
403,1.0,6
|
||||
404,1.0,6
|
||||
405,0.0,2
|
||||
406,1.0,6
|
||||
407,1.0,6
|
||||
408,1.0,6
|
||||
409,1.0,6
|
||||
410,1.0,6
|
||||
411,1.0,7
|
||||
412,1.0,6
|
||||
413,1.0,6
|
||||
414,1.0,7
|
||||
415,1.0,6
|
||||
416,1.0,6
|
||||
417,1.0,6
|
||||
418,1.0,6
|
||||
419,1.0,6
|
||||
420,1.0,6
|
||||
421,1.0,6
|
||||
422,1.0,8
|
||||
423,1.0,6
|
||||
424,1.0,8
|
||||
425,1.0,7
|
||||
426,1.0,6
|
||||
427,0.0,3
|
||||
428,1.0,6
|
||||
429,1.0,7
|
||||
430,1.0,6
|
||||
431,1.0,6
|
||||
432,1.0,6
|
||||
433,1.0,10
|
||||
434,1.0,6
|
||||
435,1.0,6
|
||||
436,1.0,6
|
||||
437,1.0,6
|
||||
438,1.0,10
|
||||
439,1.0,6
|
||||
440,1.0,8
|
||||
441,1.0,8
|
||||
442,1.0,7
|
||||
443,1.0,6
|
||||
444,0.0,5
|
||||
445,0.0,2
|
||||
446,1.0,8
|
||||
447,1.0,6
|
||||
448,1.0,10
|
||||
449,1.0,6
|
||||
450,1.0,8
|
||||
451,1.0,10
|
||||
452,1.0,6
|
||||
453,1.0,6
|
||||
454,1.0,6
|
||||
455,1.0,10
|
||||
456,1.0,6
|
||||
457,0.0,4
|
||||
458,1.0,6
|
||||
459,1.0,6
|
||||
460,1.0,6
|
||||
461,1.0,15
|
||||
462,1.0,6
|
||||
463,1.0,6
|
||||
464,1.0,6
|
||||
465,1.0,6
|
||||
466,1.0,6
|
||||
467,1.0,6
|
||||
468,1.0,8
|
||||
469,1.0,6
|
||||
470,1.0,7
|
||||
471,1.0,6
|
||||
472,1.0,6
|
||||
473,1.0,8
|
||||
474,1.0,6
|
||||
475,1.0,6
|
||||
476,1.0,8
|
||||
477,1.0,8
|
||||
478,1.0,6
|
||||
479,1.0,6
|
||||
480,1.0,6
|
||||
481,1.0,10
|
||||
482,1.0,6
|
||||
483,1.0,6
|
||||
484,1.0,6
|
||||
485,1.0,6
|
||||
486,1.0,6
|
||||
487,1.0,6
|
||||
488,1.0,6
|
||||
489,1.0,8
|
||||
490,1.0,8
|
||||
491,1.0,6
|
||||
492,1.0,6
|
||||
493,0.0,2
|
||||
494,1.0,6
|
||||
495,1.0,6
|
||||
496,1.0,6
|
||||
497,1.0,8
|
||||
498,1.0,6
|
||||
499,1.0,6
|
||||
500,1.0,6
|
||||
501,1.0,6
|
||||
502,1.0,6
|
||||
503,1.0,6
|
||||
504,1.0,6
|
||||
505,1.0,6
|
||||
506,1.0,6
|
||||
507,1.0,7
|
||||
508,0.0,3
|
||||
509,1.0,7
|
||||
510,1.0,6
|
||||
511,1.0,6
|
||||
512,1.0,6
|
||||
513,0.0,2
|
||||
514,1.0,6
|
||||
515,1.0,8
|
||||
516,1.0,6
|
||||
517,1.0,6
|
||||
518,1.0,6
|
||||
519,1.0,6
|
||||
520,1.0,9
|
||||
521,1.0,6
|
||||
522,1.0,6
|
||||
523,1.0,6
|
||||
524,1.0,6
|
||||
525,1.0,6
|
||||
526,1.0,6
|
||||
527,1.0,9
|
||||
528,1.0,7
|
||||
529,0.0,4
|
||||
530,1.0,6
|
||||
531,1.0,8
|
||||
532,1.0,11
|
||||
533,1.0,6
|
||||
534,1.0,6
|
||||
535,1.0,6
|
||||
536,1.0,6
|
||||
537,1.0,6
|
||||
538,1.0,8
|
||||
539,1.0,6
|
||||
540,1.0,6
|
||||
541,1.0,8
|
||||
542,1.0,7
|
||||
543,1.0,6
|
||||
544,1.0,8
|
||||
545,1.0,6
|
||||
546,0.0,5
|
||||
547,1.0,9
|
||||
548,1.0,8
|
||||
549,1.0,8
|
||||
550,1.0,6
|
||||
551,1.0,8
|
||||
552,1.0,8
|
||||
553,1.0,6
|
||||
554,0.0,5
|
||||
555,0.0,3
|
||||
556,0.0,2
|
||||
557,1.0,8
|
||||
558,1.0,6
|
||||
559,1.0,6
|
||||
560,1.0,6
|
||||
561,1.0,6
|
||||
562,1.0,6
|
||||
563,1.0,6
|
||||
564,1.0,6
|
||||
565,1.0,6
|
||||
566,1.0,6
|
||||
567,1.0,6
|
||||
568,1.0,6
|
||||
569,1.0,6
|
||||
570,1.0,6
|
||||
571,1.0,6
|
||||
572,0.0,2
|
||||
573,1.0,6
|
||||
574,0.0,4
|
||||
575,1.0,6
|
||||
576,1.0,6
|
||||
577,1.0,6
|
||||
578,1.0,6
|
||||
579,1.0,6
|
||||
580,1.0,8
|
||||
581,0.0,5
|
||||
582,1.0,6
|
||||
583,1.0,6
|
||||
584,1.0,6
|
||||
585,1.0,6
|
||||
586,1.0,6
|
||||
587,1.0,6
|
||||
588,0.0,3
|
||||
589,1.0,6
|
||||
590,1.0,6
|
||||
591,1.0,6
|
||||
592,0.0,2
|
||||
593,1.0,6
|
||||
594,0.0,4
|
||||
595,1.0,6
|
||||
596,1.0,6
|
||||
597,1.0,6
|
||||
598,1.0,6
|
||||
599,1.0,8
|
||||
600,1.0,6
|
||||
601,1.0,7
|
||||
602,1.0,6
|
||||
603,1.0,7
|
||||
604,1.0,6
|
||||
605,0.0,2
|
||||
606,1.0,6
|
||||
607,1.0,6
|
||||
608,0.0,5
|
||||
609,0.0,3
|
||||
610,0.0,3
|
||||
611,1.0,6
|
||||
612,0.0,5
|
||||
613,1.0,8
|
||||
614,1.0,8
|
||||
615,1.0,6
|
||||
616,1.0,6
|
||||
617,1.0,7
|
||||
618,1.0,6
|
||||
619,1.0,6
|
||||
620,1.0,6
|
||||
621,1.0,6
|
||||
622,1.0,6
|
||||
623,1.0,8
|
||||
624,0.0,2
|
||||
625,1.0,6
|
||||
626,1.0,6
|
||||
627,1.0,6
|
||||
628,1.0,6
|
||||
629,1.0,6
|
||||
630,1.0,6
|
||||
631,1.0,6
|
||||
632,1.0,8
|
||||
633,1.0,6
|
||||
634,1.0,8
|
||||
635,1.0,6
|
||||
636,1.0,6
|
||||
637,1.0,8
|
||||
638,1.0,8
|
||||
639,0.0,5
|
||||
640,0.0,4
|
||||
641,0.0,4
|
||||
642,1.0,6
|
||||
643,1.0,6
|
||||
644,1.0,6
|
||||
645,1.0,6
|
||||
646,1.0,8
|
||||
647,1.0,6
|
||||
648,0.0,4
|
||||
649,1.0,6
|
||||
650,1.0,8
|
||||
651,1.0,6
|
||||
652,1.0,6
|
||||
653,1.0,6
|
||||
654,1.0,6
|
||||
655,1.0,6
|
||||
656,1.0,6
|
||||
657,1.0,6
|
||||
658,1.0,8
|
||||
659,1.0,8
|
||||
660,1.0,6
|
||||
661,1.0,8
|
||||
662,1.0,9
|
||||
663,1.0,6
|
||||
664,1.0,6
|
||||
665,1.0,6
|
||||
666,1.0,6
|
||||
667,1.0,10
|
||||
668,1.0,6
|
||||
669,1.0,6
|
||||
670,1.0,6
|
||||
671,1.0,11
|
||||
672,1.0,10
|
||||
673,1.0,8
|
||||
674,1.0,6
|
||||
675,1.0,6
|
||||
676,1.0,6
|
||||
677,0.0,5
|
||||
678,1.0,6
|
||||
679,0.0,2
|
||||
680,1.0,9
|
||||
681,1.0,6
|
||||
682,1.0,8
|
||||
683,1.0,7
|
||||
684,1.0,6
|
||||
685,1.0,6
|
||||
686,1.0,7
|
||||
687,0.0,3
|
||||
688,1.0,7
|
||||
689,0.0,2
|
||||
690,1.0,6
|
||||
691,1.0,6
|
||||
692,1.0,8
|
||||
693,1.0,8
|
||||
694,1.0,6
|
||||
695,1.0,6
|
||||
696,0.0,2
|
||||
697,1.0,8
|
||||
698,1.0,6
|
||||
699,1.0,8
|
||||
700,1.0,6
|
||||
701,1.0,6
|
||||
702,1.0,9
|
||||
703,1.0,6
|
||||
704,1.0,8
|
||||
705,1.0,11
|
||||
706,1.0,6
|
||||
707,1.0,6
|
||||
708,1.0,6
|
||||
709,1.0,6
|
||||
710,1.0,8
|
||||
711,1.0,6
|
||||
712,1.0,6
|
||||
713,1.0,6
|
||||
714,0.0,5
|
||||
715,1.0,6
|
||||
716,1.0,6
|
||||
717,1.0,6
|
||||
718,1.0,6
|
||||
719,1.0,6
|
||||
720,1.0,7
|
||||
721,1.0,6
|
||||
722,1.0,6
|
||||
723,1.0,6
|
||||
724,1.0,6
|
||||
725,1.0,10
|
||||
726,1.0,6
|
||||
727,1.0,6
|
||||
728,1.0,6
|
||||
729,1.0,6
|
||||
730,1.0,6
|
||||
731,1.0,7
|
||||
732,1.0,6
|
||||
733,1.0,8
|
||||
734,1.0,7
|
||||
735,1.0,6
|
||||
736,1.0,6
|
||||
737,1.0,14
|
||||
738,1.0,6
|
||||
739,1.0,6
|
||||
740,1.0,12
|
||||
741,1.0,6
|
||||
742,1.0,6
|
||||
743,1.0,6
|
||||
744,1.0,6
|
||||
745,1.0,6
|
||||
746,1.0,6
|
||||
747,0.0,3
|
||||
748,1.0,6
|
||||
749,1.0,6
|
||||
750,1.0,6
|
||||
751,1.0,7
|
||||
752,1.0,6
|
||||
753,1.0,6
|
||||
754,1.0,6
|
||||
755,1.0,8
|
||||
756,0.0,2
|
||||
757,1.0,6
|
||||
758,1.0,6
|
||||
759,1.0,6
|
||||
760,1.0,6
|
||||
761,1.0,6
|
||||
762,1.0,6
|
||||
763,1.0,6
|
||||
764,1.0,6
|
||||
765,1.0,6
|
||||
766,0.0,4
|
||||
767,1.0,8
|
||||
768,1.0,6
|
||||
769,0.0,2
|
||||
770,1.0,10
|
||||
771,1.0,8
|
||||
772,1.0,6
|
||||
773,1.0,6
|
||||
774,1.0,6
|
||||
775,0.0,3
|
||||
776,1.0,6
|
||||
777,1.0,6
|
||||
778,0.0,6
|
||||
779,1.0,8
|
||||
780,1.0,6
|
||||
781,1.0,9
|
||||
782,1.0,6
|
||||
783,1.0,6
|
||||
784,1.0,8
|
||||
785,1.0,8
|
||||
786,1.0,6
|
||||
787,0.0,5
|
||||
788,1.0,6
|
||||
789,1.0,6
|
||||
790,1.0,6
|
||||
791,1.0,6
|
||||
792,1.0,6
|
||||
793,1.0,6
|
||||
794,1.0,8
|
||||
795,1.0,6
|
||||
796,0.0,2
|
||||
797,1.0,8
|
||||
798,1.0,7
|
||||
799,1.0,6
|
||||
|
136
projects/codes/Sarsa/main.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-11 17:59:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 14:26:36
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
import gym
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.register import register_env
|
||||
from envs.wrappers import CliffWalkingWapper
|
||||
from Sarsa.sarsa import Sarsa
|
||||
from common.utils import save_results,make_dir,plot_rewards,save_args,all_seed
|
||||
|
||||
def get_args():
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='Racetrack-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
|
||||
def env_agent_config(cfg):
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed= cfg['seed'])
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = CliffWalkingWapper(env)
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = Sarsa(cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
action = agent.sample_action(state)
|
||||
while True:
|
||||
# for _ in range(cfg.ep_max_steps):
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
next_action = agent.sample_action(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done) # update agent
|
||||
state = next_state # update state
|
||||
action = next_action
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
while True:
|
||||
# for _ in range(cfg.ep_max_steps):
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done = env.step(action)
|
||||
state = next_state
|
||||
ep_reward+=reward
|
||||
ep_step+=1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:58:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-04 22:22:16
|
||||
LastEditTime: 2022-08-25 00:23:22
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -14,45 +14,51 @@ from collections import defaultdict
|
||||
import torch
|
||||
import math
|
||||
class Sarsa(object):
|
||||
def __init__(self,
|
||||
n_actions,cfg):
|
||||
self.n_actions = n_actions
|
||||
self.lr = cfg.lr
|
||||
self.gamma = cfg.gamma
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg.epsilon_start
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.Q = defaultdict(lambda: np.zeros(n_actions)) # Q table
|
||||
def sample(self, state):
|
||||
def __init__(self,cfg):
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.lr = cfg['lr']
|
||||
self.gamma = cfg['gamma']
|
||||
self.epsilon = cfg['epsilon_start']
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg['epsilon_start']
|
||||
self.epsilon_end = cfg['epsilon_end']
|
||||
self.epsilon_decay = cfg['epsilon_decay']
|
||||
self.Q_table = defaultdict(lambda: np.zeros(self.n_actions)) # Q table
|
||||
def sample_action(self, state):
|
||||
''' another way to represent e-greedy policy
|
||||
'''
|
||||
self.sample_count += 1
|
||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
|
||||
best_action = np.argmax(self.Q[state])
|
||||
best_action = np.argmax(self.Q_table[state])
|
||||
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
|
||||
action_probs[best_action] += (1.0 - self.epsilon)
|
||||
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
|
||||
return action
|
||||
def predict(self,state):
|
||||
return np.argmax(self.Q[state])
|
||||
def update(self, state, action, reward, next_state, next_action,done):
|
||||
Q_predict = self.Q[state][action]
|
||||
if done:
|
||||
Q_target = reward # 终止状态
|
||||
else:
|
||||
Q_target = reward + self.gamma * self.Q[next_state][next_action] # 与Q learning不同,Sarsa是拿下一步动作对应的Q值去更新
|
||||
self.Q[state][action] += self.lr * (Q_target - Q_predict)
|
||||
def save(self,path):
|
||||
'''把 Q表格 的数据保存到文件中
|
||||
def predict_action(self,state):
|
||||
''' predict action while testing
|
||||
'''
|
||||
action = np.argmax(self.Q_table[state])
|
||||
return action
|
||||
def update(self, state, action, reward, next_state, next_action,done):
|
||||
Q_predict = self.Q_table[state][action]
|
||||
if done:
|
||||
Q_target = reward # terminal state
|
||||
else:
|
||||
Q_target = reward + self.gamma * self.Q_table[next_state][next_action] # the only difference from Q learning
|
||||
self.Q_table[state][action] += self.lr * (Q_target - Q_predict)
|
||||
def save_model(self,path):
|
||||
import dill
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
obj=self.Q,
|
||||
f=path+"sarsa_model.pkl",
|
||||
obj=self.Q_table_table,
|
||||
f=path+"checkpoint.pkl",
|
||||
pickle_module=dill
|
||||
)
|
||||
def load(self, path):
|
||||
'''从文件中读取数据到 Q表格
|
||||
'''
|
||||
print("Model saved!")
|
||||
def load_model(self, path):
|
||||
import dill
|
||||
self.Q =torch.load(f=path+'sarsa_model.pkl',pickle_module=dill)
|
||||
self.Q_table_table =torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
|
||||
print("Mode loaded!")
|
||||
@@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-11 17:59:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-04 22:28:51
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
|
||||
parent_path = os.path.dirname(curr_path) # 父路径
|
||||
sys.path.append(parent_path) # 添加路径到系统路径
|
||||
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.racetrack_env import RacetrackEnv
|
||||
from Sarsa.sarsa import Sarsa
|
||||
from common.utils import save_results,make_dir,plot_rewards,save_args
|
||||
|
||||
def get_args():
|
||||
""" 超参数
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training") # 训练的回合数
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing") # 测试的回合数
|
||||
parser.add_argument('--ep_max_steps',default=200,type=int) # 每回合最大的部署
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor") # 折扣因子
|
||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon") # e-greedy策略中初始epsilon
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon") # e-greedy策略中的终止epsilon
|
||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon") # e-greedy策略中epsilon的衰减率
|
||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' ) # path to save models
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = RacetrackEnv()
|
||||
n_actions = 9 # 动作数
|
||||
agent = Sarsa(n_actions,cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print('开始训练!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
rewards = [] # 记录奖励
|
||||
for i_ep in range(cfg.train_eps):
|
||||
state = env.reset()
|
||||
action = agent.sample(state)
|
||||
ep_reward = 0
|
||||
# while True:
|
||||
for _ in range(cfg.ep_max_steps):
|
||||
next_state, reward, done = env.step(action)
|
||||
ep_reward+=reward
|
||||
next_action = agent.sample(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done)
|
||||
state = next_state
|
||||
action = next_action
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
if (i_ep+1)%2==0:
|
||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
||||
print('完成训练!')
|
||||
return {"rewards":rewards}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print('开始测试!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
rewards = []
|
||||
for i_ep in range(cfg.test_eps):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
# while True:
|
||||
for _ in range(cfg.ep_max_steps):
|
||||
action = agent.predict(state)
|
||||
next_state, reward, done = env.step(action)
|
||||
ep_reward+=reward
|
||||
state = next_state
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
|
||||
print('完成测试!')
|
||||
return {"rewards":rewards}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||
|
||||
|
||||
|
||||
32
projects/codes/common/launcher.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from common.utils import save_args,save_results,plot_rewards
|
||||
class Launcher:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def get_args(self):
|
||||
cfg = {}
|
||||
return cfg
|
||||
def env_agent_config(self,cfg):
|
||||
env,agent = None,None
|
||||
return env,agent
|
||||
def train(self,cfg, env, agent):
|
||||
res_dic = {}
|
||||
return res_dic
|
||||
def test(self,cfg, env, agent):
|
||||
res_dic = {}
|
||||
return res_dic
|
||||
|
||||
def run(self):
|
||||
cfg = self.get_args()
|
||||
env, agent = self.env_agent_config(cfg)
|
||||
res_dic = self.train(cfg, env, agent)
|
||||
save_args(cfg,path = cfg['result_path']) # save parameters
|
||||
agent.save_model(path = cfg['model_path']) # save models
|
||||
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
||||
# testing
|
||||
env, agent = self.env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
agent.load_model(path = cfg['model_path']) # load model
|
||||
res_dic = self.test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path = cfg['result_path'])
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "test")
|
||||
@@ -72,84 +72,6 @@ class FrozenLakeWapper(gym.Wrapper):
|
||||
self.move_player(x_pos, y_pos)
|
||||
|
||||
|
||||
class CliffWalkingWapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.t = None
|
||||
self.unit = 50
|
||||
self.max_x = 12
|
||||
self.max_y = 4
|
||||
|
||||
def draw_x_line(self, y, x0, x1, color='gray'):
|
||||
assert x1 > x0
|
||||
self.t.color(color)
|
||||
self.t.setheading(0)
|
||||
self.t.up()
|
||||
self.t.goto(x0, y)
|
||||
self.t.down()
|
||||
self.t.forward(x1 - x0)
|
||||
|
||||
def draw_y_line(self, x, y0, y1, color='gray'):
|
||||
assert y1 > y0
|
||||
self.t.color(color)
|
||||
self.t.setheading(90)
|
||||
self.t.up()
|
||||
self.t.goto(x, y0)
|
||||
self.t.down()
|
||||
self.t.forward(y1 - y0)
|
||||
|
||||
def draw_box(self, x, y, fillcolor='', line_color='gray'):
|
||||
self.t.up()
|
||||
self.t.goto(x * self.unit, y * self.unit)
|
||||
self.t.color(line_color)
|
||||
self.t.fillcolor(fillcolor)
|
||||
self.t.setheading(90)
|
||||
self.t.down()
|
||||
self.t.begin_fill()
|
||||
for i in range(4):
|
||||
self.t.forward(self.unit)
|
||||
self.t.right(90)
|
||||
self.t.end_fill()
|
||||
|
||||
def move_player(self, x, y):
|
||||
self.t.up()
|
||||
self.t.setheading(90)
|
||||
self.t.fillcolor('red')
|
||||
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
|
||||
|
||||
def render(self):
|
||||
if self.t == None:
|
||||
self.t = turtle.Turtle()
|
||||
self.wn = turtle.Screen()
|
||||
self.wn.setup(self.unit * self.max_x + 100,
|
||||
self.unit * self.max_y + 100)
|
||||
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
|
||||
self.unit * self.max_y)
|
||||
self.t.shape('circle')
|
||||
self.t.width(2)
|
||||
self.t.speed(0)
|
||||
self.t.color('gray')
|
||||
for _ in range(2):
|
||||
self.t.forward(self.max_x * self.unit)
|
||||
self.t.left(90)
|
||||
self.t.forward(self.max_y * self.unit)
|
||||
self.t.left(90)
|
||||
for i in range(1, self.max_y):
|
||||
self.draw_x_line(
|
||||
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
|
||||
for i in range(1, self.max_x):
|
||||
self.draw_y_line(
|
||||
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
|
||||
|
||||
for i in range(1, self.max_x - 1):
|
||||
self.draw_box(i, 0, 'black')
|
||||
self.draw_box(self.max_x - 1, 0, 'yellow')
|
||||
self.t.shape('turtle')
|
||||
|
||||
x_pos = self.s % self.max_x
|
||||
y_pos = self.max_y - 1 - int(self.s / self.max_x)
|
||||
self.move_player(x_pos, y_pos)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 环境1:FrozenLake, 可以配置冰面是否是滑的
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
# Please do not make changes to this file - it will be overwritten with a clean
|
||||
# version when your work is marked.
|
||||
#
|
||||
# This file contains code for the racetrack environment that you will be using
|
||||
# as part of the second part of the CM50270: Reinforcement Learning coursework.
|
||||
|
||||
import imp
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
@@ -12,23 +5,20 @@ import os
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patheffects as pe
|
||||
from IPython.display import clear_output
|
||||
from gym.spaces import Discrete
|
||||
from gym.spaces import Discrete,Box
|
||||
from matplotlib import colors
|
||||
import gym
|
||||
|
||||
class RacetrackEnv(object) :
|
||||
class RacetrackEnv(gym.Env) :
|
||||
"""
|
||||
Class representing a race-track environment inspired by exercise 5.12 in Sutton & Barto 2018 (p.111).
|
||||
Please do not make changes to this class - it will be overwritten with a clean version when it comes to marking.
|
||||
|
||||
The dynamics of this environment are detailed in this coursework exercise's jupyter notebook, although I have
|
||||
included rather verbose comments here for those of you who are interested in how the environment has been
|
||||
implemented (though this should not impact your solution code).
|
||||
|
||||
If you find any *bugs* with this code, please let me know immediately - thank you for finding them, sorry that I didn't!
|
||||
However, please do not suggest optimisations - some things have been purposely simplified for readability's sake.
|
||||
implemented (though this should not impact your solution code).ss
|
||||
"""
|
||||
|
||||
|
||||
ACTIONS_DICT = {
|
||||
0 : (1, -1), # Acc Vert., Brake Horiz.
|
||||
1 : (1, 0), # Acc Vert., Hold Horiz.
|
||||
@@ -61,18 +51,15 @@ class RacetrackEnv(object) :
|
||||
for x in range(self.track.shape[1]) :
|
||||
if (self.CELL_TYPES_DICT[self.track[y, x]] == "start") :
|
||||
self.initial_states.append((y, x))
|
||||
|
||||
high= np.array([np.finfo(np.float32).max, np.finfo(np.float32).max, np.finfo(np.float32).max, np.finfo(np.float32).max])
|
||||
self.observation_space = Box(low=-high, high=high, shape=(4,), dtype=np.float32)
|
||||
self.action_space = Discrete(9)
|
||||
self.is_reset = False
|
||||
|
||||
#print("Racetrack Environment File Loaded Successfully.")
|
||||
#print("Be sure to call .reset() before starting to initialise the environment and get an initial state!")
|
||||
|
||||
|
||||
def step(self, action : int) :
|
||||
"""
|
||||
Takes a given action in the environment's current state, and returns a next state,
|
||||
reward, and whether the next state is terminal or not.
|
||||
reward, and whether the next state is done or not.
|
||||
|
||||
Arguments:
|
||||
action {int} -- The action to take in the environment's current state. Should be an integer in the range [0-8].
|
||||
@@ -86,7 +73,7 @@ class RacetrackEnv(object) :
|
||||
A tuple of:\n
|
||||
{(int, int, int, int)} -- The next state, a tuple of (y_pos, x_pos, y_velocity, x_velocity).\n
|
||||
{int} -- The reward earned by taking the given action in the current environment state.\n
|
||||
{bool} -- Whether the environment's next state is terminal or not.\n
|
||||
{bool} -- Whether the environment's next state is done or not.\n
|
||||
|
||||
"""
|
||||
|
||||
@@ -131,7 +118,7 @@ class RacetrackEnv(object) :
|
||||
new_position = (self.position[0] + self.velocity[0], self.position[1] + self.velocity[1])
|
||||
|
||||
reward = 0
|
||||
terminal = False
|
||||
done = False
|
||||
|
||||
# If position is out-of-bounds, return to start and set velocity components to zero.
|
||||
if (new_position[0] < 0 or new_position[1] < 0 or new_position[0] >= self.track.shape[0] or new_position[1] >= self.track.shape[1]) :
|
||||
@@ -150,7 +137,7 @@ class RacetrackEnv(object) :
|
||||
elif (self.CELL_TYPES_DICT[self.track[new_position]] == "goal") :
|
||||
self.position = new_position
|
||||
reward += 10
|
||||
terminal = True
|
||||
done = True
|
||||
# If this gets reached, then the student has touched something they shouldn't have. Naughty!
|
||||
else :
|
||||
raise RuntimeError("You've met with a terrible fate, haven't you?\nDon't modify things you shouldn't!")
|
||||
@@ -158,12 +145,12 @@ class RacetrackEnv(object) :
|
||||
# Penalise every timestep.
|
||||
reward -= 1
|
||||
|
||||
# Require a reset if the current state is terminal.
|
||||
if (terminal) :
|
||||
# Require a reset if the current state is done.
|
||||
if (done) :
|
||||
self.is_reset = False
|
||||
|
||||
# Return next state, reward, and whether the episode has ended.
|
||||
return (self.position[0], self.position[1], self.velocity[0], self.velocity[1]), reward, terminal
|
||||
return np.array([self.position[0], self.position[1], self.velocity[0], self.velocity[1]]), reward, done,{}
|
||||
|
||||
|
||||
def reset(self) :
|
||||
@@ -184,10 +171,10 @@ class RacetrackEnv(object) :
|
||||
|
||||
self.is_reset = True
|
||||
|
||||
return (self.position[0], self.position[1], self.velocity[0], self.velocity[1])
|
||||
return np.array([self.position[0], self.position[1], self.velocity[0], self.velocity[1]])
|
||||
|
||||
|
||||
def render(self, sleep_time : float = 0.1) :
|
||||
def render(self, mode = 'human') :
|
||||
"""
|
||||
Renders a pretty matplotlib plot representing the current state of the environment.
|
||||
Calling this method on subsequent timesteps will update the plot.
|
||||
@@ -230,13 +217,9 @@ class RacetrackEnv(object) :
|
||||
# Draw everything.
|
||||
#fig.canvas.draw()
|
||||
#fig.canvas.flush_events()
|
||||
|
||||
plt.show()
|
||||
|
||||
# Sleep if desired.
|
||||
if (sleep_time > 0) :
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# time sleep
|
||||
time.sleep(0.1)
|
||||
|
||||
def get_actions(self) :
|
||||
"""
|
||||
@@ -244,18 +227,16 @@ class RacetrackEnv(object) :
|
||||
of integers in the range [0-8].
|
||||
"""
|
||||
return [*self.ACTIONS_DICT]
|
||||
if __name__ == "__main__":
|
||||
num_steps = 1000000
|
||||
env = RacetrackEnv()
|
||||
state = env.reset()
|
||||
print(state)
|
||||
for _ in range(num_steps) :
|
||||
|
||||
# num_steps = 1000000
|
||||
next_state, reward, done,_ = env.step(random.choice(env.get_actions()))
|
||||
print(next_state)
|
||||
env.render()
|
||||
|
||||
# env = RacetrackEnv()
|
||||
# state = env.reset()
|
||||
# print(state)
|
||||
|
||||
# for _ in range(num_steps) :
|
||||
|
||||
# next_state, reward, terminal = env.step(random.choice(env.get_actions()))
|
||||
# print(next_state)
|
||||
# env.render()
|
||||
|
||||
# if (terminal) :
|
||||
# _ = env.reset()
|
||||
if (done) :
|
||||
_ = env.reset()
|
||||
34
projects/codes/envs/register.py
Normal file
@@ -0,0 +1,34 @@
|
||||
|
||||
from gym.envs.registration import register
|
||||
|
||||
def register_env(env_name):
|
||||
if env_name == 'Racetrack-v0':
|
||||
register(
|
||||
id='Racetrack-v0',
|
||||
entry_point='racetrack:RacetrackEnv',
|
||||
max_episode_steps=1000,
|
||||
kwargs={}
|
||||
)
|
||||
elif env_name == 'FrozenLakeNoSlippery-v1':
|
||||
register(
|
||||
id='FrozenLakeNoSlippery-v1',
|
||||
entry_point='gym.envs.toy_text.frozen_lake:FrozenLakeEnv',
|
||||
kwargs={'map_name':"4x4",'is_slippery':False},
|
||||
)
|
||||
else:
|
||||
print("The env name must be wrong or the environment donot need to register!")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# import random
|
||||
# import gym
|
||||
# env = gym.make('FrozenLakeNoSlippery-v1')
|
||||
# num_steps = 1000000
|
||||
# state = env.reset()
|
||||
# n_actions = env.action_space.n
|
||||
# print(state)
|
||||
# for _ in range(num_steps) :
|
||||
# next_state, reward, done,_ = env.step(random.choice(range(n_actions)))
|
||||
# print(next_state)
|
||||
# if (done) :
|
||||
# _ = env.reset()
|
||||
|
||||
78
projects/codes/envs/wrappers.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import gym
|
||||
class CliffWalkingWapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.t = None
|
||||
self.unit = 50
|
||||
self.max_x = 12
|
||||
self.max_y = 4
|
||||
|
||||
def draw_x_line(self, y, x0, x1, color='gray'):
|
||||
assert x1 > x0
|
||||
self.t.color(color)
|
||||
self.t.setheading(0)
|
||||
self.t.up()
|
||||
self.t.goto(x0, y)
|
||||
self.t.down()
|
||||
self.t.forward(x1 - x0)
|
||||
|
||||
def draw_y_line(self, x, y0, y1, color='gray'):
|
||||
assert y1 > y0
|
||||
self.t.color(color)
|
||||
self.t.setheading(90)
|
||||
self.t.up()
|
||||
self.t.goto(x, y0)
|
||||
self.t.down()
|
||||
self.t.forward(y1 - y0)
|
||||
|
||||
def draw_box(self, x, y, fillcolor='', line_color='gray'):
|
||||
self.t.up()
|
||||
self.t.goto(x * self.unit, y * self.unit)
|
||||
self.t.color(line_color)
|
||||
self.t.fillcolor(fillcolor)
|
||||
self.t.setheading(90)
|
||||
self.t.down()
|
||||
self.t.begin_fill()
|
||||
for i in range(4):
|
||||
self.t.forward(self.unit)
|
||||
self.t.right(90)
|
||||
self.t.end_fill()
|
||||
|
||||
def move_player(self, x, y):
|
||||
self.t.up()
|
||||
self.t.setheading(90)
|
||||
self.t.fillcolor('red')
|
||||
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
|
||||
|
||||
def render(self):
|
||||
if self.t == None:
|
||||
self.t = turtle.Turtle()
|
||||
self.wn = turtle.Screen()
|
||||
self.wn.setup(self.unit * self.max_x + 100,
|
||||
self.unit * self.max_y + 100)
|
||||
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
|
||||
self.unit * self.max_y)
|
||||
self.t.shape('circle')
|
||||
self.t.width(2)
|
||||
self.t.speed(0)
|
||||
self.t.color('gray')
|
||||
for _ in range(2):
|
||||
self.t.forward(self.max_x * self.unit)
|
||||
self.t.left(90)
|
||||
self.t.forward(self.max_y * self.unit)
|
||||
self.t.left(90)
|
||||
for i in range(1, self.max_y):
|
||||
self.draw_x_line(
|
||||
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
|
||||
for i in range(1, self.max_x):
|
||||
self.draw_y_line(
|
||||
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
|
||||
|
||||
for i in range(1, self.max_x - 1):
|
||||
self.draw_box(i, 0, 'black')
|
||||
self.draw_box(self.max_x - 1, 0, 'yellow')
|
||||
self.t.shape('turtle')
|
||||
|
||||
x_pos = self.s % self.max_x
|
||||
y_pos = self.max_y - 1 - int(self.s / self.max_x)
|
||||
self.move_player(x_pos, y_pos)
|
||||
@@ -11,4 +11,5 @@ else
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/QLearning/main.py --env_name FrozenLake-v1 --train_eps 800 --epsilon_start 0.70 --epsilon_end 0.1 --epsilon_decay 2000 --gamma 0.9 --lr 0.9 --device cpu
|
||||
python $codes_dir/envs/register.py # register environment
|
||||
python $codes_dir/QLearning/main.py --env_name FrozenLakeNoSlippery-v1 --train_eps 800 --epsilon_start 0.70 --epsilon_end 0.1 --epsilon_decay 2000 --gamma 0.9 --lr 0.9 --device cpu
|
||||
13
projects/codes/scripts/Sarsa_task0.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/envs/register.py # register environment
|
||||
python $codes_dir/Sarsa/main.py
|
||||