233 lines
6.5 KiB
Plaintext
233 lines
6.5 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 值迭代算法\n",
|
||
"作者:stzhao\n",
|
||
"github: https://github.com/zhaoshitian"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 一、定义环境\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import sys,os\n",
|
||
"curr_path = os.path.abspath('')\n",
|
||
"parent_path = os.path.dirname(curr_path)\n",
|
||
"sys.path.append(parent_path)\n",
|
||
"from envs.simple_grid import DrunkenWalkEnv"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 63,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def all_seed(env,seed = 1):\n",
|
||
" ## 这个函数主要是为了固定随机种子\n",
|
||
" import numpy as np\n",
|
||
" import random\n",
|
||
" import os\n",
|
||
" env.seed(seed) \n",
|
||
" np.random.seed(seed)\n",
|
||
" random.seed(seed)\n",
|
||
" os.environ['PYTHONHASHSEED'] = str(seed) \n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"env = DrunkenWalkEnv(map_name=\"theAlley\")\n",
|
||
"all_seed(env, seed = 1) # 设置随机种子为1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 二、价值迭代算法\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def value_iteration(env, theta=0.005, discount_factor=0.9):\n",
|
||
" Q = np.zeros((env.nS, env.nA)) # 初始化一个Q表格\n",
|
||
" count = 0\n",
|
||
" while True:\n",
|
||
" delta = 0.0\n",
|
||
" Q_tmp = np.zeros((env.nS, env.nA))\n",
|
||
" for state in range(env.nS):\n",
|
||
" for a in range(env.nA):\n",
|
||
" accum = 0.0\n",
|
||
" reward_total = 0.0\n",
|
||
" for prob, next_state, reward, done in env.P[state][a]:\n",
|
||
" accum += prob* np.max(Q[next_state, :])\n",
|
||
" reward_total += prob * reward\n",
|
||
" Q_tmp[state, a] = reward_total + discount_factor * accum\n",
|
||
" delta = max(delta, abs(Q_tmp[state, a] - Q[state, a]))\n",
|
||
" Q = Q_tmp\n",
|
||
" \n",
|
||
" count += 1\n",
|
||
" if delta < theta or count > 100: # 这里设置了即使算法没有收敛,跑100次也退出循环\n",
|
||
" break \n",
|
||
" return Q"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[[2.25015697e+22 2.53142659e+22 4.50031394e+22 2.53142659e+22]\n",
|
||
" [2.81269621e+22 5.41444021e+22 1.01257064e+23 5.41444021e+22]\n",
|
||
" [6.32856648e+22 1.21824905e+23 2.27828393e+23 1.21824905e+23]\n",
|
||
" [1.42392746e+23 2.74106036e+23 5.12613885e+23 2.74106036e+23]\n",
|
||
" [3.20383678e+23 5.76690620e+23 1.15338124e+24 5.76690620e+23]\n",
|
||
" [7.20863276e+23 1.38766181e+24 2.59510779e+24 1.38766181e+24]\n",
|
||
" [1.62194237e+24 3.12223906e+24 5.83899253e+24 3.12223906e+24]\n",
|
||
" [3.64937033e+24 7.02503789e+24 1.31377332e+25 7.02503789e+24]\n",
|
||
" [8.21108325e+24 1.47799498e+25 2.95598997e+25 1.47799498e+25]\n",
|
||
" [1.84749373e+25 3.55642543e+25 6.65097743e+25 3.55642543e+25]\n",
|
||
" [4.15686089e+25 8.00195722e+25 1.49646992e+26 8.00195722e+25]\n",
|
||
" [9.35293701e+25 1.80044037e+26 3.36705732e+26 1.80044037e+26]\n",
|
||
" [5.89235032e+26 7.36543790e+26 7.57587898e+26 7.36543790e+26]]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"Q = value_iteration(env)\n",
|
||
"print(Q)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 69,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"policy = np.zeros([env.nS, env.nA]) # 初始化一个策略表格\n",
|
||
"for state in range(env.nS):\n",
|
||
" best_action = np.argmax(Q[state, :]) #根据价值迭代算法得到的Q表格选择出策略\n",
|
||
" policy[state, best_action] = 1\n",
|
||
"\n",
|
||
"policy = [int(np.argwhere(policy[i]==1)) for i in range(env.nS) ]\n",
|
||
"print(policy)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 三、测试"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 78,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"num_episode = 1000 # 测试1000次\n",
|
||
"def test(env,policy):\n",
|
||
" \n",
|
||
" rewards = [] # 记录所有回合的奖励\n",
|
||
" success = [] # 记录该回合是否成功走到终点\n",
|
||
" for i_ep in range(num_episode):\n",
|
||
" ep_reward = 0 # 记录每个episode的reward\n",
|
||
" state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合) 这里state=0\n",
|
||
" while True:\n",
|
||
" action = policy[state] # 根据算法选择一个动作\n",
|
||
" next_state, reward, done, _ = env.step(action) # 与环境进行一个交互\n",
|
||
" state = next_state # 更新状态\n",
|
||
" ep_reward += reward\n",
|
||
" if done:\n",
|
||
" break\n",
|
||
" if state==12: # 即走到终点\n",
|
||
" success.append(1)\n",
|
||
" else:\n",
|
||
" success.append(0)\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" acc_suc = np.array(success).sum()/num_episode\n",
|
||
" print(\"测试的成功率是:\", acc_suc)\n",
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 79,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"测试的成功率是: 0.64\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"test(env, policy)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3.10.6 ('RL')",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.8"
|
||
},
|
||
"orig_nbformat": 4,
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "88a829278351aa402b7d6303191a511008218041c5cfdb889d81328a3ea60fbc"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|