{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 1、定义算法\n", "\n", "在阅读该教程之前,请先阅读Q learning教程。Sarsa算法跟Q learning算法基本模式相同,但是根本的区别在于,Sarsa是先做出动作然后拿这个做的动作去更新,而Q learning是假定下一步最大奖励对应的动作拿去更新,然后再使用$\\varepsilon$-greedy策略,也就是说Sarsa是on-policy的,而Q learning是off-policy的。如下方代码所示,只有在更新的地方Sarsa与Q learning有着细微的区别。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from collections import defaultdict\n", "import torch\n", "import math\n", "class Sarsa(object):\n", " def __init__(self,\n", " n_actions,cfg):\n", " self.n_actions = n_actions \n", " self.lr = cfg.lr \n", " self.gamma = cfg.gamma \n", " self.sample_count = 0 \n", " self.epsilon_start = cfg.epsilon_start\n", " self.epsilon_end = cfg.epsilon_end\n", " self.epsilon_decay = cfg.epsilon_decay \n", " self.Q = defaultdict(lambda: np.zeros(n_actions)) # Q table\n", " def sample(self, state):\n", " self.sample_count += 1\n", " self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \\\n", " math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed\n", " best_action = np.argmax(self.Q[state])\n", " action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions\n", " action_probs[best_action] += (1.0 - self.epsilon)\n", " action = np.random.choice(np.arange(len(action_probs)), p=action_probs) \n", " return action\n", " def predict(self,state):\n", " return np.argmax(self.Q[state])\n", " def update(self, state, action, reward, next_state, next_action,done):\n", " Q_predict = self.Q[state][action]\n", " if done:\n", " Q_target = reward # 终止状态\n", " else:\n", " Q_target = reward + self.gamma * self.Q[next_state][next_action] # 与Q learning不同,Sarsa是拿下一步动作对应的Q值去更新\n", " self.Q[state][action] += self.lr * (Q_target - Q_predict) \n", " def save(self,path):\n", " '''把 Q表格 的数据保存到文件中\n", " '''\n", " import dill\n", " torch.save(\n", " obj=self.Q,\n", " f=path+\"sarsa_model.pkl\",\n", " pickle_module=dill\n", " )\n", " def load(self, path):\n", " '''从文件中读取数据到 Q表格\n", " '''\n", " import dill\n", " self.Q =torch.load(f=path+'sarsa_model.pkl',pickle_module=dill)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2、定义训练\n", "\n", "同样地,跟Q learning差别也不大" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def train(cfg,env,agent):\n", " print('开始训练!')\n", " print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')\n", " rewards = [] # 记录奖励\n", " for i_ep in range(cfg.train_eps):\n", " ep_reward = 0 # 记录每个回合的奖励\n", " state = env.reset() # 重置环境,即开始新的回合\n", " action = agent.sample(state)\n", " while True:\n", " action = agent.sample(state) # 根据算法采样一个动作\n", " next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互\n", " next_action = agent.sample(next_state)\n", " agent.update(state, action, reward, next_state, next_action,done) # 算法更新\n", " state = next_state # 更新状态\n", " action = next_action\n", " ep_reward += reward\n", " if done:\n", " break\n", " rewards.append(ep_reward)\n", " print(f\"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}\")\n", " print('完成训练!')\n", " return {\"rewards\":rewards}\n", " \n", "def test(cfg,env,agent):\n", " print('开始测试!')\n", " print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')\n", " rewards = [] # 记录所有回合的奖励\n", " for i_ep in range(cfg.test_eps):\n", " ep_reward = 0 # 记录每个episode的reward\n", " state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合)\n", " while True:\n", " action = agent.predict(state) # 根据算法选择一个动作\n", " next_state, reward, done, _ = env.step(action) # 与环境进行一个交互\n", " state = next_state # 更新状态\n", " ep_reward += reward\n", " if done:\n", " break\n", " rewards.append(ep_reward)\n", " print(f\"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}\")\n", " print('完成测试!')\n", " return {\"rewards\":rewards}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3、定义环境\n", "为了具体看看Q learning和Sarsa的不同,笔者决定跟Q learning使用相同的环境\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import turtle\n", "import numpy as np\n", "\n", "# turtle tutorial : https://docs.python.org/3.3/library/turtle.html\n", "\n", "def GridWorld(gridmap=None, is_slippery=False):\n", " if gridmap is None:\n", " gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']\n", " env = gym.make(\"FrozenLake-v0\", desc=gridmap, is_slippery=False)\n", " env = FrozenLakeWapper(env)\n", " return env\n", "\n", "\n", "class FrozenLakeWapper(gym.Wrapper):\n", " def __init__(self, env):\n", " gym.Wrapper.__init__(self, env)\n", " self.max_y = env.desc.shape[0]\n", " self.max_x = env.desc.shape[1]\n", " self.t = None\n", " self.unit = 50\n", "\n", " def draw_box(self, x, y, fillcolor='', line_color='gray'):\n", " self.t.up()\n", " self.t.goto(x * self.unit, y * self.unit)\n", " self.t.color(line_color)\n", " self.t.fillcolor(fillcolor)\n", " self.t.setheading(90)\n", " self.t.down()\n", " self.t.begin_fill()\n", " for _ in range(4):\n", " self.t.forward(self.unit)\n", " self.t.right(90)\n", " self.t.end_fill()\n", "\n", " def move_player(self, x, y):\n", " self.t.up()\n", " self.t.setheading(90)\n", " self.t.fillcolor('red')\n", " self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)\n", "\n", " def render(self):\n", " if self.t == None:\n", " self.t = turtle.Turtle()\n", " self.wn = turtle.Screen()\n", " self.wn.setup(self.unit * self.max_x + 100,\n", " self.unit * self.max_y + 100)\n", " self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,\n", " self.unit * self.max_y)\n", " self.t.shape('circle')\n", " self.t.width(2)\n", " self.t.speed(0)\n", " self.t.color('gray')\n", " for i in range(self.desc.shape[0]):\n", " for j in range(self.desc.shape[1]):\n", " x = j\n", " y = self.max_y - 1 - i\n", " if self.desc[i][j] == b'S': # Start\n", " self.draw_box(x, y, 'white')\n", " elif self.desc[i][j] == b'F': # Frozen ice\n", " self.draw_box(x, y, 'white')\n", " elif self.desc[i][j] == b'G': # Goal\n", " self.draw_box(x, y, 'yellow')\n", " elif self.desc[i][j] == b'H': # Hole\n", " self.draw_box(x, y, 'black')\n", " else:\n", " self.draw_box(x, y, 'white')\n", " self.t.shape('turtle')\n", "\n", " x_pos = self.s % self.max_x\n", " y_pos = self.max_y - 1 - int(self.s / self.max_x)\n", " self.move_player(x_pos, y_pos)\n", "\n", "\n", "class CliffWalkingWapper(gym.Wrapper):\n", " def __init__(self, env):\n", " gym.Wrapper.__init__(self, env)\n", " self.t = None\n", " self.unit = 50\n", " self.max_x = 12\n", " self.max_y = 4\n", "\n", " def draw_x_line(self, y, x0, x1, color='gray'):\n", " assert x1 > x0\n", " self.t.color(color)\n", " self.t.setheading(0)\n", " self.t.up()\n", " self.t.goto(x0, y)\n", " self.t.down()\n", " self.t.forward(x1 - x0)\n", "\n", " def draw_y_line(self, x, y0, y1, color='gray'):\n", " assert y1 > y0\n", " self.t.color(color)\n", " self.t.setheading(90)\n", " self.t.up()\n", " self.t.goto(x, y0)\n", " self.t.down()\n", " self.t.forward(y1 - y0)\n", "\n", " def draw_box(self, x, y, fillcolor='', line_color='gray'):\n", " self.t.up()\n", " self.t.goto(x * self.unit, y * self.unit)\n", " self.t.color(line_color)\n", " self.t.fillcolor(fillcolor)\n", " self.t.setheading(90)\n", " self.t.down()\n", " self.t.begin_fill()\n", " for i in range(4):\n", " self.t.forward(self.unit)\n", " self.t.right(90)\n", " self.t.end_fill()\n", "\n", " def move_player(self, x, y):\n", " self.t.up()\n", " self.t.setheading(90)\n", " self.t.fillcolor('red')\n", " self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)\n", "\n", " def render(self):\n", " if self.t == None:\n", " self.t = turtle.Turtle()\n", " self.wn = turtle.Screen()\n", " self.wn.setup(self.unit * self.max_x + 100,\n", " self.unit * self.max_y + 100)\n", " self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,\n", " self.unit * self.max_y)\n", " self.t.shape('circle')\n", " self.t.width(2)\n", " self.t.speed(0)\n", " self.t.color('gray')\n", " for _ in range(2):\n", " self.t.forward(self.max_x * self.unit)\n", " self.t.left(90)\n", " self.t.forward(self.max_y * self.unit)\n", " self.t.left(90)\n", " for i in range(1, self.max_y):\n", " self.draw_x_line(\n", " y=i * self.unit, x0=0, x1=self.max_x * self.unit)\n", " for i in range(1, self.max_x):\n", " self.draw_y_line(\n", " x=i * self.unit, y0=0, y1=self.max_y * self.unit)\n", "\n", " for i in range(1, self.max_x - 1):\n", " self.draw_box(i, 0, 'black')\n", " self.draw_box(self.max_x - 1, 0, 'yellow')\n", " self.t.shape('turtle')\n", "\n", " x_pos = self.s % self.max_x\n", " y_pos = self.max_y - 1 - int(self.s / self.max_x)\n", " self.move_player(x_pos, y_pos)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def env_agent_config(cfg,seed=1):\n", " '''创建环境和智能体\n", " Args:\n", " cfg ([type]): [description]\n", " seed (int, optional): 随机种子. Defaults to 1.\n", " Returns:\n", " env [type]: 环境\n", " agent : 智能体\n", " ''' \n", " env = gym.make(cfg.env_name) \n", " env = CliffWalkingWapper(env)\n", " env.seed(seed) # 设置随机种子\n", " n_states = env.observation_space.n # 状态维度\n", " n_actions = env.action_space.n # 动作维度\n", " print(f\"状态数:{n_states},动作数:{n_actions}\")\n", " agent = Sarsa(n_actions,cfg)\n", " return env,agent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4、设置参数\n", "同样的参数也是一样" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import datetime\n", "import argparse\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "def get_args():\n", " \"\"\" \n", " \"\"\"\n", " curr_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\") # 获取当前时间\n", " parser = argparse.ArgumentParser(description=\"hyperparameters\") \n", " parser.add_argument('--algo_name',default='Sarsa',type=str,help=\"name of algorithm\")\n", " parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help=\"name of environment\")\n", " parser.add_argument('--train_eps',default=400,type=int,help=\"episodes of training\") # 训练的回合数\n", " parser.add_argument('--test_eps',default=20,type=int,help=\"episodes of testing\") # 测试的回合数\n", " parser.add_argument('--gamma',default=0.90,type=float,help=\"discounted factor\") # 折扣因子\n", " parser.add_argument('--epsilon_start',default=0.95,type=float,help=\"initial value of epsilon\") # e-greedy策略中初始epsilon\n", " parser.add_argument('--epsilon_end',default=0.01,type=float,help=\"final value of epsilon\") # e-greedy策略中的终止epsilon\n", " parser.add_argument('--epsilon_decay',default=300,type=int,help=\"decay rate of epsilon\") # e-greedy策略中epsilon的衰减率\n", " parser.add_argument('--lr',default=0.1,type=float,help=\"learning rate\")\n", " parser.add_argument('--device',default='cpu',type=str,help=\"cpu or cuda\") \n", " args = parser.parse_args([]) \n", " return args\n", "\n", "def smooth(data, weight=0.9): \n", " '''用于平滑曲线,类似于Tensorboard中的smooth\n", "\n", " Args:\n", " data (List):输入数据\n", " weight (Float): 平滑权重,处于0-1之间,数值越高说明越平滑,一般取0.9\n", "\n", " Returns:\n", " smoothed (List): 平滑后的数据\n", " '''\n", " last = data[0] # First value in the plot (first timestep)\n", " smoothed = list()\n", " for point in data:\n", " smoothed_val = last * weight + (1 - weight) * point # 计算平滑值\n", " smoothed.append(smoothed_val) \n", " last = smoothed_val \n", " return smoothed\n", "\n", "def plot_rewards(rewards,cfg, tag='train'):\n", " sns.set()\n", " plt.figure() # 创建一个图形实例,方便同时多画几个图\n", " plt.title(f\"{tag}ing curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}\")\n", " plt.xlabel('epsiodes')\n", " plt.plot(rewards, label='rewards')\n", " plt.plot(smooth(rewards), label='smoothed')\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5、开始训练\n", "训练结果会发现,Sarsa收敛速度更快,但收敛值会比Q-learning低" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "状态数:48,动作数:4\n", "开始训练!\n", "环境:CliffWalking-v0, 算法:Sarsa, 设备:cpu\n", "回合:1/400,奖励:-1524.0,Epsilon:0.2029722781251147\n", "回合:2/400,奖励:-1294.0,Epsilon:0.011808588201828951\n", "回合:3/400,奖励:-192.0,Epsilon:0.01050118158853445\n", "回合:4/400,奖励:-346.0,Epsilon:0.010049747911736582\n", "回合:5/400,奖励:-252.0,Epsilon:0.010009240861841986\n", "回合:6/400,奖励:-168.0,Epsilon:0.010003005072880926\n", "回合:7/400,奖励:-393.0,Epsilon:0.01000042188120369\n", "回合:8/400,奖励:-169.0,Epsilon:0.010000136281659052\n", "回合:9/400,奖励:-97.0,Epsilon:0.010000071145264558\n", "回合:10/400,奖励:-134.0,Epsilon:0.010000029022085234\n", "回合:11/400,奖励:-124.0,Epsilon:0.010000012655059554\n", "回合:12/400,奖励:-74.0,Epsilon:0.010000007701309915\n", "回合:13/400,奖励:-135.0,Epsilon:0.010000003120699265\n", "回合:14/400,奖励:-84.0,Epsilon:0.010000001776639691\n", "回合:15/400,奖励:-101.0,Epsilon:0.010000000903081117\n", "回合:16/400,奖励:-111.0,Epsilon:0.010000000429438717\n", "回合:17/400,奖励:-114.0,Epsilon:0.010000000200165738\n", "回合:18/400,奖励:-114.0,Epsilon:0.010000000093299278\n", "回合:19/400,奖励:-82.0,Epsilon:0.010000000053829002\n", "回合:20/400,奖励:-85.0,Epsilon:0.01000000003044167\n", "回合:21/400,奖励:-108.0,Epsilon:0.010000000014768242\n", "回合:22/400,奖励:-66.0,Epsilon:0.010000000009479634\n", "回合:23/400,奖励:-74.0,Epsilon:0.010000000005768887\n", "回合:24/400,奖励:-114.0,Epsilon:0.010000000002688936\n", "回合:25/400,奖励:-98.0,Epsilon:0.010000000001394421\n", "回合:26/400,奖励:-94.0,Epsilon:0.010000000000742658\n", "回合:27/400,奖励:-58.0,Epsilon:0.010000000000502822\n", "回合:28/400,奖励:-100.0,Epsilon:0.010000000000257298\n", "回合:29/400,奖励:-208.0,Epsilon:0.010000000000123995\n", "回合:30/400,奖励:-184.0,Epsilon:0.010000000000070121\n", "回合:31/400,奖励:-62.0,Epsilon:0.010000000000046227\n", "回合:32/400,奖励:-117.0,Epsilon:0.01000000000002112\n", "回合:33/400,奖励:-47.0,Epsilon:0.010000000000015387\n", "回合:34/400,奖励:-54.0,Epsilon:0.0100000000000107\n", "回合:35/400,奖励:-120.0,Epsilon:0.010000000000004792\n", "回合:36/400,奖励:-75.0,Epsilon:0.010000000000002897\n", "回合:37/400,奖励:-62.0,Epsilon:0.01000000000000191\n", "回合:38/400,奖励:-70.0,Epsilon:0.010000000000001194\n", "回合:39/400,奖励:-67.0,Epsilon:0.010000000000000762\n", "回合:40/400,奖励:-87.0,Epsilon:0.010000000000000425\n", "回合:41/400,奖励:-92.0,Epsilon:0.01000000000000023\n", "回合:42/400,奖励:-79.0,Epsilon:0.010000000000000136\n", "回合:43/400,奖励:-49.0,Epsilon:0.010000000000000097\n", "回合:44/400,奖励:-103.0,Epsilon:0.010000000000000049\n", "回合:45/400,奖励:-40.0,Epsilon:0.010000000000000037\n", "回合:46/400,奖励:-214.0,Epsilon:0.010000000000000018\n", "回合:47/400,奖励:-83.0,Epsilon:0.01000000000000001\n", "回合:48/400,奖励:-62.0,Epsilon:0.010000000000000007\n", "回合:49/400,奖励:-37.0,Epsilon:0.010000000000000005\n", "回合:50/400,奖励:-73.0,Epsilon:0.010000000000000004\n", "回合:51/400,奖励:-66.0,Epsilon:0.010000000000000002\n", "回合:52/400,奖励:-48.0,Epsilon:0.010000000000000002\n", "回合:53/400,奖励:-96.0,Epsilon:0.01\n", "回合:54/400,奖励:-189.0,Epsilon:0.01\n", "回合:55/400,奖励:-42.0,Epsilon:0.01\n", "回合:56/400,奖励:-46.0,Epsilon:0.01\n", "回合:57/400,奖励:-85.0,Epsilon:0.01\n", "回合:58/400,奖励:-52.0,Epsilon:0.01\n", "回合:59/400,奖励:-86.0,Epsilon:0.01\n", "回合:60/400,奖励:-41.0,Epsilon:0.01\n", "回合:61/400,奖励:-51.0,Epsilon:0.01\n", "回合:62/400,奖励:-59.0,Epsilon:0.01\n", "回合:63/400,奖励:-145.0,Epsilon:0.01\n", "回合:64/400,奖励:-76.0,Epsilon:0.01\n", "回合:65/400,奖励:-43.0,Epsilon:0.01\n", "回合:66/400,奖励:-49.0,Epsilon:0.01\n", "回合:67/400,奖励:-36.0,Epsilon:0.01\n", "回合:68/400,奖励:-41.0,Epsilon:0.01\n", "回合:69/400,奖励:-69.0,Epsilon:0.01\n", "回合:70/400,奖励:-38.0,Epsilon:0.01\n", "回合:71/400,奖励:-63.0,Epsilon:0.01\n", "回合:72/400,奖励:-46.0,Epsilon:0.01\n", "回合:73/400,奖励:-30.0,Epsilon:0.01\n", "回合:74/400,奖励:-45.0,Epsilon:0.01\n", "回合:75/400,奖励:-38.0,Epsilon:0.01\n", "回合:76/400,奖励:-88.0,Epsilon:0.01\n", "回合:77/400,奖励:-19.0,Epsilon:0.01\n", "回合:78/400,奖励:-40.0,Epsilon:0.01\n", "回合:79/400,奖励:-62.0,Epsilon:0.01\n", "回合:80/400,奖励:-25.0,Epsilon:0.01\n", "回合:81/400,奖励:-54.0,Epsilon:0.01\n", "回合:82/400,奖励:-41.0,Epsilon:0.01\n", "回合:83/400,奖励:-57.0,Epsilon:0.01\n", "回合:84/400,奖励:-52.0,Epsilon:0.01\n", "回合:85/400,奖励:-42.0,Epsilon:0.01\n", "回合:86/400,奖励:-51.0,Epsilon:0.01\n", "回合:87/400,奖励:-53.0,Epsilon:0.01\n", "回合:88/400,奖励:-42.0,Epsilon:0.01\n", "回合:89/400,奖励:-53.0,Epsilon:0.01\n", "回合:90/400,奖励:-31.0,Epsilon:0.01\n", "回合:91/400,奖励:-75.0,Epsilon:0.01\n", "回合:92/400,奖励:-148.0,Epsilon:0.01\n", "回合:93/400,奖励:-41.0,Epsilon:0.01\n", "回合:94/400,奖励:-47.0,Epsilon:0.01\n", "回合:95/400,奖励:-184.0,Epsilon:0.01\n", "回合:96/400,奖励:-34.0,Epsilon:0.01\n", "回合:97/400,奖励:-45.0,Epsilon:0.01\n", "回合:98/400,奖励:-52.0,Epsilon:0.01\n", "回合:99/400,奖励:-44.0,Epsilon:0.01\n", "回合:100/400,奖励:-49.0,Epsilon:0.01\n", "回合:101/400,奖励:-30.0,Epsilon:0.01\n", "回合:102/400,奖励:-49.0,Epsilon:0.01\n", "回合:103/400,奖励:-23.0,Epsilon:0.01\n", "回合:104/400,奖励:-37.0,Epsilon:0.01\n", "回合:105/400,奖励:-37.0,Epsilon:0.01\n", "回合:106/400,奖励:-44.0,Epsilon:0.01\n", "回合:107/400,奖励:-40.0,Epsilon:0.01\n", "回合:108/400,奖励:-28.0,Epsilon:0.01\n", "回合:109/400,奖励:-50.0,Epsilon:0.01\n", "回合:110/400,奖励:-46.0,Epsilon:0.01\n", "回合:111/400,奖励:-28.0,Epsilon:0.01\n", "回合:112/400,奖励:-35.0,Epsilon:0.01\n", "回合:113/400,奖励:-35.0,Epsilon:0.01\n", "回合:114/400,奖励:-45.0,Epsilon:0.01\n", "回合:115/400,奖励:-38.0,Epsilon:0.01\n", "回合:116/400,奖励:-39.0,Epsilon:0.01\n", "回合:117/400,奖励:-27.0,Epsilon:0.01\n", "回合:118/400,奖励:-49.0,Epsilon:0.01\n", "回合:119/400,奖励:-27.0,Epsilon:0.01\n", "回合:120/400,奖励:-25.0,Epsilon:0.01\n", "回合:121/400,奖励:-50.0,Epsilon:0.01\n", "回合:122/400,奖励:-41.0,Epsilon:0.01\n", "回合:123/400,奖励:-22.0,Epsilon:0.01\n", "回合:124/400,奖励:-38.0,Epsilon:0.01\n", "回合:125/400,奖励:-125.0,Epsilon:0.01\n", "回合:126/400,奖励:-25.0,Epsilon:0.01\n", "回合:127/400,奖励:-40.0,Epsilon:0.01\n", "回合:128/400,奖励:-33.0,Epsilon:0.01\n", "回合:129/400,奖励:-56.0,Epsilon:0.01\n", "回合:130/400,奖励:-32.0,Epsilon:0.01\n", "回合:131/400,奖励:-21.0,Epsilon:0.01\n", "回合:132/400,奖励:-33.0,Epsilon:0.01\n", "回合:133/400,奖励:-23.0,Epsilon:0.01\n", "回合:134/400,奖励:-33.0,Epsilon:0.01\n", "回合:135/400,奖励:-34.0,Epsilon:0.01\n", "回合:136/400,奖励:-33.0,Epsilon:0.01\n", "回合:137/400,奖励:-21.0,Epsilon:0.01\n", "回合:138/400,奖励:-40.0,Epsilon:0.01\n", "回合:139/400,奖励:-23.0,Epsilon:0.01\n", "回合:140/400,奖励:-31.0,Epsilon:0.01\n", "回合:141/400,奖励:-31.0,Epsilon:0.01\n", "回合:142/400,奖励:-26.0,Epsilon:0.01\n", "回合:143/400,奖励:-26.0,Epsilon:0.01\n", "回合:144/400,奖励:-32.0,Epsilon:0.01\n", "回合:145/400,奖励:-27.0,Epsilon:0.01\n", "回合:146/400,奖励:-33.0,Epsilon:0.01\n", "回合:147/400,奖励:-35.0,Epsilon:0.01\n", "回合:148/400,奖励:-21.0,Epsilon:0.01\n", "回合:149/400,奖励:-23.0,Epsilon:0.01\n", "回合:150/400,奖励:-33.0,Epsilon:0.01\n", "回合:151/400,奖励:-25.0,Epsilon:0.01\n", "回合:152/400,奖励:-41.0,Epsilon:0.01\n", "回合:153/400,奖励:-31.0,Epsilon:0.01\n", "回合:154/400,奖励:-28.0,Epsilon:0.01\n", "回合:155/400,奖励:-133.0,Epsilon:0.01\n", "回合:156/400,奖励:-22.0,Epsilon:0.01\n", "回合:157/400,奖励:-21.0,Epsilon:0.01\n", "回合:158/400,奖励:-33.0,Epsilon:0.01\n", "回合:159/400,奖励:-33.0,Epsilon:0.01\n", "回合:160/400,奖励:-24.0,Epsilon:0.01\n", "回合:161/400,奖励:-34.0,Epsilon:0.01\n", "回合:162/400,奖励:-20.0,Epsilon:0.01\n", "回合:163/400,奖励:-21.0,Epsilon:0.01\n", "回合:164/400,奖励:-126.0,Epsilon:0.01\n", "回合:165/400,奖励:-36.0,Epsilon:0.01\n", "回合:166/400,奖励:-18.0,Epsilon:0.01\n", "回合:167/400,奖励:-35.0,Epsilon:0.01\n", "回合:168/400,奖励:-26.0,Epsilon:0.01\n", "回合:169/400,奖励:-24.0,Epsilon:0.01\n", "回合:170/400,奖励:-33.0,Epsilon:0.01\n", "回合:171/400,奖励:-17.0,Epsilon:0.01\n", "回合:172/400,奖励:-23.0,Epsilon:0.01\n", "回合:173/400,奖励:-26.0,Epsilon:0.01\n", "回合:174/400,奖励:-23.0,Epsilon:0.01\n", "回合:175/400,奖励:-21.0,Epsilon:0.01\n", "回合:176/400,奖励:-35.0,Epsilon:0.01\n", "回合:177/400,奖励:-26.0,Epsilon:0.01\n", "回合:178/400,奖励:-17.0,Epsilon:0.01\n", "回合:179/400,奖励:-20.0,Epsilon:0.01\n", "回合:180/400,奖励:-28.0,Epsilon:0.01\n", "回合:181/400,奖励:-34.0,Epsilon:0.01\n", "回合:182/400,奖励:-27.0,Epsilon:0.01\n", "回合:183/400,奖励:-22.0,Epsilon:0.01\n", "回合:184/400,奖励:-24.0,Epsilon:0.01\n", "回合:185/400,奖励:-26.0,Epsilon:0.01\n", "回合:186/400,奖励:-20.0,Epsilon:0.01\n", "回合:187/400,奖励:-30.0,Epsilon:0.01\n", "回合:188/400,奖励:-28.0,Epsilon:0.01\n", "回合:189/400,奖励:-15.0,Epsilon:0.01\n", "回合:190/400,奖励:-30.0,Epsilon:0.01\n", "回合:191/400,奖励:-29.0,Epsilon:0.01\n", "回合:192/400,奖励:-22.0,Epsilon:0.01\n", "回合:193/400,奖励:-25.0,Epsilon:0.01\n", "回合:194/400,奖励:-21.0,Epsilon:0.01\n", "回合:195/400,奖励:-19.0,Epsilon:0.01\n", "回合:196/400,奖励:-23.0,Epsilon:0.01\n", "回合:197/400,奖励:-21.0,Epsilon:0.01\n", "回合:198/400,奖励:-32.0,Epsilon:0.01\n", "回合:199/400,奖励:-30.0,Epsilon:0.01\n", "回合:200/400,奖励:-22.0,Epsilon:0.01\n", "回合:201/400,奖励:-20.0,Epsilon:0.01\n", "回合:202/400,奖励:-27.0,Epsilon:0.01\n", "回合:203/400,奖励:-21.0,Epsilon:0.01\n", "回合:204/400,奖励:-26.0,Epsilon:0.01\n", "回合:205/400,奖励:-19.0,Epsilon:0.01\n", "回合:206/400,奖励:-17.0,Epsilon:0.01\n", "回合:207/400,奖励:-31.0,Epsilon:0.01\n", "回合:208/400,奖励:-18.0,Epsilon:0.01\n", "回合:209/400,奖励:-24.0,Epsilon:0.01\n", "回合:210/400,奖励:-17.0,Epsilon:0.01\n", "回合:211/400,奖励:-26.0,Epsilon:0.01\n", "回合:212/400,奖励:-27.0,Epsilon:0.01\n", "回合:213/400,奖励:-33.0,Epsilon:0.01\n", "回合:214/400,奖励:-16.0,Epsilon:0.01\n", "回合:215/400,奖励:-32.0,Epsilon:0.01\n", "回合:216/400,奖励:-19.0,Epsilon:0.01\n", "回合:217/400,奖励:-20.0,Epsilon:0.01\n", "回合:218/400,奖励:-15.0,Epsilon:0.01\n", "回合:219/400,奖励:-119.0,Epsilon:0.01\n", "回合:220/400,奖励:-26.0,Epsilon:0.01\n", "回合:221/400,奖励:-26.0,Epsilon:0.01\n", "回合:222/400,奖励:-22.0,Epsilon:0.01\n", "回合:223/400,奖励:-22.0,Epsilon:0.01\n", "回合:224/400,奖励:-15.0,Epsilon:0.01\n", "回合:225/400,奖励:-24.0,Epsilon:0.01\n", "回合:226/400,奖励:-15.0,Epsilon:0.01\n", "回合:227/400,奖励:-31.0,Epsilon:0.01\n", "回合:228/400,奖励:-24.0,Epsilon:0.01\n", "回合:229/400,奖励:-20.0,Epsilon:0.01\n", "回合:230/400,奖励:-20.0,Epsilon:0.01\n", "回合:231/400,奖励:-22.0,Epsilon:0.01\n", "回合:232/400,奖励:-15.0,Epsilon:0.01\n", "回合:233/400,奖励:-19.0,Epsilon:0.01\n", "回合:234/400,奖励:-21.0,Epsilon:0.01\n", "回合:235/400,奖励:-27.0,Epsilon:0.01\n", "回合:236/400,奖励:-15.0,Epsilon:0.01\n", "回合:237/400,奖励:-25.0,Epsilon:0.01\n", "回合:238/400,奖励:-22.0,Epsilon:0.01\n", "回合:239/400,奖励:-16.0,Epsilon:0.01\n", "回合:240/400,奖励:-18.0,Epsilon:0.01\n", "回合:241/400,奖励:-13.0,Epsilon:0.01\n", "回合:242/400,奖励:-13.0,Epsilon:0.01\n", "回合:243/400,奖励:-13.0,Epsilon:0.01\n", "回合:244/400,奖励:-23.0,Epsilon:0.01\n", "回合:245/400,奖励:-29.0,Epsilon:0.01\n", "回合:246/400,奖励:-26.0,Epsilon:0.01\n", "回合:247/400,奖励:-19.0,Epsilon:0.01\n", "回合:248/400,奖励:-21.0,Epsilon:0.01\n", "回合:249/400,奖励:-17.0,Epsilon:0.01\n", "回合:250/400,奖励:-17.0,Epsilon:0.01\n", "回合:251/400,奖励:-15.0,Epsilon:0.01\n", "回合:252/400,奖励:-20.0,Epsilon:0.01\n", "回合:253/400,奖励:-23.0,Epsilon:0.01\n", "回合:254/400,奖励:-19.0,Epsilon:0.01\n", "回合:255/400,奖励:-21.0,Epsilon:0.01\n", "回合:256/400,奖励:-19.0,Epsilon:0.01\n", "回合:257/400,奖励:-17.0,Epsilon:0.01\n", "回合:258/400,奖励:-17.0,Epsilon:0.01\n", "回合:259/400,奖励:-15.0,Epsilon:0.01\n", "回合:260/400,奖励:-21.0,Epsilon:0.01\n", "回合:261/400,奖励:-17.0,Epsilon:0.01\n", "回合:262/400,奖励:-19.0,Epsilon:0.01\n", "回合:263/400,奖励:-19.0,Epsilon:0.01\n", "回合:264/400,奖励:-15.0,Epsilon:0.01\n", "回合:265/400,奖励:-19.0,Epsilon:0.01\n", "回合:266/400,奖励:-17.0,Epsilon:0.01\n", "回合:267/400,奖励:-15.0,Epsilon:0.01\n", "回合:268/400,奖励:-19.0,Epsilon:0.01\n", "回合:269/400,奖励:-27.0,Epsilon:0.01\n", "回合:270/400,奖励:-15.0,Epsilon:0.01\n", "回合:271/400,奖励:-17.0,Epsilon:0.01\n", "回合:272/400,奖励:-17.0,Epsilon:0.01\n", "回合:273/400,奖励:-25.0,Epsilon:0.01\n", "回合:274/400,奖励:-19.0,Epsilon:0.01\n", "回合:275/400,奖励:-22.0,Epsilon:0.01\n", "回合:276/400,奖励:-23.0,Epsilon:0.01\n", "回合:277/400,奖励:-18.0,Epsilon:0.01\n", "回合:278/400,奖励:-23.0,Epsilon:0.01\n", "回合:279/400,奖励:-21.0,Epsilon:0.01\n", "回合:280/400,奖励:-21.0,Epsilon:0.01\n", "回合:281/400,奖励:-21.0,Epsilon:0.01\n", "回合:282/400,奖励:-19.0,Epsilon:0.01\n", "回合:283/400,奖励:-18.0,Epsilon:0.01\n", "回合:284/400,奖励:-15.0,Epsilon:0.01\n", "回合:285/400,奖励:-19.0,Epsilon:0.01\n", "回合:286/400,奖励:-19.0,Epsilon:0.01\n", "回合:287/400,奖励:-21.0,Epsilon:0.01\n", "回合:288/400,奖励:-15.0,Epsilon:0.01\n", "回合:289/400,奖励:-32.0,Epsilon:0.01\n", "回合:290/400,奖励:-18.0,Epsilon:0.01\n", "回合:291/400,奖励:-17.0,Epsilon:0.01\n", "回合:292/400,奖励:-15.0,Epsilon:0.01\n", "回合:293/400,奖励:-24.0,Epsilon:0.01\n", "回合:294/400,奖励:-22.0,Epsilon:0.01\n", "回合:295/400,奖励:-31.0,Epsilon:0.01\n", "回合:296/400,奖励:-17.0,Epsilon:0.01\n", "回合:297/400,奖励:-19.0,Epsilon:0.01\n", "回合:298/400,奖励:-19.0,Epsilon:0.01\n", "回合:299/400,奖励:-20.0,Epsilon:0.01\n", "回合:300/400,奖励:-21.0,Epsilon:0.01\n", "回合:301/400,奖励:-26.0,Epsilon:0.01\n", "回合:302/400,奖励:-20.0,Epsilon:0.01\n", "回合:303/400,奖励:-16.0,Epsilon:0.01\n", "回合:304/400,奖励:-20.0,Epsilon:0.01\n", "回合:305/400,奖励:-21.0,Epsilon:0.01\n", "回合:306/400,奖励:-16.0,Epsilon:0.01\n", "回合:307/400,奖励:-19.0,Epsilon:0.01\n", "回合:308/400,奖励:-24.0,Epsilon:0.01\n", "回合:309/400,奖励:-20.0,Epsilon:0.01\n", "回合:310/400,奖励:-17.0,Epsilon:0.01\n", "回合:311/400,奖励:-16.0,Epsilon:0.01\n", "回合:312/400,奖励:-25.0,Epsilon:0.01\n", "回合:313/400,奖励:-16.0,Epsilon:0.01\n", "回合:314/400,奖励:-19.0,Epsilon:0.01\n", "回合:315/400,奖励:-19.0,Epsilon:0.01\n", "回合:316/400,奖励:-27.0,Epsilon:0.01\n", "回合:317/400,奖励:-15.0,Epsilon:0.01\n", "回合:318/400,奖励:-15.0,Epsilon:0.01\n", "回合:319/400,奖励:-15.0,Epsilon:0.01\n", "回合:320/400,奖励:-19.0,Epsilon:0.01\n", "回合:321/400,奖励:-23.0,Epsilon:0.01\n", "回合:322/400,奖励:-24.0,Epsilon:0.01\n", "回合:323/400,奖励:-15.0,Epsilon:0.01\n", "回合:324/400,奖励:-20.0,Epsilon:0.01\n", "回合:325/400,奖励:-18.0,Epsilon:0.01\n", "回合:326/400,奖励:-19.0,Epsilon:0.01\n", "回合:327/400,奖励:-19.0,Epsilon:0.01\n", "回合:328/400,奖励:-26.0,Epsilon:0.01\n", "回合:329/400,奖励:-16.0,Epsilon:0.01\n", "回合:330/400,奖励:-18.0,Epsilon:0.01\n", "回合:331/400,奖励:-15.0,Epsilon:0.01\n", "回合:332/400,奖励:-15.0,Epsilon:0.01\n", "回合:333/400,奖励:-17.0,Epsilon:0.01\n", "回合:334/400,奖励:-17.0,Epsilon:0.01\n", "回合:335/400,奖励:-16.0,Epsilon:0.01\n", "回合:336/400,奖励:-24.0,Epsilon:0.01\n", "回合:337/400,奖励:-15.0,Epsilon:0.01\n", "回合:338/400,奖励:-18.0,Epsilon:0.01\n", "回合:339/400,奖励:-16.0,Epsilon:0.01\n", "回合:340/400,奖励:-15.0,Epsilon:0.01\n", "回合:341/400,奖励:-18.0,Epsilon:0.01\n", "回合:342/400,奖励:-15.0,Epsilon:0.01\n", "回合:343/400,奖励:-20.0,Epsilon:0.01\n", "回合:344/400,奖励:-18.0,Epsilon:0.01\n", "回合:345/400,奖励:-17.0,Epsilon:0.01\n", "回合:346/400,奖励:-19.0,Epsilon:0.01\n", "回合:347/400,奖励:-15.0,Epsilon:0.01\n", "回合:348/400,奖励:-15.0,Epsilon:0.01\n", "回合:349/400,奖励:-15.0,Epsilon:0.01\n", "回合:350/400,奖励:-18.0,Epsilon:0.01\n", "回合:351/400,奖励:-16.0,Epsilon:0.01\n", "回合:352/400,奖励:-16.0,Epsilon:0.01\n", "回合:353/400,奖励:-15.0,Epsilon:0.01\n", "回合:354/400,奖励:-20.0,Epsilon:0.01\n", "回合:355/400,奖励:-15.0,Epsilon:0.01\n", "回合:356/400,奖励:-17.0,Epsilon:0.01\n", "回合:357/400,奖励:-15.0,Epsilon:0.01\n", "回合:358/400,奖励:-17.0,Epsilon:0.01\n", "回合:359/400,奖励:-15.0,Epsilon:0.01\n", "回合:360/400,奖励:-16.0,Epsilon:0.01\n", "回合:361/400,奖励:-15.0,Epsilon:0.01\n", "回合:362/400,奖励:-18.0,Epsilon:0.01\n", "回合:363/400,奖励:-17.0,Epsilon:0.01\n", "回合:364/400,奖励:-22.0,Epsilon:0.01\n", "回合:365/400,奖励:-15.0,Epsilon:0.01\n", "回合:366/400,奖励:-15.0,Epsilon:0.01\n", "回合:367/400,奖励:-15.0,Epsilon:0.01\n", "回合:368/400,奖励:-16.0,Epsilon:0.01\n", "回合:369/400,奖励:-16.0,Epsilon:0.01\n", "回合:370/400,奖励:-15.0,Epsilon:0.01\n", "回合:371/400,奖励:-20.0,Epsilon:0.01\n", "回合:372/400,奖励:-15.0,Epsilon:0.01\n", "回合:373/400,奖励:-15.0,Epsilon:0.01\n", "回合:374/400,奖励:-15.0,Epsilon:0.01\n", "回合:375/400,奖励:-16.0,Epsilon:0.01\n", "回合:376/400,奖励:-15.0,Epsilon:0.01\n", "回合:377/400,奖励:-15.0,Epsilon:0.01\n", "回合:378/400,奖励:-17.0,Epsilon:0.01\n", "回合:379/400,奖励:-20.0,Epsilon:0.01\n", "回合:380/400,奖励:-17.0,Epsilon:0.01\n", "回合:381/400,奖励:-15.0,Epsilon:0.01\n", "回合:382/400,奖励:-15.0,Epsilon:0.01\n", "回合:383/400,奖励:-15.0,Epsilon:0.01\n", "回合:384/400,奖励:-15.0,Epsilon:0.01\n", "回合:385/400,奖励:-16.0,Epsilon:0.01\n", "回合:386/400,奖励:-15.0,Epsilon:0.01\n", "回合:387/400,奖励:-18.0,Epsilon:0.01\n", "回合:388/400,奖励:-15.0,Epsilon:0.01\n", "回合:389/400,奖励:-15.0,Epsilon:0.01\n", "回合:390/400,奖励:-15.0,Epsilon:0.01\n", "回合:391/400,奖励:-16.0,Epsilon:0.01\n", "回合:392/400,奖励:-18.0,Epsilon:0.01\n", "回合:393/400,奖励:-15.0,Epsilon:0.01\n", "回合:394/400,奖励:-15.0,Epsilon:0.01\n", "回合:395/400,奖励:-15.0,Epsilon:0.01\n", "回合:396/400,奖励:-20.0,Epsilon:0.01\n", "回合:397/400,奖励:-15.0,Epsilon:0.01\n", "回合:398/400,奖励:-15.0,Epsilon:0.01\n", "回合:399/400,奖励:-15.0,Epsilon:0.01\n", "回合:400/400,奖励:-15.0,Epsilon:0.01\n", "完成训练!\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "开始测试!\n", "环境:CliffWalking-v0, 算法:Sarsa, 设备:cpu\n", "回合数:1/20, 奖励:-15.0\n", "回合数:2/20, 奖励:-15.0\n", "回合数:3/20, 奖励:-15.0\n", "回合数:4/20, 奖励:-15.0\n", "回合数:5/20, 奖励:-15.0\n", "回合数:6/20, 奖励:-15.0\n", "回合数:7/20, 奖励:-15.0\n", "回合数:8/20, 奖励:-15.0\n", "回合数:9/20, 奖励:-15.0\n", "回合数:10/20, 奖励:-15.0\n", "回合数:11/20, 奖励:-15.0\n", "回合数:12/20, 奖励:-15.0\n", "回合数:13/20, 奖励:-15.0\n", "回合数:14/20, 奖励:-15.0\n", "回合数:15/20, 奖励:-15.0\n", "回合数:16/20, 奖励:-15.0\n", "回合数:17/20, 奖励:-15.0\n", "回合数:18/20, 奖励:-15.0\n", "回合数:19/20, 奖励:-15.0\n", "回合数:20/20, 奖励:-15.0\n", "完成测试!\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 获取参数\n", "cfg = get_args() \n", "# 训练\n", "env, agent = env_agent_config(cfg)\n", "res_dic = train(cfg, env, agent)\n", " \n", "plot_rewards(res_dic['rewards'], cfg, tag=\"train\") \n", "# 测试\n", "res_dic = test(cfg, env, agent)\n", "plot_rewards(res_dic['rewards'], cfg, tag=\"test\") # 画出结果" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.12 ('rl_tutorials')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "4f613f1ab80ec98dc1b91d6e720de51301598a187317378e53e49b773c1123dd" } } }, "nbformat": 4, "nbformat_minor": 2 }