456 lines
94 KiB
Plaintext
456 lines
94 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1、定义算法\n",
|
||
"强化学习算法的模式都比较固定,一般包括sample(即训练时采样动作),predict(测试时预测动作),update(算法更新)以及保存模型和加载模型等几个方法,其中对于每种算法samle和update的方式是不相同,而其他方法就大同小异。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import math\n",
|
||
"from collections import defaultdict\n",
|
||
"\n",
|
||
"class QLearning(object):\n",
|
||
" def __init__(self,n_states,\n",
|
||
" n_actions,cfg):\n",
|
||
" self.n_actions = n_actions \n",
|
||
" self.lr = cfg.lr # 学习率\n",
|
||
" self.gamma = cfg.gamma \n",
|
||
" self.epsilon = cfg.epsilon_start\n",
|
||
" self.sample_count = 0 \n",
|
||
" self.epsilon_start = cfg.epsilon_start\n",
|
||
" self.epsilon_end = cfg.epsilon_end\n",
|
||
" self.epsilon_decay = cfg.epsilon_decay\n",
|
||
" self.Q_table = defaultdict(lambda: np.zeros(n_actions)) # 用嵌套字典存放状态->动作->状态-动作值(Q值)的映射,即Q表\n",
|
||
" def sample_action(self, state):\n",
|
||
" ''' 采样动作,训练时用\n",
|
||
" '''\n",
|
||
" self.sample_count += 1\n",
|
||
" self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \\\n",
|
||
" math.exp(-1. * self.sample_count / self.epsilon_decay) # epsilon是会递减的,这里选择指数递减\n",
|
||
" # e-greedy 策略\n",
|
||
" if np.random.uniform(0, 1) > self.epsilon:\n",
|
||
" action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)最大对应的动作\n",
|
||
" else:\n",
|
||
" action = np.random.choice(self.n_actions) # 随机选择动作\n",
|
||
" return action\n",
|
||
" def predict_action(self,state):\n",
|
||
" ''' 预测或选择动作,测试时用\n",
|
||
" '''\n",
|
||
" action = np.argmax(self.Q_table[str(state)])\n",
|
||
" return action\n",
|
||
" def update(self, state, action, reward, next_state, terminated):\n",
|
||
" Q_predict = self.Q_table[str(state)][action] \n",
|
||
" if terminated: # 终止状态\n",
|
||
" Q_target = reward \n",
|
||
" else:\n",
|
||
" Q_target = reward + self.gamma * np.max(self.Q_table[str(next_state)]) \n",
|
||
" self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2、定义训练\n",
|
||
"强化学习算法的训练方式也比较固定,如下:\n",
|
||
"```python\n",
|
||
"for i_ep in range(train_eps): # 遍历每个回合\n",
|
||
" state = env.reset() # 重置环境,即开始新的回合\n",
|
||
" while True: # 对于一些比较复杂的游戏可以设置每回合最大的步长,例如while ep_step<100,就是每回合最大步长为100。\n",
|
||
" action = agent.sample(state) # 根据算法采样一个动作\n",
|
||
" next_state, reward, terminated, _ = env.step(action) # 与环境进行一次动作交互\n",
|
||
" agent.memory.push(state, action, reward, next_state, terminated) # 记录memory\n",
|
||
" agent.update(state, action, reward, next_state, terminated) # 算法更新\n",
|
||
" state = next_state # 更新状态\n",
|
||
" if terminated:\n",
|
||
" break\n",
|
||
"```\n",
|
||
"首先对于每个回合,回合开始时环境需要重置,好比我们每次开一把游戏需要从头再来一样。我们可以设置智能体在每回合数的最大步长,尤其是对于比较复杂的游戏,这样做的好处之一就是帮助智能体在训练中快速收敛,比如我们先验地知道最优解的大概步数,那么理论上智能体收敛时也应该是这个步数附近,设置最大步数可以方便智能体接近这个最优解。在每个回合中,智能体首先需要采样(sample),或者说采用探索策略例如常见的$\\varepsilon$-greedy策略或者UCB探索策略等等。采样的过程是将当前的状态state作为输入,智能体采样输出动作action。然后环境根据采样出来的动作反馈出下一个状态以及相应的reward等信息。接下来对于具有memory的智能体例如包含replay memory的DQN来说,需要将相应的transition(记住这个词,中文不好翻译,通常是状态、动作、奖励等信息)。紧接着就是智能体更新,对于深度强化学习此时一般从memory中随机采样一些transition进行更新,对于Q learning一般是采样上一次的transition。更新公式是比较关键的部分,但是也很通用,一般基于值的算法更新公式都是一个套路如下:\n",
|
||
"$$\n",
|
||
"y_{j}= \\begin{cases}r_{j} & \\text { for terminal } s_{t+1} \\\\ r_{j}+\\gamma \\max _{a^{\\prime}} Q\\left(s_{t+1}, a^{\\prime} ; \\theta\\right) & \\text { for non-terminal } s_{t+1}\\end{cases}\n",
|
||
"$$\n",
|
||
"智能体更新完之后,通常需要更新状态,即```state = next_state```,然后会检查是否完成了这一回合的游戏,即```terminated==True```,注意完成并不代表这回合成功,也有可能是失败的太离谱,等同学们有了自定义强化学习环境的经验就知道了(等你长大就知道了XD)。\n",
|
||
"如果需要记录奖励、损失等等的话可以再加上,如下方代码,实际项目中更多地使用tensorboard来记录相应的数据,甚至于笔者就在这些教学代码中使用过,但是看起来有些繁琐,容易给大家增加不必要的学习难度,因此学有余力以及需要在项目研究中做强化学习的可以去看看,也很简单。\n",
|
||
"此外稍微复杂一些的强化学习不是一次性写完代码就能收敛的,这时需要我们做一个调参侠。为了检查我们参数调得好不好,可以在终端print出奖励、损失以及epsilon等随着回合数的变化,这点说明一下强化学习的训练过程一般都是先探索然后收敛的,官方的话就是权衡exploration and exploitation。e-greedy策略的做法就是前期探索,然后逐渐减小探索率至慢慢收敛,也就是这个epsilon。这个值越大比如0.9就说明智能体90%的概率在随机探索,通常情况下会设置三个值,epsilon_start、epsilon_end以及epsilon_decay,即初始值、终止值和衰减率,其中初始值一般是0.95不变,终止值是0.01,也就是说即使在收敛阶段也让智能体保持很小概率的探索,这样做的原因就是智能体已经学出了一个不错的策略,但是保不齐还有更好的策略,好比我们知道要出人头地学历高比较重要,但是“人还是要有梦想的,万一实现了呢”,总是存在意外的可能,对吧。回归正题,比较关键的是epsilon_decay这个衰减率,这个epsilon衰减太快了学来的策略往往过拟合,好比一条只能选择一朵花的花道上,你早早选择了一朵看起来还可以的花,却错过了后面更多的好花。但是衰减的太慢会影响收敛的速度,好比你走过了花道的尽头也还没选出一朵花来,相比前者不如更甚。当然强化学习的调参相比于深度学习只能说是有过之无不及,比较复杂,不止epsilon这一个,这就需要同学们的耐心学习了。\n",
|
||
"强化学习测试的代码跟训练基本上是一样的,因此我放到同一个代码段里。相比于训练代码,测试代码主要有以下几点不同:1、测试模型的过程是不需要更新的,这个是不言而喻的;2、测试代码不需要采样(sample)动作,相比之代替的是预测(sample)动作,其区别就是采样动作时可能会使用各种策略例如$\\varepsilon$-greedy策略,而预测动作不需要,只需要根据训练时学习好的Q表或者网络模型代入状态得到动作即可;3、测试过程终端一般只需要看奖励,不需要看epislon等,反正它在测试中也是无意义的。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train(cfg,env,agent):\n",
|
||
" print('开始训练!')\n",
|
||
" print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}')\n",
|
||
" # print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')\n",
|
||
" rewards = [] # 记录奖励\n",
|
||
" for i_ep in range(cfg.train_eps):\n",
|
||
" ep_reward = 0 # 记录每个回合的奖励\n",
|
||
" state = env.reset(seed=cfg.seed) # 重置环境,即开始新的回合\n",
|
||
" while True:\n",
|
||
" action = agent.sample_action(state) # 根据算法采样一个动作\n",
|
||
" next_state, reward, terminated, info = env.step(action) # 与环境进行一次动作交互\n",
|
||
" agent.update(state, action, reward, next_state, terminated) # Q学习算法更新\n",
|
||
" state = next_state # 更新状态\n",
|
||
" ep_reward += reward\n",
|
||
" if terminated:\n",
|
||
" break\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" if (i_ep+1)%20==0:\n",
|
||
" print(f\"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon:.3f}\")\n",
|
||
" print('完成训练!')\n",
|
||
" return {\"rewards\":rewards}\n",
|
||
"def test(cfg,env,agent):\n",
|
||
" print('开始测试!')\n",
|
||
" print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}')\n",
|
||
" rewards = [] # 记录所有回合的奖励\n",
|
||
" for i_ep in range(cfg.test_eps):\n",
|
||
" ep_reward = 0 # 记录每个episode的reward\n",
|
||
" state = env.reset(seed=cfg.seed) # 重置环境, 重新开一局(即开始新的一个回合)\n",
|
||
" while True:\n",
|
||
" action = agent.predict_action(state) # 根据算法选择一个动作\n",
|
||
" next_state, reward, terminated, info = env.step(action) # 与环境进行一个交互\n",
|
||
" state = next_state # 更新状态\n",
|
||
" ep_reward += reward\n",
|
||
" if terminated:\n",
|
||
" break\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" print(f\"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}\")\n",
|
||
" print('完成测试!')\n",
|
||
" return {\"rewards\":rewards}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3、定义环境\n",
|
||
"\n",
|
||
"OpenAI Gym中其实集成了很多强化学习环境,足够大家学习了,但是在做强化学习的应用中免不了要自己创建环境,比如在本项目中其实不太好找到Qlearning能学出来的环境,Qlearning实在是太弱了,需要足够简单的环境才行,因此本项目写了一个环境,大家感兴趣的话可以看一下,一般环境接口最关键的部分即使reset和step。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import gym\n",
|
||
"import turtle\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"# turtle tutorial : https://docs.python.org/3.3/library/turtle.html\n",
|
||
"\n",
|
||
"class CliffWalkingWapper(gym.Wrapper):\n",
|
||
" def __init__(self, env):\n",
|
||
" gym.Wrapper.__init__(self, env)\n",
|
||
" self.t = None\n",
|
||
" self.unit = 50\n",
|
||
" self.max_x = 12\n",
|
||
" self.max_y = 4\n",
|
||
"\n",
|
||
" def draw_x_line(self, y, x0, x1, color='gray'):\n",
|
||
" assert x1 > x0\n",
|
||
" self.t.color(color)\n",
|
||
" self.t.setheading(0)\n",
|
||
" self.t.up()\n",
|
||
" self.t.goto(x0, y)\n",
|
||
" self.t.down()\n",
|
||
" self.t.forward(x1 - x0)\n",
|
||
"\n",
|
||
" def draw_y_line(self, x, y0, y1, color='gray'):\n",
|
||
" assert y1 > y0\n",
|
||
" self.t.color(color)\n",
|
||
" self.t.setheading(90)\n",
|
||
" self.t.up()\n",
|
||
" self.t.goto(x, y0)\n",
|
||
" self.t.down()\n",
|
||
" self.t.forward(y1 - y0)\n",
|
||
"\n",
|
||
" def draw_box(self, x, y, fillcolor='', line_color='gray'):\n",
|
||
" self.t.up()\n",
|
||
" self.t.goto(x * self.unit, y * self.unit)\n",
|
||
" self.t.color(line_color)\n",
|
||
" self.t.fillcolor(fillcolor)\n",
|
||
" self.t.setheading(90)\n",
|
||
" self.t.down()\n",
|
||
" self.t.begin_fill()\n",
|
||
" for i in range(4):\n",
|
||
" self.t.forward(self.unit)\n",
|
||
" self.t.right(90)\n",
|
||
" self.t.end_fill()\n",
|
||
"\n",
|
||
" def move_player(self, x, y):\n",
|
||
" self.t.up()\n",
|
||
" self.t.setheading(90)\n",
|
||
" self.t.fillcolor('red')\n",
|
||
" self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)\n",
|
||
"\n",
|
||
" def render(self):\n",
|
||
" if self.t == None:\n",
|
||
" self.t = turtle.Turtle()\n",
|
||
" self.wn = turtle.Screen()\n",
|
||
" self.wn.setup(self.unit * self.max_x + 100,\n",
|
||
" self.unit * self.max_y + 100)\n",
|
||
" self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,\n",
|
||
" self.unit * self.max_y)\n",
|
||
" self.t.shape('circle')\n",
|
||
" self.t.width(2)\n",
|
||
" self.t.speed(0)\n",
|
||
" self.t.color('gray')\n",
|
||
" for _ in range(2):\n",
|
||
" self.t.forward(self.max_x * self.unit)\n",
|
||
" self.t.left(90)\n",
|
||
" self.t.forward(self.max_y * self.unit)\n",
|
||
" self.t.left(90)\n",
|
||
" for i in range(1, self.max_y):\n",
|
||
" self.draw_x_line(\n",
|
||
" y=i * self.unit, x0=0, x1=self.max_x * self.unit)\n",
|
||
" for i in range(1, self.max_x):\n",
|
||
" self.draw_y_line(\n",
|
||
" x=i * self.unit, y0=0, y1=self.max_y * self.unit)\n",
|
||
"\n",
|
||
" for i in range(1, self.max_x - 1):\n",
|
||
" self.draw_box(i, 0, 'black')\n",
|
||
" self.draw_box(self.max_x - 1, 0, 'yellow')\n",
|
||
" self.t.shape('turtle')\n",
|
||
"\n",
|
||
" x_pos = self.s % self.max_x\n",
|
||
" y_pos = self.max_y - 1 - int(self.s / self.max_x)\n",
|
||
" self.move_player(x_pos, y_pos)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import gym\n",
|
||
"def env_agent_config(cfg,seed=1):\n",
|
||
" '''创建环境和智能体\n",
|
||
" ''' \n",
|
||
" env = gym.make(cfg.env_name,new_step_api=True) \n",
|
||
" env = CliffWalkingWapper(env)\n",
|
||
" n_states = env.observation_space.n # 状态维度\n",
|
||
" n_actions = env.action_space.n # 动作维度\n",
|
||
" agent = QLearning(n_states,n_actions,cfg)\n",
|
||
" return env,agent"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4、设置参数\n",
|
||
"\n",
|
||
"到这里所有qlearning模块就算完成了,下面需要设置一些参数,方便大家“炼丹”,其中默认的是笔者已经调好的~。另外为了定义了一个画图函数,用来描述奖励的变化。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import datetime\n",
|
||
"import argparse\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"class Config:\n",
|
||
" '''配置参数\n",
|
||
" '''\n",
|
||
" def __init__(self):\n",
|
||
" self.env_name = 'CliffWalking-v0' # 环境名称\n",
|
||
" self.algo_name = 'Q-Learning' # 算法名称\n",
|
||
" self.train_eps = 400 # 训练回合数\n",
|
||
" self.test_eps = 20 # 测试回合数\n",
|
||
" self.max_steps = 200 # 每个回合最大步数\n",
|
||
" self.epsilon_start = 0.95 # e-greedy策略中epsilon的初始值\n",
|
||
" self.epsilon_end = 0.01 # e-greedy策略中epsilon的最终值\n",
|
||
" self.epsilon_decay = 300 # e-greedy策略中epsilon的衰减率\n",
|
||
" self.gamma = 0.9 # 折扣因子\n",
|
||
" self.lr = 0.1 # 学习率\n",
|
||
" self.seed = 1 # 随机种子\n",
|
||
"\n",
|
||
"def smooth(data, weight=0.9): \n",
|
||
" '''用于平滑曲线\n",
|
||
" '''\n",
|
||
" last = data[0] # First value in the plot (first timestep)\n",
|
||
" smoothed = list()\n",
|
||
" for point in data:\n",
|
||
" smoothed_val = last * weight + (1 - weight) * point # 计算平滑值\n",
|
||
" smoothed.append(smoothed_val) \n",
|
||
" last = smoothed_val \n",
|
||
" return smoothed\n",
|
||
"\n",
|
||
"def plot_rewards(rewards,title=\"learning curve\"):\n",
|
||
" sns.set()\n",
|
||
" plt.figure() # 创建一个图形实例,方便同时多画几个图\n",
|
||
" plt.title(f\"{title}\")\n",
|
||
" plt.xlim(0, len(rewards)) # 设置x轴的范围\n",
|
||
" plt.xlabel('epsiodes')\n",
|
||
" plt.plot(rewards, label='rewards')\n",
|
||
" plt.plot(smooth(rewards), label='smoothed')\n",
|
||
" plt.legend()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5、我准备好了!\n",
|
||
"\n",
|
||
"到现在我们真的可以像海绵宝宝那样大声说出来“我准备好了!“,跟着注释来看下效果吧~。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/Users/curiousx/miniconda3/lib/python3.11/site-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
||
" deprecation(\n",
|
||
"/Users/curiousx/miniconda3/lib/python3.11/site-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n",
|
||
" if not isinstance(terminated, (bool, np.bool8)):\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"开始训练!\n",
|
||
"环境:CliffWalking-v0, 算法:Q-Learning\n",
|
||
"回合:20/400,奖励:-151.0,Epsilon:0.010\n",
|
||
"回合:40/400,奖励:-167.0,Epsilon:0.010\n",
|
||
"回合:60/400,奖励:-61.0,Epsilon:0.010\n",
|
||
"回合:80/400,奖励:-39.0,Epsilon:0.010\n",
|
||
"回合:100/400,奖励:-52.0,Epsilon:0.010\n",
|
||
"回合:120/400,奖励:-18.0,Epsilon:0.010\n",
|
||
"回合:140/400,奖励:-36.0,Epsilon:0.010\n",
|
||
"回合:160/400,奖励:-14.0,Epsilon:0.010\n",
|
||
"回合:180/400,奖励:-50.0,Epsilon:0.010\n",
|
||
"回合:200/400,奖励:-27.0,Epsilon:0.010\n",
|
||
"回合:220/400,奖励:-35.0,Epsilon:0.010\n",
|
||
"回合:240/400,奖励:-14.0,Epsilon:0.010\n",
|
||
"回合:260/400,奖励:-22.0,Epsilon:0.010\n",
|
||
"回合:280/400,奖励:-17.0,Epsilon:0.010\n",
|
||
"回合:300/400,奖励:-13.0,Epsilon:0.010\n",
|
||
"回合:320/400,奖励:-13.0,Epsilon:0.010\n",
|
||
"回合:340/400,奖励:-13.0,Epsilon:0.010\n",
|
||
"回合:360/400,奖励:-13.0,Epsilon:0.010\n",
|
||
"回合:380/400,奖励:-18.0,Epsilon:0.010\n",
|
||
"回合:400/400,奖励:-13.0,Epsilon:0.010\n",
|
||
"完成训练!\n",
|
||
"开始测试!\n",
|
||
"环境:CliffWalking-v0, 算法:Q-Learning\n",
|
||
"回合数:1/20, 奖励:-13.0\n",
|
||
"回合数:2/20, 奖励:-13.0\n",
|
||
"回合数:3/20, 奖励:-13.0\n",
|
||
"回合数:4/20, 奖励:-13.0\n",
|
||
"回合数:5/20, 奖励:-13.0\n",
|
||
"回合数:6/20, 奖励:-13.0\n",
|
||
"回合数:7/20, 奖励:-13.0\n",
|
||
"回合数:8/20, 奖励:-13.0\n",
|
||
"回合数:9/20, 奖励:-13.0\n",
|
||
"回合数:10/20, 奖励:-13.0\n",
|
||
"回合数:11/20, 奖励:-13.0\n",
|
||
"回合数:12/20, 奖励:-13.0\n",
|
||
"回合数:13/20, 奖励:-13.0\n",
|
||
"回合数:14/20, 奖励:-13.0\n",
|
||
"回合数:15/20, 奖励:-13.0\n",
|
||
"回合数:16/20, 奖励:-13.0\n",
|
||
"回合数:17/20, 奖励:-13.0\n",
|
||
"回合数:18/20, 奖励:-13.0\n",
|
||
"回合数:19/20, 奖励:-13.0\n",
|
||
"回合数:20/20, 奖励:-13.0\n",
|
||
"完成测试!\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 获取参数\n",
|
||
"cfg = Config() \n",
|
||
"# 训练\n",
|
||
"env, agent = env_agent_config(cfg)\n",
|
||
"res_dic = train(cfg, env, agent)\n",
|
||
" \n",
|
||
"plot_rewards(res_dic['rewards'], title=f\"training curve of {cfg.algo_name} for {cfg.env_name}\") \n",
|
||
"# 测试\n",
|
||
"res_dic = test(cfg, env, agent)\n",
|
||
"\n",
|
||
"plot_rewards(res_dic['rewards'], title=f\"testing curve of {cfg.algo_name} for {cfg.env_name}\") # 画出结果"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3.9.13 ('gsc': conda)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.7"
|
||
},
|
||
"orig_nbformat": 4,
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "216dab6b21526179d387c06b08cb2654f2959273fc1353fb08296303e34d0db1"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|