update Sarsa
This commit is contained in:
@@ -1,74 +1,52 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:58:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-03-12 17:03:05
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import numpy as np
|
||||
|
||||
# 根据Q表格选动作
|
||||
class SarsaAgent(object):
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
class Sarsa(object):
|
||||
def __init__(self,
|
||||
obs_n,
|
||||
act_n,
|
||||
learning_rate=0.01,
|
||||
gamma=0.9,
|
||||
e_greed=0.1):
|
||||
self.act_n = act_n # 动作维度,有几个动作可选
|
||||
self.lr = learning_rate # 学习率
|
||||
self.gamma = gamma # reward的衰减率
|
||||
self.epsilon = e_greed # 按一定概率随机选动作
|
||||
self.Q = np.zeros((obs_n, act_n)) # 初始化Q表格
|
||||
|
||||
# 根据输入观察值,采样输出的动作值,带探索(epsilon-greedy,训练时用这个方法)
|
||||
def sample(self, obs):
|
||||
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
|
||||
action = self.predict(obs)
|
||||
else:
|
||||
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
|
||||
n_actions,sarsa_cfg,):
|
||||
self.n_actions = n_actions # number of actions
|
||||
self.lr = sarsa_cfg.lr # learning rate
|
||||
self.gamma = sarsa_cfg.gamma
|
||||
self.epsilon = sarsa_cfg.epsilon
|
||||
self.Q = defaultdict(lambda: np.zeros(n_actions))
|
||||
# self.Q = np.zeros((n_states, n_actions)) # Q表
|
||||
def choose_action(self, state):
|
||||
best_action = np.argmax(self.Q[state])
|
||||
# action = best_action
|
||||
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
|
||||
action_probs[best_action] += (1.0 - self.epsilon)
|
||||
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
|
||||
return action
|
||||
|
||||
# 根据输入观察值,预测输出的动作值(已有里面挑最大,贪心的算法,只有利用,没有探索)
|
||||
def predict(self, obs):
|
||||
Q_list = self.Q[obs, :]
|
||||
maxQ = np.max(Q_list) # 找到最大Q对应的下标
|
||||
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
|
||||
action = np.random.choice(action_list) # 从这些action中随机挑一个action(可以打印出来看看)
|
||||
return action
|
||||
|
||||
# 学习方法,也就是更新Q-table的方法
|
||||
def learn(self, obs, action, reward, next_obs, next_action, done):
|
||||
""" on-policy
|
||||
obs: 交互前的obs, s_t
|
||||
action: 本次交互选择的action, a_t
|
||||
reward: 本次动作获得的奖励r
|
||||
next_obs: 本次交互后的obs, s_t+1
|
||||
next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
|
||||
done: episode是否结束
|
||||
"""
|
||||
predict_Q = self.Q[obs, action]
|
||||
if done: # done为ture的话,代表这是episode最后一个状态
|
||||
target_Q = reward # 没有下一个状态了
|
||||
|
||||
def update(self, state, action, reward, next_state, next_action,done):
|
||||
Q_predict = self.Q[state][action]
|
||||
if done:
|
||||
Q_target = reward # terminal state
|
||||
else:
|
||||
target_Q = reward + self.gamma * self.Q[next_obs,
|
||||
next_action] # Sarsa
|
||||
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
|
||||
|
||||
def save(self):
|
||||
npy_file = './q_table.npy'
|
||||
np.save(npy_file, self.Q)
|
||||
print(npy_file + ' saved.')
|
||||
|
||||
def restore(self, npy_file='./q_table.npy'):
|
||||
self.Q = np.load(npy_file)
|
||||
print(npy_file + ' loaded.')
|
||||
Q_target = reward + self.gamma * self.Q[next_state][next_action]
|
||||
self.Q[state][action] += self.lr * (Q_target - Q_predict)
|
||||
def save(self,path):
|
||||
'''把 Q表格 的数据保存到文件中
|
||||
'''
|
||||
import dill
|
||||
torch.save(
|
||||
obj=self.Q,
|
||||
f=path+"Sarsa_model.pkl",
|
||||
pickle_module=dill
|
||||
)
|
||||
def load(self, path):
|
||||
'''从文件中读取数据到 Q表格
|
||||
'''
|
||||
import dill
|
||||
self.Q =torch.load(f=path+'Sarsa_model.pkl',pickle_module=dill)
|
||||
Reference in New Issue
Block a user