update Sarsa

This commit is contained in:
JohnJim0816
2021-03-12 17:19:04 +08:00
parent 1fdcfbbd66
commit f1394feb65
10 changed files with 147 additions and 357 deletions

View File

@@ -1,74 +1,52 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:58:16
LastEditor: John
LastEditTime: 2021-03-12 17:03:05
Discription:
Environment:
'''
import numpy as np
# 根据Q表格选动作
class SarsaAgent(object):
from collections import defaultdict
import torch
class Sarsa(object):
def __init__(self,
obs_n,
act_n,
learning_rate=0.01,
gamma=0.9,
e_greed=0.1):
self.act_n = act_n # 动作维度,有几个动作可选
self.lr = learning_rate # 学习率
self.gamma = gamma # reward的衰减率
self.epsilon = e_greed # 按一定概率随机选动作
self.Q = np.zeros((obs_n, act_n)) # 初始化Q表格
# 根据输入观察值,采样输出的动作值,带探索(epsilon-greedy训练时用这个方法)
def sample(self, obs):
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
action = self.predict(obs)
else:
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
n_actions,sarsa_cfg,):
self.n_actions = n_actions # number of actions
self.lr = sarsa_cfg.lr # learning rate
self.gamma = sarsa_cfg.gamma
self.epsilon = sarsa_cfg.epsilon
self.Q = defaultdict(lambda: np.zeros(n_actions))
# self.Q = np.zeros((n_states, n_actions)) # Q表
def choose_action(self, state):
best_action = np.argmax(self.Q[state])
# action = best_action
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
action_probs[best_action] += (1.0 - self.epsilon)
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
return action
# 根据输入观察值,预测输出的动作值(已有里面挑最大,贪心的算法,只有利用,没有探索)
def predict(self, obs):
Q_list = self.Q[obs, :]
maxQ = np.max(Q_list) # 找到最大Q对应的下标
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
action = np.random.choice(action_list) # 从这些action中随机挑一个action可以打印出来看看
return action
# 学习方法也就是更新Q-table的方法
def learn(self, obs, action, reward, next_obs, next_action, done):
""" on-policy
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
done: episode是否结束
"""
predict_Q = self.Q[obs, action]
if done: # done为ture的话代表这是episode最后一个状态
target_Q = reward # 没有下一个状态了
def update(self, state, action, reward, next_state, next_action,done):
Q_predict = self.Q[state][action]
if done:
Q_target = reward # terminal state
else:
target_Q = reward + self.gamma * self.Q[next_obs,
next_action] # Sarsa
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
print(npy_file + ' saved.')
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
Q_target = reward + self.gamma * self.Q[next_state][next_action]
self.Q[state][action] += self.lr * (Q_target - Q_predict)
def save(self,path):
'''把 Q表格 的数据保存到文件中
'''
import dill
torch.save(
obj=self.Q,
f=path+"Sarsa_model.pkl",
pickle_module=dill
)
def load(self, path):
'''从文件中读取数据到 Q表格
'''
import dill
self.Q =torch.load(f=path+'Sarsa_model.pkl',pickle_module=dill)