Files
easy-rl/notebooks/Value Iteration/value_iteration.ipynb
2022-12-04 20:54:36 +08:00

233 lines
6.5 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 值迭代算法\n",
"作者stzhao\n",
"github: https://github.com/zhaoshitian"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 一、定义环境\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import sys,os\n",
"curr_path = os.path.abspath('')\n",
"parent_path = os.path.dirname(curr_path)\n",
"sys.path.append(parent_path)\n",
"from envs.simple_grid import DrunkenWalkEnv"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"def all_seed(env,seed = 1):\n",
" ## 这个函数主要是为了固定随机种子\n",
" import numpy as np\n",
" import random\n",
" import os\n",
" env.seed(seed) \n",
" np.random.seed(seed)\n",
" random.seed(seed)\n",
" os.environ['PYTHONHASHSEED'] = str(seed) \n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"env = DrunkenWalkEnv(map_name=\"theAlley\")\n",
"all_seed(env, seed = 1) # 设置随机种子为1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 二、价值迭代算法\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def value_iteration(env, theta=0.005, discount_factor=0.9):\n",
" Q = np.zeros((env.nS, env.nA)) # 初始化一个Q表格\n",
" count = 0\n",
" while True:\n",
" delta = 0.0\n",
" Q_tmp = np.zeros((env.nS, env.nA))\n",
" for state in range(env.nS):\n",
" for a in range(env.nA):\n",
" accum = 0.0\n",
" reward_total = 0.0\n",
" for prob, next_state, reward, done in env.P[state][a]:\n",
" accum += prob* np.max(Q[next_state, :])\n",
" reward_total += prob * reward\n",
" Q_tmp[state, a] = reward_total + discount_factor * accum\n",
" delta = max(delta, abs(Q_tmp[state, a] - Q[state, a]))\n",
" Q = Q_tmp\n",
" \n",
" count += 1\n",
" if delta < theta or count > 100: # 这里设置了即使算法没有收敛跑100次也退出循环\n",
" break \n",
" return Q"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.25015697e+22 2.53142659e+22 4.50031394e+22 2.53142659e+22]\n",
" [2.81269621e+22 5.41444021e+22 1.01257064e+23 5.41444021e+22]\n",
" [6.32856648e+22 1.21824905e+23 2.27828393e+23 1.21824905e+23]\n",
" [1.42392746e+23 2.74106036e+23 5.12613885e+23 2.74106036e+23]\n",
" [3.20383678e+23 5.76690620e+23 1.15338124e+24 5.76690620e+23]\n",
" [7.20863276e+23 1.38766181e+24 2.59510779e+24 1.38766181e+24]\n",
" [1.62194237e+24 3.12223906e+24 5.83899253e+24 3.12223906e+24]\n",
" [3.64937033e+24 7.02503789e+24 1.31377332e+25 7.02503789e+24]\n",
" [8.21108325e+24 1.47799498e+25 2.95598997e+25 1.47799498e+25]\n",
" [1.84749373e+25 3.55642543e+25 6.65097743e+25 3.55642543e+25]\n",
" [4.15686089e+25 8.00195722e+25 1.49646992e+26 8.00195722e+25]\n",
" [9.35293701e+25 1.80044037e+26 3.36705732e+26 1.80044037e+26]\n",
" [5.89235032e+26 7.36543790e+26 7.57587898e+26 7.36543790e+26]]\n"
]
}
],
"source": [
"Q = value_iteration(env)\n",
"print(Q)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
]
}
],
"source": [
"policy = np.zeros([env.nS, env.nA]) # 初始化一个策略表格\n",
"for state in range(env.nS):\n",
" best_action = np.argmax(Q[state, :]) #根据价值迭代算法得到的Q表格选择出策略\n",
" policy[state, best_action] = 1\n",
"\n",
"policy = [int(np.argwhere(policy[i]==1)) for i in range(env.nS) ]\n",
"print(policy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 三、测试"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"num_episode = 1000 # 测试1000次\n",
"def test(env,policy):\n",
" \n",
" rewards = [] # 记录所有回合的奖励\n",
" success = [] # 记录该回合是否成功走到终点\n",
" for i_ep in range(num_episode):\n",
" ep_reward = 0 # 记录每个episode的reward\n",
" state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合) 这里state=0\n",
" while True:\n",
" action = policy[state] # 根据算法选择一个动作\n",
" next_state, reward, done, _ = env.step(action) # 与环境进行一个交互\n",
" state = next_state # 更新状态\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" if state==12: # 即走到终点\n",
" success.append(1)\n",
" else:\n",
" success.append(0)\n",
" rewards.append(ep_reward)\n",
" acc_suc = np.array(success).sum()/num_episode\n",
" print(\"测试的成功率是:\", acc_suc)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"测试的成功率是: 0.64\n"
]
}
],
"source": [
"test(env, policy)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.6 ('RL')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "88a829278351aa402b7d6303191a511008218041c5cfdb889d81328a3ea60fbc"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}