{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 值迭代算法\n", "作者:stzhao\n", "github: https://github.com/zhaoshitian" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 一、定义环境\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import sys,os\n", "curr_path = os.path.abspath('')\n", "parent_path = os.path.dirname(curr_path)\n", "sys.path.append(parent_path)\n", "from envs.simple_grid import DrunkenWalkEnv" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "def all_seed(env,seed = 1):\n", " ## 这个函数主要是为了固定随机种子\n", " import numpy as np\n", " import random\n", " import os\n", " env.seed(seed) \n", " np.random.seed(seed)\n", " random.seed(seed)\n", " os.environ['PYTHONHASHSEED'] = str(seed) \n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "env = DrunkenWalkEnv(map_name=\"theAlley\")\n", "all_seed(env, seed = 1) # 设置随机种子为1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 二、价值迭代算法\n" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "def value_iteration(env, theta=0.005, discount_factor=0.9):\n", " Q = np.zeros((env.nS, env.nA)) # 初始化一个Q表格\n", " count = 0\n", " while True:\n", " delta = 0.0\n", " Q_tmp = np.zeros((env.nS, env.nA))\n", " for state in range(env.nS):\n", " for a in range(env.nA):\n", " accum = 0.0\n", " reward_total = 0.0\n", " for prob, next_state, reward, done in env.P[state][a]:\n", " accum += prob* np.max(Q[next_state, :])\n", " reward_total += prob * reward\n", " Q_tmp[state, a] = reward_total + discount_factor * accum\n", " delta = max(delta, abs(Q_tmp[state, a] - Q[state, a]))\n", " Q = Q_tmp\n", " \n", " count += 1\n", " if delta < theta or count > 100: # 这里设置了即使算法没有收敛,跑100次也退出循环\n", " break \n", " return Q" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[2.25015697e+22 2.53142659e+22 4.50031394e+22 2.53142659e+22]\n", " [2.81269621e+22 5.41444021e+22 1.01257064e+23 5.41444021e+22]\n", " [6.32856648e+22 1.21824905e+23 2.27828393e+23 1.21824905e+23]\n", " [1.42392746e+23 2.74106036e+23 5.12613885e+23 2.74106036e+23]\n", " [3.20383678e+23 5.76690620e+23 1.15338124e+24 5.76690620e+23]\n", " [7.20863276e+23 1.38766181e+24 2.59510779e+24 1.38766181e+24]\n", " [1.62194237e+24 3.12223906e+24 5.83899253e+24 3.12223906e+24]\n", " [3.64937033e+24 7.02503789e+24 1.31377332e+25 7.02503789e+24]\n", " [8.21108325e+24 1.47799498e+25 2.95598997e+25 1.47799498e+25]\n", " [1.84749373e+25 3.55642543e+25 6.65097743e+25 3.55642543e+25]\n", " [4.15686089e+25 8.00195722e+25 1.49646992e+26 8.00195722e+25]\n", " [9.35293701e+25 1.80044037e+26 3.36705732e+26 1.80044037e+26]\n", " [5.89235032e+26 7.36543790e+26 7.57587898e+26 7.36543790e+26]]\n" ] } ], "source": [ "Q = value_iteration(env)\n", "print(Q)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n" ] } ], "source": [ "policy = np.zeros([env.nS, env.nA]) # 初始化一个策略表格\n", "for state in range(env.nS):\n", " best_action = np.argmax(Q[state, :]) #根据价值迭代算法得到的Q表格选择出策略\n", " policy[state, best_action] = 1\n", "\n", "policy = [int(np.argwhere(policy[i]==1)) for i in range(env.nS) ]\n", "print(policy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 三、测试" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "num_episode = 1000 # 测试1000次\n", "def test(env,policy):\n", " \n", " rewards = [] # 记录所有回合的奖励\n", " success = [] # 记录该回合是否成功走到终点\n", " for i_ep in range(num_episode):\n", " ep_reward = 0 # 记录每个episode的reward\n", " state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合) 这里state=0\n", " while True:\n", " action = policy[state] # 根据算法选择一个动作\n", " next_state, reward, done, _ = env.step(action) # 与环境进行一个交互\n", " state = next_state # 更新状态\n", " ep_reward += reward\n", " if done:\n", " break\n", " if state==12: # 即走到终点\n", " success.append(1)\n", " else:\n", " success.append(0)\n", " rewards.append(ep_reward)\n", " acc_suc = np.array(success).sum()/num_episode\n", " print(\"测试的成功率是:\", acc_suc)\n", " " ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "测试的成功率是: 0.64\n" ] } ], "source": [ "test(env, policy)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.6 ('RL')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "88a829278351aa402b7d6303191a511008218041c5cfdb889d81328a3ea60fbc" } } }, "nbformat": 4, "nbformat_minor": 2 }