From 8447a28dcb539ffcae147ed9db5a671314a3dfad Mon Sep 17 00:00:00 2001 From: qiwang067 Date: Tue, 21 Jul 2020 00:02:45 +0800 Subject: [PATCH] fix some errors --- docs/code/Q-learning/agent.py | 75 ------------ docs/code/Q-learning/gridworld.py | 195 ------------------------------ docs/code/Q-learning/train.py | 90 -------------- docs/code/Sarsa/agent.py | 74 ------------ docs/code/Sarsa/gridworld.py | 195 ------------------------------ docs/code/Sarsa/train.py | 92 -------------- 6 files changed, 721 deletions(-) delete mode 100644 docs/code/Q-learning/agent.py delete mode 100644 docs/code/Q-learning/gridworld.py delete mode 100644 docs/code/Q-learning/train.py delete mode 100644 docs/code/Sarsa/agent.py delete mode 100644 docs/code/Sarsa/gridworld.py delete mode 100644 docs/code/Sarsa/train.py diff --git a/docs/code/Q-learning/agent.py b/docs/code/Q-learning/agent.py deleted file mode 100644 index 729c6d8..0000000 --- a/docs/code/Q-learning/agent.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- - -import numpy as np - - -class QLearningAgent(object): - def __init__(self, - obs_n, - act_n, - learning_rate=0.01, - gamma=0.9, - e_greed=0.1): - self.act_n = act_n # 动作维度,有几个动作可选 - self.lr = learning_rate # 学习率 - self.gamma = gamma # reward的衰减率 - self.epsilon = e_greed # 按一定概率随机选动作 - self.Q = np.zeros((obs_n, act_n)) - - # 根据输入观察值,采样输出的动作值,带探索 - def sample(self, obs): - if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作 - action = self.predict(obs) - else: - action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作 - return action - - # 根据输入观察值,预测输出的动作值 - def predict(self, obs): - Q_list = self.Q[obs, :] - maxQ = np.max(Q_list) - action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action - action = np.random.choice(action_list) - return action - - # 学习方法,也就是更新Q-table的方法 - def learn(self, obs, action, reward, next_obs, done): - """ off-policy - obs: 交互前的obs, s_t - action: 本次交互选择的action, a_t - reward: 本次动作获得的奖励r - next_obs: 本次交互后的obs, s_t+1 - done: episode是否结束 - """ - predict_Q = self.Q[obs, action] - if done: - target_Q = reward # 没有下一个状态了 - else: - target_Q = reward + self.gamma * np.max( - self.Q[next_obs, :]) # Q-learning - self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q - - # 把 Q表格 的数据保存到文件中 - def save(self): - npy_file = './q_table.npy' - np.save(npy_file, self.Q) - print(npy_file + ' saved.') - - # 从文件中读取数据到 Q表格 - def restore(self, npy_file='./q_table.npy'): - self.Q = np.load(npy_file) - print(npy_file + ' loaded.') \ No newline at end of file diff --git a/docs/code/Q-learning/gridworld.py b/docs/code/Q-learning/gridworld.py deleted file mode 100644 index 31d968f..0000000 --- a/docs/code/Q-learning/gridworld.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- - -import gym -import turtle -import numpy as np - -# turtle tutorial : https://docs.python.org/3.3/library/turtle.html - - -def GridWorld(gridmap=None, is_slippery=False): - if gridmap is None: - gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG'] - env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False) - env = FrozenLakeWapper(env) - return env - - -class FrozenLakeWapper(gym.Wrapper): - def __init__(self, env): - gym.Wrapper.__init__(self, env) - self.max_y = env.desc.shape[0] - self.max_x = env.desc.shape[1] - self.t = None - self.unit = 50 - - def draw_box(self, x, y, fillcolor='', line_color='gray'): - self.t.up() - self.t.goto(x * self.unit, y * self.unit) - self.t.color(line_color) - self.t.fillcolor(fillcolor) - self.t.setheading(90) - self.t.down() - self.t.begin_fill() - for _ in range(4): - self.t.forward(self.unit) - self.t.right(90) - self.t.end_fill() - - def move_player(self, x, y): - self.t.up() - self.t.setheading(90) - self.t.fillcolor('red') - self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit) - - def render(self): - if self.t == None: - self.t = turtle.Turtle() - self.wn = turtle.Screen() - self.wn.setup(self.unit * self.max_x + 100, - self.unit * self.max_y + 100) - self.wn.setworldcoordinates(0, 0, self.unit * self.max_x, - self.unit * self.max_y) - self.t.shape('circle') - self.t.width(2) - self.t.speed(0) - self.t.color('gray') - for i in range(self.desc.shape[0]): - for j in range(self.desc.shape[1]): - x = j - y = self.max_y - 1 - i - if self.desc[i][j] == b'S': # Start - self.draw_box(x, y, 'white') - elif self.desc[i][j] == b'F': # Frozen ice - self.draw_box(x, y, 'white') - elif self.desc[i][j] == b'G': # Goal - self.draw_box(x, y, 'yellow') - elif self.desc[i][j] == b'H': # Hole - self.draw_box(x, y, 'black') - else: - self.draw_box(x, y, 'white') - self.t.shape('turtle') - - x_pos = self.s % self.max_x - y_pos = self.max_y - 1 - int(self.s / self.max_x) - self.move_player(x_pos, y_pos) - - -class CliffWalkingWapper(gym.Wrapper): - def __init__(self, env): - gym.Wrapper.__init__(self, env) - self.t = None - self.unit = 50 - self.max_x = 12 - self.max_y = 4 - - def draw_x_line(self, y, x0, x1, color='gray'): - assert x1 > x0 - self.t.color(color) - self.t.setheading(0) - self.t.up() - self.t.goto(x0, y) - self.t.down() - self.t.forward(x1 - x0) - - def draw_y_line(self, x, y0, y1, color='gray'): - assert y1 > y0 - self.t.color(color) - self.t.setheading(90) - self.t.up() - self.t.goto(x, y0) - self.t.down() - self.t.forward(y1 - y0) - - def draw_box(self, x, y, fillcolor='', line_color='gray'): - self.t.up() - self.t.goto(x * self.unit, y * self.unit) - self.t.color(line_color) - self.t.fillcolor(fillcolor) - self.t.setheading(90) - self.t.down() - self.t.begin_fill() - for i in range(4): - self.t.forward(self.unit) - self.t.right(90) - self.t.end_fill() - - def move_player(self, x, y): - self.t.up() - self.t.setheading(90) - self.t.fillcolor('red') - self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit) - - def render(self): - if self.t == None: - self.t = turtle.Turtle() - self.wn = turtle.Screen() - self.wn.setup(self.unit * self.max_x + 100, - self.unit * self.max_y + 100) - self.wn.setworldcoordinates(0, 0, self.unit * self.max_x, - self.unit * self.max_y) - self.t.shape('circle') - self.t.width(2) - self.t.speed(0) - self.t.color('gray') - for _ in range(2): - self.t.forward(self.max_x * self.unit) - self.t.left(90) - self.t.forward(self.max_y * self.unit) - self.t.left(90) - for i in range(1, self.max_y): - self.draw_x_line( - y=i * self.unit, x0=0, x1=self.max_x * self.unit) - for i in range(1, self.max_x): - self.draw_y_line( - x=i * self.unit, y0=0, y1=self.max_y * self.unit) - - for i in range(1, self.max_x - 1): - self.draw_box(i, 0, 'black') - self.draw_box(self.max_x - 1, 0, 'yellow') - self.t.shape('turtle') - - x_pos = self.s % self.max_x - y_pos = self.max_y - 1 - int(self.s / self.max_x) - self.move_player(x_pos, y_pos) - - -if __name__ == '__main__': - # 环境1:FrozenLake, 可以配置冰面是否是滑的 - # 0 left, 1 down, 2 right, 3 up - env = gym.make("FrozenLake-v0", is_slippery=False) - env = FrozenLakeWapper(env) - - # 环境2:CliffWalking, 悬崖环境 - # env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left - # env = CliffWalkingWapper(env) - - # 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal - # gridmap = [ - # 'SFFF', - # 'FHFF', - # 'FFFF', - # 'HFGF' ] - # env = GridWorld(gridmap) - - env.reset() - for step in range(10): - action = np.random.randint(0, 4) - obs, reward, done, info = env.step(action) - print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\ - step, action, obs, reward, done, info)) - # env.render() # 渲染一帧图像 \ No newline at end of file diff --git a/docs/code/Q-learning/train.py b/docs/code/Q-learning/train.py deleted file mode 100644 index 032e2f9..0000000 --- a/docs/code/Q-learning/train.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- - -import gym -from gridworld import CliffWalkingWapper, FrozenLakeWapper -from agent import QLearningAgent -import time - - -def run_episode(env, agent, render=False): - total_steps = 0 # 记录每个episode走了多少step - total_reward = 0 - - obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode) - - while True: - action = agent.sample(obs) # 根据算法选择一个动作 - next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互 - # 训练 Q-learning算法 - agent.learn(obs, action, reward, next_obs, done) # 不需要下一步的action - - obs = next_obs # 存储上一个观察值 - total_reward += reward - total_steps += 1 # 计算step数 - if render: - env.render() #渲染新的一帧图形 - if done: - break - return total_reward, total_steps - - -def test_episode(env, agent): - total_reward = 0 - obs = env.reset() - while True: - action = agent.predict(obs) # greedy - next_obs, reward, done, _ = env.step(action) - total_reward += reward - obs = next_obs - time.sleep(0.5) - env.render() - if done: - print('test reward = %.1f' % (total_reward)) - break - - -def main(): - # env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up - # env = FrozenLakeWapper(env) - - env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left - env = CliffWalkingWapper(env) - - agent = QLearningAgent( - obs_n=env.observation_space.n, - act_n=env.action_space.n, - learning_rate=0.1, - gamma=0.9, - e_greed=0.1) - - is_render = False - for episode in range(500): - ep_reward, ep_steps = run_episode(env, agent, is_render) - print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, - ep_reward)) - - # 每隔20个episode渲染一下看看效果 - if episode % 20 == 0: - is_render = True - else: - is_render = False - # 训练结束,查看算法效果 - test_episode(env, agent) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/docs/code/Sarsa/agent.py b/docs/code/Sarsa/agent.py deleted file mode 100644 index f791293..0000000 --- a/docs/code/Sarsa/agent.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- - -import numpy as np - -# 根据Q表格选动作 -class SarsaAgent(object): - def __init__(self, - obs_n, - act_n, - learning_rate=0.01, - gamma=0.9, - e_greed=0.1): - self.act_n = act_n # 动作维度,有几个动作可选 - self.lr = learning_rate # 学习率 - self.gamma = gamma # reward的衰减率 - self.epsilon = e_greed # 按一定概率随机选动作 - self.Q = np.zeros((obs_n, act_n)) # 初始化Q表格 - - # 根据输入观察值,采样输出的动作值,带探索(epsilon-greedy,训练时用这个方法) - def sample(self, obs): - if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作 - action = self.predict(obs) - else: - action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作 - return action - - # 根据输入观察值,预测输出的动作值(已有里面挑最大,贪心的算法,只有利用,没有探索) - def predict(self, obs): - Q_list = self.Q[obs, :] - maxQ = np.max(Q_list) # 找到最大Q对应的下标 - action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action - action = np.random.choice(action_list) # 从这些action中随机挑一个action(可以打印出来看看) - return action - - # 学习方法,也就是更新Q-table的方法 - def learn(self, obs, action, reward, next_obs, next_action, done): - """ on-policy - obs: 交互前的obs, s_t - action: 本次交互选择的action, a_t - reward: 本次动作获得的奖励r - next_obs: 本次交互后的obs, s_t+1 - next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1 - done: episode是否结束 - """ - predict_Q = self.Q[obs, action] - if done: # done为ture的话,代表这是episode最后一个状态 - target_Q = reward # 没有下一个状态了 - else: - target_Q = reward + self.gamma * self.Q[next_obs, - next_action] # Sarsa - self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q - - def save(self): - npy_file = './q_table.npy' - np.save(npy_file, self.Q) - print(npy_file + ' saved.') - - def restore(self, npy_file='./q_table.npy'): - self.Q = np.load(npy_file) - print(npy_file + ' loaded.') \ No newline at end of file diff --git a/docs/code/Sarsa/gridworld.py b/docs/code/Sarsa/gridworld.py deleted file mode 100644 index 31d968f..0000000 --- a/docs/code/Sarsa/gridworld.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- - -import gym -import turtle -import numpy as np - -# turtle tutorial : https://docs.python.org/3.3/library/turtle.html - - -def GridWorld(gridmap=None, is_slippery=False): - if gridmap is None: - gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG'] - env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False) - env = FrozenLakeWapper(env) - return env - - -class FrozenLakeWapper(gym.Wrapper): - def __init__(self, env): - gym.Wrapper.__init__(self, env) - self.max_y = env.desc.shape[0] - self.max_x = env.desc.shape[1] - self.t = None - self.unit = 50 - - def draw_box(self, x, y, fillcolor='', line_color='gray'): - self.t.up() - self.t.goto(x * self.unit, y * self.unit) - self.t.color(line_color) - self.t.fillcolor(fillcolor) - self.t.setheading(90) - self.t.down() - self.t.begin_fill() - for _ in range(4): - self.t.forward(self.unit) - self.t.right(90) - self.t.end_fill() - - def move_player(self, x, y): - self.t.up() - self.t.setheading(90) - self.t.fillcolor('red') - self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit) - - def render(self): - if self.t == None: - self.t = turtle.Turtle() - self.wn = turtle.Screen() - self.wn.setup(self.unit * self.max_x + 100, - self.unit * self.max_y + 100) - self.wn.setworldcoordinates(0, 0, self.unit * self.max_x, - self.unit * self.max_y) - self.t.shape('circle') - self.t.width(2) - self.t.speed(0) - self.t.color('gray') - for i in range(self.desc.shape[0]): - for j in range(self.desc.shape[1]): - x = j - y = self.max_y - 1 - i - if self.desc[i][j] == b'S': # Start - self.draw_box(x, y, 'white') - elif self.desc[i][j] == b'F': # Frozen ice - self.draw_box(x, y, 'white') - elif self.desc[i][j] == b'G': # Goal - self.draw_box(x, y, 'yellow') - elif self.desc[i][j] == b'H': # Hole - self.draw_box(x, y, 'black') - else: - self.draw_box(x, y, 'white') - self.t.shape('turtle') - - x_pos = self.s % self.max_x - y_pos = self.max_y - 1 - int(self.s / self.max_x) - self.move_player(x_pos, y_pos) - - -class CliffWalkingWapper(gym.Wrapper): - def __init__(self, env): - gym.Wrapper.__init__(self, env) - self.t = None - self.unit = 50 - self.max_x = 12 - self.max_y = 4 - - def draw_x_line(self, y, x0, x1, color='gray'): - assert x1 > x0 - self.t.color(color) - self.t.setheading(0) - self.t.up() - self.t.goto(x0, y) - self.t.down() - self.t.forward(x1 - x0) - - def draw_y_line(self, x, y0, y1, color='gray'): - assert y1 > y0 - self.t.color(color) - self.t.setheading(90) - self.t.up() - self.t.goto(x, y0) - self.t.down() - self.t.forward(y1 - y0) - - def draw_box(self, x, y, fillcolor='', line_color='gray'): - self.t.up() - self.t.goto(x * self.unit, y * self.unit) - self.t.color(line_color) - self.t.fillcolor(fillcolor) - self.t.setheading(90) - self.t.down() - self.t.begin_fill() - for i in range(4): - self.t.forward(self.unit) - self.t.right(90) - self.t.end_fill() - - def move_player(self, x, y): - self.t.up() - self.t.setheading(90) - self.t.fillcolor('red') - self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit) - - def render(self): - if self.t == None: - self.t = turtle.Turtle() - self.wn = turtle.Screen() - self.wn.setup(self.unit * self.max_x + 100, - self.unit * self.max_y + 100) - self.wn.setworldcoordinates(0, 0, self.unit * self.max_x, - self.unit * self.max_y) - self.t.shape('circle') - self.t.width(2) - self.t.speed(0) - self.t.color('gray') - for _ in range(2): - self.t.forward(self.max_x * self.unit) - self.t.left(90) - self.t.forward(self.max_y * self.unit) - self.t.left(90) - for i in range(1, self.max_y): - self.draw_x_line( - y=i * self.unit, x0=0, x1=self.max_x * self.unit) - for i in range(1, self.max_x): - self.draw_y_line( - x=i * self.unit, y0=0, y1=self.max_y * self.unit) - - for i in range(1, self.max_x - 1): - self.draw_box(i, 0, 'black') - self.draw_box(self.max_x - 1, 0, 'yellow') - self.t.shape('turtle') - - x_pos = self.s % self.max_x - y_pos = self.max_y - 1 - int(self.s / self.max_x) - self.move_player(x_pos, y_pos) - - -if __name__ == '__main__': - # 环境1:FrozenLake, 可以配置冰面是否是滑的 - # 0 left, 1 down, 2 right, 3 up - env = gym.make("FrozenLake-v0", is_slippery=False) - env = FrozenLakeWapper(env) - - # 环境2:CliffWalking, 悬崖环境 - # env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left - # env = CliffWalkingWapper(env) - - # 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal - # gridmap = [ - # 'SFFF', - # 'FHFF', - # 'FFFF', - # 'HFGF' ] - # env = GridWorld(gridmap) - - env.reset() - for step in range(10): - action = np.random.randint(0, 4) - obs, reward, done, info = env.step(action) - print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\ - step, action, obs, reward, done, info)) - # env.render() # 渲染一帧图像 \ No newline at end of file diff --git a/docs/code/Sarsa/train.py b/docs/code/Sarsa/train.py deleted file mode 100644 index d390b9d..0000000 --- a/docs/code/Sarsa/train.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- - -import gym -from gridworld import CliffWalkingWapper, FrozenLakeWapper -from agent import SarsaAgent -import time - - -def run_episode(env, agent, render=False): - total_steps = 0 # 记录每个episode走了多少step - total_reward = 0 - - obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode) - action = agent.sample(obs) # 根据算法选择一个动作 - - while True: - next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互 - next_action = agent.sample(next_obs) # 根据算法选择一个动作 - # 训练 Sarsa 算法 - agent.learn(obs, action, reward, next_obs, next_action, done) - - action = next_action - obs = next_obs # 存储上一个观察值 - total_reward += reward - total_steps += 1 # 计算step数 - if render: - env.render() #渲染新的一帧图形 - if done: - break - return total_reward, total_steps - - -def test_episode(env, agent): - total_reward = 0 - obs = env.reset() - while True: - action = agent.predict(obs) # greedy,只取最优的动作 - next_obs, reward, done, _ = env.step(action) - total_reward += reward - obs = next_obs - time.sleep(0.5) # 每个step延迟0.5秒来看看效果 - env.render() - if done: - print('test reward = %.1f' % (total_reward)) - break - - -def main(): - # env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up - # env = FrozenLakeWapper(env) - - env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left - env = CliffWalkingWapper(env) # 这行不加也可以,这个是为了显示效果更好一点 - - agent = SarsaAgent( - obs_n=env.observation_space.n, - act_n=env.action_space.n, - learning_rate=0.1, - gamma=0.9, - e_greed=0.1) - - is_render = False - for episode in range(500): - ep_reward, ep_steps = run_episode(env, agent, is_render) - print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, - ep_reward)) - - # 每隔20个episode渲染一下看看效果(每个episode都渲染的话,时间会比较长) - if episode % 20 == 0: - is_render = True - else: - is_render = False - # 训练结束,查看算法效果 - test_episode(env, agent) - - -if __name__ == "__main__": - main() \ No newline at end of file