update
This commit is contained in:
@@ -1,3 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-10-07 20:48:29
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -12,64 +23,72 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class QLearningAgent(object):
|
||||
class QLearning(object):
|
||||
def __init__(self,
|
||||
obs_n,
|
||||
act_n,
|
||||
obs_dim,
|
||||
action_dim,
|
||||
learning_rate=0.01,
|
||||
gamma=0.9,
|
||||
e_greed=0.1):
|
||||
self.act_n = act_n # 动作维度,有几个动作可选
|
||||
epsilon_start=0.9,epsilon_end=0.1,epsilon_decay=200):
|
||||
self.action_dim = action_dim # 动作维度,有几个动作可选
|
||||
self.lr = learning_rate # 学习率
|
||||
self.gamma = gamma # reward的衰减率
|
||||
self.epsilon = e_greed # 按一定概率随机选动作
|
||||
self.Q = np.zeros((obs_n, act_n))
|
||||
self.gamma = gamma # reward 的衰减率
|
||||
self.epsilon = 0 # 按一定概率随机选动作,即 e-greedy 策略, 并且epsilon逐渐衰减
|
||||
self.sample_count = 0 # epsilon随训练的也就是采样次数逐渐衰减,所以需要计数
|
||||
self.epsilon_start = epsilon_start
|
||||
self.epsilon_end = epsilon_end
|
||||
self.epsilon_decay= epsilon_decay
|
||||
self.Q_table = np.zeros((obs_dim, action_dim)) # Q表
|
||||
|
||||
# 根据输入观察值,采样输出的动作值,带探索
|
||||
def sample(self, obs):
|
||||
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
|
||||
'''根据输入观测值,采样输出的动作值,带探索,训练模型时使用
|
||||
'''
|
||||
self.sample_count += 1
|
||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||
math.exp(-1. * self.sample_count / self.epsilon_decay)
|
||||
if np.random.uniform(0, 1) > self.epsilon: # 随机选取0-1之间的值,如果大于epsilon就按照贪心策略选取action,否则随机选取
|
||||
action = self.predict(obs)
|
||||
else:
|
||||
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
|
||||
action = np.random.choice(self.action_dim) #有一定概率随机探索选取一个动作
|
||||
return action
|
||||
|
||||
# 根据输入观察值,预测输出的动作值
|
||||
def predict(self, obs):
|
||||
Q_list = self.Q[obs, :]
|
||||
maxQ = np.max(Q_list)
|
||||
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
|
||||
action = np.random.choice(action_list)
|
||||
'''根据输入观测值,采样输出的动作值,带探索,测试模型时使用
|
||||
'''
|
||||
Q_list = self.Q_table[obs, :]
|
||||
Q_max = np.max(Q_list)
|
||||
action_list = np.where(Q_list == Q_max)[0]
|
||||
action = np.random.choice(action_list) # Q_max可能对应多个 action ,可以随机抽取一个
|
||||
return action
|
||||
|
||||
# 学习方法,也就是更新Q-table的方法
|
||||
def learn(self, obs, action, reward, next_obs, done):
|
||||
""" off-policy
|
||||
obs: 交互前的obs, s_t
|
||||
action: 本次交互选择的action, a_t
|
||||
reward: 本次动作获得的奖励r
|
||||
next_obs: 本次交互后的obs, s_t+1
|
||||
done: episode是否结束
|
||||
"""
|
||||
predict_Q = self.Q[obs, action]
|
||||
'''学习方法(off-policy),也就是更新Q-table的方法
|
||||
Args:
|
||||
obs [type]: 交互前的obs, s_t
|
||||
action [type]: 本次交互选择的action, a_t
|
||||
reward [type]: 本次动作获得的奖励r
|
||||
next_obs [type]: 本次交互后的obs, s_t+1
|
||||
done function: episode是否结束
|
||||
'''
|
||||
Q_predict = self.Q_table[obs, action]
|
||||
if done:
|
||||
target_Q = reward # 没有下一个状态了
|
||||
Q_target = reward # 没有下一个状态了
|
||||
else:
|
||||
target_Q = reward + self.gamma * np.max(
|
||||
self.Q[next_obs, :]) # Q-learning
|
||||
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
|
||||
Q_target = reward + self.gamma * np.max(
|
||||
self.Q_table[next_obs, :]) # Q_table-learning
|
||||
self.Q_table[obs, action] += self.lr * (Q_target - Q_predict) # 修正q
|
||||
|
||||
# 把 Q表格 的数据保存到文件中
|
||||
def save(self):
|
||||
npy_file = './q_table.npy'
|
||||
np.save(npy_file, self.Q)
|
||||
'''把 Q表格 的数据保存到文件中
|
||||
'''
|
||||
npy_file = './Q_table.npy'
|
||||
np.save(npy_file, self.Q_table)
|
||||
print(npy_file + ' saved.')
|
||||
|
||||
# 从文件中读取数据到 Q表格
|
||||
def restore(self, npy_file='./q_table.npy'):
|
||||
self.Q = np.load(npy_file)
|
||||
print(npy_file + ' loaded.')
|
||||
def load(self, npy_file='./Q_table.npy'):
|
||||
'''从文件中读取数据到 Q表格
|
||||
'''
|
||||
self.Q_table = np.load(npy_file)
|
||||
print(npy_file + 'loaded.')
|
||||
Reference in New Issue
Block a user