This commit is contained in:
JohnJim0816
2020-10-07 21:47:25 +08:00
parent 07b835663a
commit 5fe8bfc6c1
23 changed files with 378 additions and 139 deletions

View File

@@ -1,3 +1,14 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2020-09-11 23:03:00
LastEditor: John
LastEditTime: 2020-10-07 20:48:29
Discription:
Environment:
'''
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,64 +23,72 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import numpy as np
import math
class QLearningAgent(object):
class QLearning(object):
def __init__(self,
obs_n,
act_n,
obs_dim,
action_dim,
learning_rate=0.01,
gamma=0.9,
e_greed=0.1):
self.act_n = act_n # 动作维度,有几个动作可选
epsilon_start=0.9,epsilon_end=0.1,epsilon_decay=200):
self.action_dim = action_dim # 动作维度,有几个动作可选
self.lr = learning_rate # 学习率
self.gamma = gamma # reward的衰减率
self.epsilon = e_greed # 按一定概率随机选动作
self.Q = np.zeros((obs_n, act_n))
self.gamma = gamma # reward 的衰减率
self.epsilon = 0 # 按一定概率随机选动作,即 e-greedy 策略, 并且epsilon逐渐衰减
self.sample_count = 0 # epsilon随训练的也就是采样次数逐渐衰减所以需要计数
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay= epsilon_decay
self.Q_table = np.zeros((obs_dim, action_dim)) # Q表
# 根据输入观察值,采样输出的动作值,带探索
def sample(self, obs):
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
'''根据输入观测值,采样输出的动作值,带探索,训练模型时使用
'''
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay)
if np.random.uniform(0, 1) > self.epsilon: # 随机选取0-1之间的值如果大于epsilon就按照贪心策略选取action否则随机选取
action = self.predict(obs)
else:
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
action = np.random.choice(self.action_dim) #有一定概率随机探索选取一个动作
return action
# 根据输入观察值,预测输出的动作值
def predict(self, obs):
Q_list = self.Q[obs, :]
maxQ = np.max(Q_list)
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
action = np.random.choice(action_list)
'''根据输入观测值,采样输出的动作值,带探索,测试模型时使用
'''
Q_list = self.Q_table[obs, :]
Q_max = np.max(Q_list)
action_list = np.where(Q_list == Q_max)[0]
action = np.random.choice(action_list) # Q_max可能对应多个 action ,可以随机抽取一个
return action
# 学习方法也就是更新Q-table的方法
def learn(self, obs, action, reward, next_obs, done):
""" off-policy
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
done: episode是否结束
"""
predict_Q = self.Q[obs, action]
'''学习方法(off-policy)也就是更新Q-table的方法
Args:
obs [type]: 交互前的obs, s_t
action [type]: 本次交互选择的action, a_t
reward [type]: 本次动作获得的奖励r
next_obs [type]: 本次交互后的obs, s_t+1
done function: episode是否结束
'''
Q_predict = self.Q_table[obs, action]
if done:
target_Q = reward # 没有下一个状态了
Q_target = reward # 没有下一个状态了
else:
target_Q = reward + self.gamma * np.max(
self.Q[next_obs, :]) # Q-learning
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
Q_target = reward + self.gamma * np.max(
self.Q_table[next_obs, :]) # Q_table-learning
self.Q_table[obs, action] += self.lr * (Q_target - Q_predict) # 修正q
# 把 Q表格 的数据保存到文件中
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
'''把 Q表格 的数据保存到文件中
'''
npy_file = './Q_table.npy'
np.save(npy_file, self.Q_table)
print(npy_file + ' saved.')
# 从文件中读取数据到 Q表格
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
def load(self, npy_file='./Q_table.npy'):
'''从文件中读取数据到 Q表格
'''
self.Q_table = np.load(npy_file)
print(npy_file + 'loaded.')