645 lines
125 KiB
Plaintext
645 lines
125 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. 定义算法\n",
|
||
"\n",
|
||
"DQN 经验回放的算法中的均匀采样算法,可能会忘记一些重要的、以后使用的经验数据。针对这样的问题,PER_DQN 提出了优先级经验回放(prioritized experience reolay)的技术来解决,这种方法应用到 DQN 获得了更好的效果。PER_DQN 成功的原因有:1. 提出了sum tree这样复杂度为O(logn)的高效数据结构。 2. 正确估计了 weighted importance sampling. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1.1、 定义模型\n",
|
||
"\n",
|
||
"这里的 PER_DQN 的模型和 DQN 中类似,也是用的三层的MLP。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"class MLP(nn.Module):\n",
|
||
" def __init__(self, n_states,n_actions,hidden_dim=128):\n",
|
||
" \"\"\" 初始化q网络,为全连接网络\n",
|
||
" \"\"\"\n",
|
||
" super(MLP, self).__init__()\n",
|
||
" self.fc1 = nn.Linear(n_states, hidden_dim) # 输入层\n",
|
||
" self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层\n",
|
||
" self.fc3 = nn.Linear(hidden_dim, n_actions) # 输出层\n",
|
||
" \n",
|
||
" def forward(self, x):\n",
|
||
" # 各层对应的激活函数\n",
|
||
" x = F.relu(self.fc1(x)) \n",
|
||
" x = F.relu(self.fc2(x))\n",
|
||
" return self.fc3(x)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1.2、定义经验回放\n",
|
||
"\n",
|
||
"这里的经验回放是和DQN中最大的不同。它使用了sum tree的数据结构,它是一种特殊的二叉树,其父亲节点的值等于子节点的和。节点上的值,定义为每个样本的优先度,这里就用TDerror来衡量。叶子上的数值就是样本优先度。\n",
|
||
"\n",
|
||
"sum tree 采样过程:根据根节点的priority和采样样本数,划分采样的区间,然后在这些区间中均应采样得到所要选取的样本的优先度。从根节点开始,逐层将样本的优先度和节点的优先度进行对比,最终可以得到所要采样的叶子样本。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import random\n",
|
||
"\n",
|
||
"class SumTree:\n",
|
||
" '''SumTree for the per(Prioritized Experience Replay) DQN. \n",
|
||
" This SumTree code is a modified version and the original code is from:\n",
|
||
" https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py\n",
|
||
" '''\n",
|
||
" def __init__(self, capacity: int):\n",
|
||
" self.capacity = capacity\n",
|
||
" self.data_pointer = 0\n",
|
||
" self.n_entries = 0\n",
|
||
" self.tree = np.zeros(2 * capacity - 1)\n",
|
||
" self.data = np.zeros(capacity, dtype = object)\n",
|
||
"\n",
|
||
" def update(self, tree_idx, p):\n",
|
||
" '''Update the sampling weight\n",
|
||
" '''\n",
|
||
" change = p - self.tree[tree_idx]\n",
|
||
" self.tree[tree_idx] = p\n",
|
||
"\n",
|
||
" while tree_idx != 0:\n",
|
||
" tree_idx = (tree_idx - 1) // 2\n",
|
||
" self.tree[tree_idx] += change\n",
|
||
"\n",
|
||
" def add(self, p, data):\n",
|
||
" '''Adding new data to the sumTree\n",
|
||
" '''\n",
|
||
" tree_idx = self.data_pointer + self.capacity - 1\n",
|
||
" self.data[self.data_pointer] = data\n",
|
||
" # print (\"tree_idx=\", tree_idx)\n",
|
||
" # print (\"nonzero = \", np.count_nonzero(self.tree))\n",
|
||
" self.update(tree_idx, p)\n",
|
||
"\n",
|
||
" self.data_pointer += 1\n",
|
||
" if self.data_pointer >= self.capacity:\n",
|
||
" self.data_pointer = 0\n",
|
||
"\n",
|
||
" if self.n_entries < self.capacity:\n",
|
||
" self.n_entries += 1\n",
|
||
"\n",
|
||
" def get_leaf(self, v):\n",
|
||
" '''Sampling the data\n",
|
||
" '''\n",
|
||
" parent_idx = 0\n",
|
||
" while True:\n",
|
||
" cl_idx = 2 * parent_idx + 1\n",
|
||
" cr_idx = cl_idx + 1\n",
|
||
" if cl_idx >= len(self.tree):\n",
|
||
" leaf_idx = parent_idx\n",
|
||
" break\n",
|
||
" else:\n",
|
||
" if v <= self.tree[cl_idx] :\n",
|
||
" parent_idx = cl_idx\n",
|
||
" else:\n",
|
||
" v -= self.tree[cl_idx]\n",
|
||
" parent_idx = cr_idx\n",
|
||
"\n",
|
||
" data_idx = leaf_idx - self.capacity + 1\n",
|
||
" return leaf_idx, self.tree[leaf_idx], self.data[data_idx]\n",
|
||
"\n",
|
||
" def total(self):\n",
|
||
" return int(self.tree[0])\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"class ReplayTree:\n",
|
||
" '''ReplayTree for the per(Prioritized Experience Replay) DQN. \n",
|
||
" '''\n",
|
||
" def __init__(self, capacity):\n",
|
||
" self.capacity = capacity # the capacity for memory replay\n",
|
||
" self.tree = SumTree(capacity)\n",
|
||
" self.abs_err_upper = 1.\n",
|
||
"\n",
|
||
" ## hyper parameter for calculating the importance sampling weight\n",
|
||
" self.beta_increment_per_sampling = 0.001\n",
|
||
" self.alpha = 0.6\n",
|
||
" self.beta = 0.4\n",
|
||
" self.epsilon = 0.01 \n",
|
||
" self.abs_err_upper = 1.\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" ''' return the num of storage\n",
|
||
" '''\n",
|
||
" return self.tree.total()\n",
|
||
"\n",
|
||
" def push(self, error, sample):\n",
|
||
" '''Push the sample into the replay according to the importance sampling weight\n",
|
||
" '''\n",
|
||
" p = (np.abs(error) + self.epsilon) ** self.alpha\n",
|
||
" self.tree.add(p, sample) \n",
|
||
"\n",
|
||
"\n",
|
||
" def sample(self, batch_size):\n",
|
||
" '''This is for sampling a batch data and the original code is from:\n",
|
||
" https://github.com/rlcode/per/blob/master/prioritized_memory.py\n",
|
||
" '''\n",
|
||
" pri_segment = self.tree.total() / batch_size\n",
|
||
"\n",
|
||
" priorities = []\n",
|
||
" batch = []\n",
|
||
" idxs = []\n",
|
||
"\n",
|
||
" is_weights = []\n",
|
||
"\n",
|
||
" self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])\n",
|
||
" min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total() \n",
|
||
"\n",
|
||
" for i in range(batch_size):\n",
|
||
" a = pri_segment * i\n",
|
||
" b = pri_segment * (i+1)\n",
|
||
"\n",
|
||
" s = random.uniform(a, b)\n",
|
||
" idx, p, data = self.tree.get_leaf(s)\n",
|
||
"\n",
|
||
" priorities.append(p)\n",
|
||
" batch.append(data)\n",
|
||
" idxs.append(idx)\n",
|
||
" prob = p / self.tree.total()\n",
|
||
"\n",
|
||
" sampling_probabilities = np.array(priorities) / self.tree.total()\n",
|
||
" is_weights = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)\n",
|
||
" is_weights /= is_weights.max()\n",
|
||
"\n",
|
||
" return zip(*batch), idxs, is_weights\n",
|
||
" \n",
|
||
" def batch_update(self, tree_idx, abs_errors):\n",
|
||
" '''Update the importance sampling weight\n",
|
||
" '''\n",
|
||
" abs_errors += self.epsilon\n",
|
||
"\n",
|
||
" clipped_errors = np.minimum(abs_errors, self.abs_err_upper)\n",
|
||
" ps = np.power(clipped_errors, self.alpha)\n",
|
||
"\n",
|
||
" for ti, p in zip(tree_idx, ps):\n",
|
||
" self.tree.update(ti, p)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1.3、模型算法定义\n",
|
||
"\n",
|
||
"这里使用三层MLP搭建智能体agent,其中和DQN基本一致,这里不再赘述。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torch.optim as optim\n",
|
||
"import math\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"\n",
|
||
"class PER_DQN:\n",
|
||
" def __init__(self, model, memory, cfg):\n",
|
||
"\n",
|
||
" self.n_actions = cfg.n_actions \n",
|
||
" self.device = torch.device(cfg.device) \n",
|
||
" self.gamma = cfg.gamma \n",
|
||
" ## e-greedy策略相关参数\n",
|
||
" self.sample_count = 0 # 用于epsilon的衰减计数\n",
|
||
" self.epsilon = cfg.epsilon_start\n",
|
||
" self.epsilon_start = cfg.epsilon_start\n",
|
||
" self.epsilon_end = cfg.epsilon_end\n",
|
||
" self.epsilon_decay = cfg.epsilon_decay\n",
|
||
" self.batch_size = cfg.batch_size\n",
|
||
" self.target_update = cfg.target_update\n",
|
||
"\n",
|
||
" self.policy_net = model.to(self.device)\n",
|
||
" self.target_net = model.to(self.device)\n",
|
||
" ## 复制参数到目标网络\n",
|
||
" for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): \n",
|
||
" target_param.data.copy_(param.data)\n",
|
||
" self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) \n",
|
||
" self.memory = memory # SumTree 经验回放\n",
|
||
" self.update_flag = False \n",
|
||
" \n",
|
||
" def sample_action(self, state):\n",
|
||
" ''' sample action with e-greedy policy\n",
|
||
" '''\n",
|
||
" self.sample_count += 1\n",
|
||
" # epsilon 指数衰减\n",
|
||
" self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \\\n",
|
||
" math.exp(-1. * self.sample_count / self.epsilon_decay) \n",
|
||
" if random.random() > self.epsilon:\n",
|
||
" with torch.no_grad():\n",
|
||
" state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)\n",
|
||
" q_values = self.policy_net(state)\n",
|
||
" action = q_values.max(1)[1].item() # 根据Q值选择动作\n",
|
||
" else:\n",
|
||
" action = random.randrange(self.n_actions)\n",
|
||
" return action\n",
|
||
"\n",
|
||
" def predict_action(self,state):\n",
|
||
" ''' 预测动作\n",
|
||
" '''\n",
|
||
" with torch.no_grad():\n",
|
||
" state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)\n",
|
||
" q_values = self.policy_net(state)\n",
|
||
" action = q_values.max(1)[1].item() \n",
|
||
" return action\n",
|
||
" def update(self):\n",
|
||
" if len(self.memory) < self.batch_size: # 不满足一个批量时,不更新策略\n",
|
||
" return\n",
|
||
" else:\n",
|
||
" if not self.update_flag:\n",
|
||
" print(\"Begin to update!\")\n",
|
||
" self.update_flag = True\n",
|
||
" # 采样一个batch\n",
|
||
" (state_batch, action_batch, reward_batch, next_state_batch, done_batch), idxs_batch, is_weights_batch = self.memory.sample(\n",
|
||
" self.batch_size)\n",
|
||
" state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)\n",
|
||
" action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) # shape(batchsize,1)\n",
|
||
" reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)\n",
|
||
" next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)\n",
|
||
" done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1) # shape(batchsize,1)\n",
|
||
" q_value_batch = self.policy_net(state_batch).gather(dim=1, index=action_batch) # shape(batchsize,1),requires_grad=True\n",
|
||
" next_max_q_value_batch = self.target_net(next_state_batch).max(1)[0].detach().unsqueeze(1) \n",
|
||
" expected_q_value_batch = reward_batch + self.gamma * next_max_q_value_batch* (1-done_batch)\n",
|
||
"\n",
|
||
" # loss中根据优先度进行了加权\n",
|
||
" loss = torch.mean(torch.pow((q_value_batch - expected_q_value_batch) * torch.from_numpy(is_weights_batch).to(self.device), 2))\n",
|
||
"\n",
|
||
" # loss = nn.MSELoss()(q_value_batch, expected_q_value_batch) \n",
|
||
"\n",
|
||
" abs_errors = np.sum(np.abs(q_value_batch.cpu().detach().numpy() - expected_q_value_batch.cpu().detach().numpy()), axis=1)\n",
|
||
" # 需要更新样本的优先度\n",
|
||
" self.memory.batch_update(idxs_batch, abs_errors) \n",
|
||
"\n",
|
||
" # 反向传播\n",
|
||
" self.optimizer.zero_grad() \n",
|
||
" loss.backward()\n",
|
||
" # 梯度截断,防止梯度爆炸\n",
|
||
" for param in self.policy_net.parameters(): \n",
|
||
" param.grad.data.clamp_(-1, 1)\n",
|
||
" self.optimizer.step() \n",
|
||
" if self.sample_count % self.target_update == 0: # 更新 target_net\n",
|
||
" self.target_net.load_state_dict(self.policy_net.state_dict()) \n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2、 定义训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train(cfg, env, agent):\n",
|
||
" ''' 训练\n",
|
||
" '''\n",
|
||
" print(\"开始训练!\")\n",
|
||
" rewards = [] # 记录所有回合的奖励\n",
|
||
" steps = []\n",
|
||
" for i_ep in range(cfg.train_eps):\n",
|
||
" ep_reward = 0 # 记录一回合内的奖励\n",
|
||
" ep_step = 0\n",
|
||
" state = env.reset() # 重置环境,返回初始状态\n",
|
||
" for _ in range(cfg.max_steps):\n",
|
||
" ep_step += 1\n",
|
||
" action = agent.sample_action(state) # 选择动作\n",
|
||
" next_state, reward, done, _= env.step(action) # 更新环境,返回transition\n",
|
||
"\n",
|
||
" ## PER DQN 特有的内容\n",
|
||
" policy_val = agent.policy_net(torch.tensor(state, device = cfg.device))[action]\n",
|
||
" target_val = agent.target_net(torch.tensor(next_state, device = cfg.device))\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" error = abs(policy_val - reward)\n",
|
||
" else:\n",
|
||
" error = abs(policy_val - reward - cfg.gamma * torch.max(target_val))\n",
|
||
"\n",
|
||
" agent.memory.push(error.cpu().detach().numpy(), (state, action, reward,\n",
|
||
" next_state, done)) # 保存transition\n",
|
||
" \n",
|
||
" agent.update() # 更新智能体\n",
|
||
" state = next_state # 更新下一个状态\n",
|
||
" ep_reward += reward # 累加奖励\n",
|
||
" if done:\n",
|
||
" break\n",
|
||
" if (i_ep + 1) % cfg.target_update == 0: # 智能体目标网络更新\n",
|
||
" agent.target_net.load_state_dict(agent.policy_net.state_dict())\n",
|
||
" steps.append(ep_step)\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" if (i_ep + 1) % 10 == 0:\n",
|
||
" print(f\"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f},Epislon:{agent.epsilon:.3f}\")\n",
|
||
" print(\"完成训练!\")\n",
|
||
" env.close()\n",
|
||
" return {'rewards':rewards}\n",
|
||
"\n",
|
||
"def test(cfg, env, agent):\n",
|
||
" print(\"开始测试!\")\n",
|
||
" rewards = [] # 记录所有回合的奖励\n",
|
||
" steps = []\n",
|
||
" for i_ep in range(cfg.test_eps):\n",
|
||
" ep_reward = 0 # 记录一回合内的奖励\n",
|
||
" ep_step = 0\n",
|
||
" state = env.reset() # 重置环境,返回初始状态\n",
|
||
" for _ in range(cfg.max_steps):\n",
|
||
" ep_step+=1\n",
|
||
" action = agent.predict_action(state) # 选择动作\n",
|
||
" next_state, reward, done, _ = env.step(action) # 更新环境,返回transition\n",
|
||
" state = next_state # 更新下一个状态\n",
|
||
" ep_reward += reward # 累加奖励\n",
|
||
" if done:\n",
|
||
" break\n",
|
||
" steps.append(ep_step)\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" print(f\"回合:{i_ep+1}/{cfg.test_eps},奖励:{ep_reward:.2f}\")\n",
|
||
" print(\"完成测试\")\n",
|
||
" env.close()\n",
|
||
" return {'rewards':rewards}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. 定义环境"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import gym\n",
|
||
"import os\n",
|
||
"def all_seed(env,seed = 1):\n",
|
||
" ''' 万能的seed函数\n",
|
||
" '''\n",
|
||
" env.seed(seed) # env config\n",
|
||
" np.random.seed(seed)\n",
|
||
" random.seed(seed)\n",
|
||
" torch.manual_seed(seed) # config for CPU\n",
|
||
" torch.cuda.manual_seed(seed) # config for GPU\n",
|
||
" os.environ['PYTHONHASHSEED'] = str(seed) # config for python scripts\n",
|
||
" # config for cudnn\n",
|
||
" torch.backends.cudnn.deterministic = True\n",
|
||
" torch.backends.cudnn.benchmark = False\n",
|
||
" torch.backends.cudnn.enabled = False\n",
|
||
"def env_agent_config(cfg):\n",
|
||
" env = gym.make(cfg.env_name) # 创建环境\n",
|
||
" if cfg.seed !=0:\n",
|
||
" all_seed(env,seed=cfg.seed)\n",
|
||
" n_states = env.observation_space.shape[0]\n",
|
||
" n_actions = env.action_space.n\n",
|
||
" print(f\"状态空间维度:{n_states},动作空间维度:{n_actions}\")\n",
|
||
"\n",
|
||
" cfg.n_actions = env.action_space.n ## set the env action space\n",
|
||
" model = MLP(n_states, n_actions, hidden_dim = cfg.hidden_dim) # 创建模型\n",
|
||
" memory = ReplayTree(cfg.buffer_size)\n",
|
||
" agent = PER_DQN(model,memory,cfg)\n",
|
||
" return env,agent"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4、设置参数"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import argparse\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"class Config():\n",
|
||
" def __init__(self) -> None:\n",
|
||
" self.env_name = \"CartPole-v1\" # 环境名字\n",
|
||
" self.new_step_api = True # 是否用gym的新api\n",
|
||
" self.wrapper = None \n",
|
||
" self.render = False \n",
|
||
" self.algo_name = \"PER_DQN\" # 算法名字\n",
|
||
" self.mode = \"train\" # train or test\n",
|
||
" self.seed = 0 # 随机种子\n",
|
||
" self.device = \"cpu\" # device to use\n",
|
||
" self.train_eps = 100 # 训练的回合数\n",
|
||
" self.test_eps = 20 # 测试的回合数\n",
|
||
" self.eval_eps = 10 # 评估的回合数\n",
|
||
" self.eval_per_episode = 5 # 每个回合的评估次数\n",
|
||
" self.max_steps = 200 # 每个回合的最大步数\n",
|
||
" self.load_checkpoint = False\n",
|
||
" self.load_path = \"tasks\" # 加载模型的路径\n",
|
||
" self.show_fig = False # 是否展示图片\n",
|
||
" self.save_fig = True # 是否存储图片\n",
|
||
"\n",
|
||
" # 设置epsilon值\n",
|
||
" self.epsilon_start = 0.95 # 起始的epsilon值\n",
|
||
" self.epsilon_end = 0.01 # 终止的epsilon值\n",
|
||
" self.epsilon_decay = 500 # 衰减率\n",
|
||
" self.hidden_dim = 256 \n",
|
||
" self.gamma = 0.95 \n",
|
||
" self.lr = 0.0001 \n",
|
||
" self.buffer_size = 100000 # 经验回放的buffer大小\n",
|
||
" self.batch_size = 64 # batch size\n",
|
||
" self.target_update = 4 # 目标网络更新频率\n",
|
||
" self.value_layers = [\n",
|
||
" {'layer_type': 'linear', 'layer_dim': ['n_states', 256],\n",
|
||
" 'activation': 'relu'},\n",
|
||
" {'layer_type': 'linear', 'layer_dim': [256, 256],\n",
|
||
" 'activation': 'relu'},\n",
|
||
" {'layer_type': 'linear', 'layer_dim': [256, 'n_actions'],\n",
|
||
" 'activation': 'none'}]\n",
|
||
"\n",
|
||
"def smooth(data, weight=0.9): \n",
|
||
" '''用于平滑曲线,类似于Tensorboard中的smooth曲线\n",
|
||
" '''\n",
|
||
" last = data[0] \n",
|
||
" smoothed = []\n",
|
||
" for point in data:\n",
|
||
" smoothed_val = last * weight + (1 - weight) * point # 计算平滑值\n",
|
||
" smoothed.append(smoothed_val) \n",
|
||
" last = smoothed_val \n",
|
||
" return smoothed\n",
|
||
"\n",
|
||
"def plot_rewards(rewards,cfg, tag='train'):\n",
|
||
" ''' 画图\n",
|
||
" '''\n",
|
||
" sns.set()\n",
|
||
" plt.figure() # 创建一个图形实例,方便同时多画几个图\n",
|
||
" plt.title(f\"{tag}ing curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}\")\n",
|
||
" plt.xlabel('epsiodes')\n",
|
||
" plt.plot(rewards, label='rewards')\n",
|
||
" plt.plot(smooth(rewards), label='smoothed')\n",
|
||
" plt.legend()\n",
|
||
" plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5、开始训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"状态空间维度:4,动作空间维度:2\n",
|
||
"开始训练!\n",
|
||
"Begin to update!\n",
|
||
"回合:10/100,奖励:10.00,Epislon:0.569\n",
|
||
"回合:20/100,奖励:12.00,Epislon:0.443\n",
|
||
"回合:30/100,奖励:10.00,Epislon:0.357\n",
|
||
"回合:40/100,奖励:15.00,Epislon:0.268\n",
|
||
"回合:50/100,奖励:125.00,Epislon:0.103\n",
|
||
"回合:60/100,奖励:67.00,Epislon:0.024\n",
|
||
"回合:70/100,奖励:200.00,Epislon:0.012\n",
|
||
"回合:80/100,奖励:200.00,Epislon:0.010\n",
|
||
"回合:90/100,奖励:200.00,Epislon:0.010\n",
|
||
"回合:100/100,奖励:200.00,Epislon:0.010\n",
|
||
"完成训练!\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"开始测试!\n",
|
||
"回合:1/20,奖励:200.00\n",
|
||
"回合:2/20,奖励:200.00\n",
|
||
"回合:3/20,奖励:200.00\n",
|
||
"回合:4/20,奖励:200.00\n",
|
||
"回合:5/20,奖励:200.00\n",
|
||
"回合:6/20,奖励:200.00\n",
|
||
"回合:7/20,奖励:200.00\n",
|
||
"回合:8/20,奖励:200.00\n",
|
||
"回合:9/20,奖励:200.00\n",
|
||
"回合:10/20,奖励:200.00\n",
|
||
"回合:11/20,奖励:200.00\n",
|
||
"回合:12/20,奖励:200.00\n",
|
||
"回合:13/20,奖励:200.00\n",
|
||
"回合:14/20,奖励:200.00\n",
|
||
"回合:15/20,奖励:200.00\n",
|
||
"回合:16/20,奖励:200.00\n",
|
||
"回合:17/20,奖励:200.00\n",
|
||
"回合:18/20,奖励:200.00\n",
|
||
"回合:19/20,奖励:200.00\n",
|
||
"回合:20/20,奖励:200.00\n",
|
||
"完成测试\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 获取参数\n",
|
||
"cfg = Config() \n",
|
||
"# 训练\n",
|
||
"env, agent = env_agent_config(cfg)\n",
|
||
"res_dic = train(cfg, env, agent)\n",
|
||
" \n",
|
||
"plot_rewards(res_dic['rewards'], cfg, tag=\"train\") \n",
|
||
"# 测试\n",
|
||
"res_dic = test(cfg, env, agent)\n",
|
||
"plot_rewards(res_dic['rewards'], cfg, tag=\"test\") # 画出结果"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3.7.13 ('joyrl')",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.13"
|
||
},
|
||
"orig_nbformat": 4,
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "996e2c1bcfa8ebbd3aba48733c28d7658f0aec7cda7e9a0e5abbef50d3f90575"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|