add some code
This commit is contained in:
74
docs/code/Sarsa/agent.py
Normal file
74
docs/code/Sarsa/agent.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
|
||||
# 根据Q表格选动作
|
||||
class SarsaAgent(object):
|
||||
def __init__(self,
|
||||
obs_n,
|
||||
act_n,
|
||||
learning_rate=0.01,
|
||||
gamma=0.9,
|
||||
e_greed=0.1):
|
||||
self.act_n = act_n # 动作维度,有几个动作可选
|
||||
self.lr = learning_rate # 学习率
|
||||
self.gamma = gamma # reward的衰减率
|
||||
self.epsilon = e_greed # 按一定概率随机选动作
|
||||
self.Q = np.zeros((obs_n, act_n)) # 初始化Q表格
|
||||
|
||||
# 根据输入观察值,采样输出的动作值,带探索(epsilon-greedy,训练时用这个方法)
|
||||
def sample(self, obs):
|
||||
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
|
||||
action = self.predict(obs)
|
||||
else:
|
||||
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
|
||||
return action
|
||||
|
||||
# 根据输入观察值,预测输出的动作值(已有里面挑最大,贪心的算法,只有利用,没有探索)
|
||||
def predict(self, obs):
|
||||
Q_list = self.Q[obs, :]
|
||||
maxQ = np.max(Q_list) # 找到最大Q对应的下标
|
||||
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
|
||||
action = np.random.choice(action_list) # 从这些action中随机挑一个action(可以打印出来看看)
|
||||
return action
|
||||
|
||||
# 学习方法,也就是更新Q-table的方法
|
||||
def learn(self, obs, action, reward, next_obs, next_action, done):
|
||||
""" on-policy
|
||||
obs: 交互前的obs, s_t
|
||||
action: 本次交互选择的action, a_t
|
||||
reward: 本次动作获得的奖励r
|
||||
next_obs: 本次交互后的obs, s_t+1
|
||||
next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
|
||||
done: episode是否结束
|
||||
"""
|
||||
predict_Q = self.Q[obs, action]
|
||||
if done: # done为ture的话,代表这是episode最后一个状态
|
||||
target_Q = reward # 没有下一个状态了
|
||||
else:
|
||||
target_Q = reward + self.gamma * self.Q[next_obs,
|
||||
next_action] # Sarsa
|
||||
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
|
||||
|
||||
def save(self):
|
||||
npy_file = './q_table.npy'
|
||||
np.save(npy_file, self.Q)
|
||||
print(npy_file + ' saved.')
|
||||
|
||||
def restore(self, npy_file='./q_table.npy'):
|
||||
self.Q = np.load(npy_file)
|
||||
print(npy_file + ' loaded.')
|
||||
Reference in New Issue
Block a user