更新蘑菇书附书代码

This commit is contained in:
johnjim0816
2022-12-04 20:54:36 +08:00
parent f030fe283d
commit dc8d13a13e
23 changed files with 10784 additions and 0 deletions

View File

View File

@@ -0,0 +1,232 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 值迭代算法\n",
"作者stzhao\n",
"github: https://github.com/zhaoshitian"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 一、定义环境\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import sys,os\n",
"curr_path = os.path.abspath('')\n",
"parent_path = os.path.dirname(curr_path)\n",
"sys.path.append(parent_path)\n",
"from envs.simple_grid import DrunkenWalkEnv"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"def all_seed(env,seed = 1):\n",
" ## 这个函数主要是为了固定随机种子\n",
" import numpy as np\n",
" import random\n",
" import os\n",
" env.seed(seed) \n",
" np.random.seed(seed)\n",
" random.seed(seed)\n",
" os.environ['PYTHONHASHSEED'] = str(seed) \n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"env = DrunkenWalkEnv(map_name=\"theAlley\")\n",
"all_seed(env, seed = 1) # 设置随机种子为1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 二、价值迭代算法\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def value_iteration(env, theta=0.005, discount_factor=0.9):\n",
" Q = np.zeros((env.nS, env.nA)) # 初始化一个Q表格\n",
" count = 0\n",
" while True:\n",
" delta = 0.0\n",
" Q_tmp = np.zeros((env.nS, env.nA))\n",
" for state in range(env.nS):\n",
" for a in range(env.nA):\n",
" accum = 0.0\n",
" reward_total = 0.0\n",
" for prob, next_state, reward, done in env.P[state][a]:\n",
" accum += prob* np.max(Q[next_state, :])\n",
" reward_total += prob * reward\n",
" Q_tmp[state, a] = reward_total + discount_factor * accum\n",
" delta = max(delta, abs(Q_tmp[state, a] - Q[state, a]))\n",
" Q = Q_tmp\n",
" \n",
" count += 1\n",
" if delta < theta or count > 100: # 这里设置了即使算法没有收敛跑100次也退出循环\n",
" break \n",
" return Q"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.25015697e+22 2.53142659e+22 4.50031394e+22 2.53142659e+22]\n",
" [2.81269621e+22 5.41444021e+22 1.01257064e+23 5.41444021e+22]\n",
" [6.32856648e+22 1.21824905e+23 2.27828393e+23 1.21824905e+23]\n",
" [1.42392746e+23 2.74106036e+23 5.12613885e+23 2.74106036e+23]\n",
" [3.20383678e+23 5.76690620e+23 1.15338124e+24 5.76690620e+23]\n",
" [7.20863276e+23 1.38766181e+24 2.59510779e+24 1.38766181e+24]\n",
" [1.62194237e+24 3.12223906e+24 5.83899253e+24 3.12223906e+24]\n",
" [3.64937033e+24 7.02503789e+24 1.31377332e+25 7.02503789e+24]\n",
" [8.21108325e+24 1.47799498e+25 2.95598997e+25 1.47799498e+25]\n",
" [1.84749373e+25 3.55642543e+25 6.65097743e+25 3.55642543e+25]\n",
" [4.15686089e+25 8.00195722e+25 1.49646992e+26 8.00195722e+25]\n",
" [9.35293701e+25 1.80044037e+26 3.36705732e+26 1.80044037e+26]\n",
" [5.89235032e+26 7.36543790e+26 7.57587898e+26 7.36543790e+26]]\n"
]
}
],
"source": [
"Q = value_iteration(env)\n",
"print(Q)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
]
}
],
"source": [
"policy = np.zeros([env.nS, env.nA]) # 初始化一个策略表格\n",
"for state in range(env.nS):\n",
" best_action = np.argmax(Q[state, :]) #根据价值迭代算法得到的Q表格选择出策略\n",
" policy[state, best_action] = 1\n",
"\n",
"policy = [int(np.argwhere(policy[i]==1)) for i in range(env.nS) ]\n",
"print(policy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 三、测试"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"num_episode = 1000 # 测试1000次\n",
"def test(env,policy):\n",
" \n",
" rewards = [] # 记录所有回合的奖励\n",
" success = [] # 记录该回合是否成功走到终点\n",
" for i_ep in range(num_episode):\n",
" ep_reward = 0 # 记录每个episode的reward\n",
" state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合) 这里state=0\n",
" while True:\n",
" action = policy[state] # 根据算法选择一个动作\n",
" next_state, reward, done, _ = env.step(action) # 与环境进行一个交互\n",
" state = next_state # 更新状态\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" if state==12: # 即走到终点\n",
" success.append(1)\n",
" else:\n",
" success.append(0)\n",
" rewards.append(ep_reward)\n",
" acc_suc = np.array(success).sum()/num_episode\n",
" print(\"测试的成功率是:\", acc_suc)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"测试的成功率是: 0.64\n"
]
}
],
"source": [
"test(env, policy)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.6 ('RL')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "88a829278351aa402b7d6303191a511008218041c5cfdb889d81328a3ea60fbc"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}