593 lines
97 KiB
Plaintext
593 lines
97 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 86,
|
||
"id": "6218efae",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from torch.optim import Adam\n",
|
||
"from torch.distributions import Normal\n",
|
||
"import random\n",
|
||
"import numpy as np\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8a207689",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1.定义算法\n",
|
||
"### 1.1 建立Q网络和策略网络"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 75,
|
||
"id": "5955151d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class ValueNet(nn.Module):\n",
|
||
" def __init__(self, n_states, hidden_dim, init_w=3e-3):\n",
|
||
" super(ValueNet, self).__init__()\n",
|
||
" '''定义值网络\n",
|
||
" '''\n",
|
||
" self.linear1 = nn.Linear(n_states, hidden_dim) # 输入层\n",
|
||
" self.linear2 = nn.Linear(hidden_dim, hidden_dim) # 隐藏层\n",
|
||
" self.linear3 = nn.Linear(hidden_dim, 1)\n",
|
||
"\n",
|
||
" self.linear3.weight.data.uniform_(-init_w, init_w) # 初始化权重\n",
|
||
" self.linear3.bias.data.uniform_(-init_w, init_w)\n",
|
||
" \n",
|
||
" def forward(self, state):\n",
|
||
" x = F.relu(self.linear1(state))\n",
|
||
" x = F.relu(self.linear2(x))\n",
|
||
" x = self.linear3(x)\n",
|
||
" return x\n",
|
||
" \n",
|
||
" \n",
|
||
"class SoftQNet(nn.Module):\n",
|
||
" def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):\n",
|
||
" super(SoftQNet, self).__init__()\n",
|
||
" '''定义Q网络,n_states, n_actions, hidden_dim, init_w分别为状态维度、动作维度隐藏层维度和初始化权重\n",
|
||
" '''\n",
|
||
" self.linear1 = nn.Linear(n_states + n_actions, hidden_dim)\n",
|
||
" self.linear2 = nn.Linear(hidden_dim, hidden_dim)\n",
|
||
" self.linear3 = nn.Linear(hidden_dim, 1)\n",
|
||
" \n",
|
||
" self.linear3.weight.data.uniform_(-init_w, init_w)\n",
|
||
" self.linear3.bias.data.uniform_(-init_w, init_w)\n",
|
||
" \n",
|
||
" def forward(self, state, action):\n",
|
||
" x = torch.cat([state, action], 1)\n",
|
||
" x = F.relu(self.linear1(x))\n",
|
||
" x = F.relu(self.linear2(x))\n",
|
||
" x = self.linear3(x)\n",
|
||
" return x\n",
|
||
" \n",
|
||
" \n",
|
||
"class PolicyNet(nn.Module):\n",
|
||
" def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):\n",
|
||
" super(PolicyNet, self).__init__()\n",
|
||
" '''定义策略网络,n_states, n_actions, hidden_dim, init_w分别为状态维度、动作维度隐藏层维度和初始化权重\n",
|
||
" log_std_min和log_std_max为标准差对数的最大值和最小值\n",
|
||
" '''\n",
|
||
" self.log_std_min = log_std_min\n",
|
||
" self.log_std_max = log_std_max\n",
|
||
" \n",
|
||
" self.linear1 = nn.Linear(n_states, hidden_dim)\n",
|
||
" self.linear2 = nn.Linear(hidden_dim, hidden_dim)\n",
|
||
" \n",
|
||
" self.mean_linear = nn.Linear(hidden_dim, n_actions)\n",
|
||
" self.mean_linear.weight.data.uniform_(-init_w, init_w)\n",
|
||
" self.mean_linear.bias.data.uniform_(-init_w, init_w)\n",
|
||
" \n",
|
||
" self.log_std_linear = nn.Linear(hidden_dim, n_actions)\n",
|
||
" self.log_std_linear.weight.data.uniform_(-init_w, init_w)\n",
|
||
" self.log_std_linear.bias.data.uniform_(-init_w, init_w)\n",
|
||
" \n",
|
||
" def forward(self, state):\n",
|
||
" x = F.relu(self.linear1(state))\n",
|
||
" x = F.relu(self.linear2(x))\n",
|
||
" \n",
|
||
" mean = self.mean_linear(x)\n",
|
||
" log_std = self.log_std_linear(x)\n",
|
||
" log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)\n",
|
||
" \n",
|
||
" return mean, log_std\n",
|
||
" \n",
|
||
" def evaluate(self, state, epsilon=1e-6):\n",
|
||
" mean, log_std = self.forward(state)\n",
|
||
" std = log_std.exp()\n",
|
||
" ## 计算动作\n",
|
||
" normal = Normal(mean, std)\n",
|
||
" z = normal.sample()\n",
|
||
" action = torch.tanh(z)\n",
|
||
" ## 计算动作概率\n",
|
||
" log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)\n",
|
||
" log_prob = log_prob.sum(-1, keepdim=True)\n",
|
||
" \n",
|
||
" return action, log_prob, z, mean, log_std\n",
|
||
" \n",
|
||
" \n",
|
||
" def get_action(self, state):\n",
|
||
" state = torch.FloatTensor(state).unsqueeze(0)\n",
|
||
" mean, log_std = self.forward(state)\n",
|
||
" std = log_std.exp()\n",
|
||
" \n",
|
||
" normal = Normal(mean, std)\n",
|
||
" z = normal.sample()\n",
|
||
" action = torch.tanh(z)\n",
|
||
" \n",
|
||
" action = action.detach().cpu().numpy()\n",
|
||
" return action[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "debce530",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1.2 定义经验回放池"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 76,
|
||
"id": "1c740ca0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class ReplayBuffer:\n",
|
||
" def __init__(self, capacity):\n",
|
||
" self.capacity = capacity # 经验回放的容量\n",
|
||
" self.buffer = [] # 缓冲区\n",
|
||
" self.position = 0 \n",
|
||
" \n",
|
||
" def push(self, state, action, reward, next_state, done):\n",
|
||
" ''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition)\n",
|
||
" '''\n",
|
||
" if len(self.buffer) < self.capacity:\n",
|
||
" self.buffer.append(None)\n",
|
||
" self.buffer[self.position] = (state, action, reward, next_state, done)\n",
|
||
" self.position = (self.position + 1) % self.capacity \n",
|
||
" \n",
|
||
" def sample(self, batch_size):\n",
|
||
" batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移\n",
|
||
" state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等\n",
|
||
" return state, action, reward, next_state, done\n",
|
||
" \n",
|
||
" def __len__(self):\n",
|
||
" ''' 返回当前存储的量\n",
|
||
" '''\n",
|
||
" return len(self.buffer)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "32a65b71",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1.3 SAC算法"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 77,
|
||
"id": "5a86f725",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SAC:\n",
|
||
" def __init__(self,cfg) -> None:\n",
|
||
" self.n_states = cfg.n_states\n",
|
||
" self.n_actions = cfg.n_actions\n",
|
||
" self.batch_size = cfg.batch_size \n",
|
||
" self.memory = ReplayBuffer(cfg.capacity)\n",
|
||
" self.device = cfg.device\n",
|
||
" self.action_space = cfg.action_space\n",
|
||
" self.value_net = ValueNet(self.n_states, cfg.hidden_dim).to(self.device)\n",
|
||
" self.target_value_net = ValueNet(self.n_states, cfg.hidden_dim).to(self.device)\n",
|
||
" self.soft_q_net = SoftQNet(self.n_states, self.n_actions, cfg.hidden_dim).to(self.device)\n",
|
||
" self.policy_net = PolicyNet(self.n_states, self.n_actions, cfg.hidden_dim).to(self.device) \n",
|
||
" self.value_optimizer = Adam(self.value_net.parameters(), lr=cfg.value_lr)\n",
|
||
" self.soft_q_optimizer = Adam(self.soft_q_net.parameters(), lr=cfg.soft_q_lr)\n",
|
||
" self.policy_optimizer = Adam(self.policy_net.parameters(), lr=cfg.policy_lr) \n",
|
||
" for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):\n",
|
||
" target_param.data.copy_(param.data)\n",
|
||
" self.value_criterion = nn.MSELoss()\n",
|
||
" self.soft_q_criterion = nn.MSELoss()\n",
|
||
" def update(self, gamma=0.99,mean_lambda=1e-3,\n",
|
||
" std_lambda=1e-3,\n",
|
||
" z_lambda=0.0,\n",
|
||
" soft_tau=1e-2,\n",
|
||
" ):\n",
|
||
" if len(self.memory) < self.batch_size: # 当经验回放中不满足一个批量时,不更新策略\n",
|
||
" return \n",
|
||
" state, action, reward, next_state, done = self.memory.sample(self.batch_size) # 从经验回放中随机采样一个批量的转移(transition)\n",
|
||
" # 将数据转换为tensor\n",
|
||
" state = torch.FloatTensor(state).to(self.device)\n",
|
||
" next_state = torch.FloatTensor(next_state).to(self.device)\n",
|
||
" action = torch.FloatTensor(action).to(self.device)\n",
|
||
" reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device)\n",
|
||
" done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)\n",
|
||
" \n",
|
||
" expected_q_value = self.soft_q_net(state, action) #计算t时刻的状态-动作Q值\n",
|
||
" expected_value = self.value_net(state) #计算t时刻的状态值\n",
|
||
" new_action, log_prob, z, mean, log_std = self.policy_net.evaluate(state) #计算t时刻的动作、动作似然概率、正态分布抽样、分布均值和标准差\n",
|
||
"\n",
|
||
"\n",
|
||
" target_value = self.target_value_net(next_state) #计算t+1时刻的状态值\n",
|
||
" next_q_value = reward + (1 - done) * gamma * target_value # 时序差分计算t+1时刻的Q值\n",
|
||
" # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]\n",
|
||
" q_value_loss = self.soft_q_criterion(expected_q_value, next_q_value.detach()) #计算q网路的损失函数\n",
|
||
"\n",
|
||
" expected_new_q_value = self.soft_q_net(state, new_action) #计算t时刻动作对应的q值\n",
|
||
" next_value = expected_new_q_value - log_prob # 计算t时刻的v值\n",
|
||
" value_loss = self.value_criterion(expected_value, next_value.detach()) #计算值网络损失函数\n",
|
||
" \n",
|
||
" ## 计算策略损失\n",
|
||
" log_prob_target = expected_new_q_value - expected_value \n",
|
||
" # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]\n",
|
||
" policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()\n",
|
||
"\n",
|
||
" ## 计算reparameterization参数损失\n",
|
||
" mean_loss = mean_lambda * mean.pow(2).mean()\n",
|
||
" std_loss = std_lambda * log_std.pow(2).mean()\n",
|
||
" z_loss = z_lambda * z.pow(2).sum(1).mean()\n",
|
||
"\n",
|
||
" policy_loss += mean_loss + std_loss + z_loss\n",
|
||
"\n",
|
||
" self.soft_q_optimizer.zero_grad()\n",
|
||
" q_value_loss.backward()\n",
|
||
" self.soft_q_optimizer.step()\n",
|
||
"\n",
|
||
" self.value_optimizer.zero_grad()\n",
|
||
" value_loss.backward()\n",
|
||
" self.value_optimizer.step()\n",
|
||
"\n",
|
||
" self.policy_optimizer.zero_grad()\n",
|
||
" policy_loss.backward()\n",
|
||
" self.policy_optimizer.step()\n",
|
||
" ## 更新目标值网络参数\n",
|
||
" for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):\n",
|
||
" target_param.data.copy_(\n",
|
||
" target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
|
||
" )\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e2581771",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2.模型训练与测试"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 78,
|
||
"id": "0a3e3413",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train(cfg, env, agent):\n",
|
||
" print(\"开始训练!\")\n",
|
||
" rewards = [] # 记录所有回合的奖励\n",
|
||
" for i_ep in range(cfg.train_eps):\n",
|
||
" state = env.reset() # 重置环境,返回初始状态\n",
|
||
" ep_reward = 0 # 记录一回合内的奖励\n",
|
||
" for i_step in range(cfg.max_steps):\n",
|
||
" action = agent.policy_net.get_action(state) # 抽样动作\n",
|
||
" next_state, reward, terminated, info = env.step(action) # 更新环境,返回transitions\n",
|
||
" agent.memory.push(state, action, reward,next_state, terminated) # 保存transition\n",
|
||
" agent.update() # 更新智能体\n",
|
||
" state = next_state # 更新下一个状态\n",
|
||
" ep_reward += reward # 累加奖励\n",
|
||
" if terminated:\n",
|
||
" break\n",
|
||
" if (i_ep+1)%10 == 0:\n",
|
||
" print(f\"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f}\")\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" print(\"完成训练!\")\n",
|
||
" return {'rewards':rewards}\n",
|
||
"def test(cfg, env, agent):\n",
|
||
" print(\"开始测试!\")\n",
|
||
" rewards = [] # 记录所有回合的奖励\n",
|
||
" for i_ep in range(cfg.test_eps):\n",
|
||
" state = env.reset() # 重置环境,返回初始状态\n",
|
||
" ep_reward = 0\n",
|
||
" for i_step in range(cfg.max_steps):\n",
|
||
" action = agent.policy_net.get_action(state) # 抽样动作\n",
|
||
" next_state, reward, terminated, info = env.step(action) # 更新环境,返回transitions\n",
|
||
" state = next_state # 更新下一个状态\n",
|
||
" ep_reward += reward # 累加奖励\n",
|
||
" if terminated:\n",
|
||
" break\n",
|
||
" rewards.append(ep_reward)\n",
|
||
" print(f\"回合:{i_ep+1}/{cfg.test_eps},奖励:{ep_reward:.2f}\")\n",
|
||
" print(\"完成测试!\")\n",
|
||
" return {'rewards':rewards}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d4d45832",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3.定义环境"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 79,
|
||
"id": "15b94efa",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import gym\n",
|
||
"import os\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"class NormalizedActions(gym.ActionWrapper):\n",
|
||
" def action(self, action):\n",
|
||
" low = self.action_space.low\n",
|
||
" high = self.action_space.high\n",
|
||
" \n",
|
||
" action = low + (action + 1.0) * 0.5 * (high - low)\n",
|
||
" action = np.clip(action, low, high)\n",
|
||
" \n",
|
||
" return action\n",
|
||
"\n",
|
||
" def reverse_action(self, action):\n",
|
||
" low = self.action_space.low\n",
|
||
" high = self.action_space.high\n",
|
||
" action = 2 * (action - low) / (high - low) - 1\n",
|
||
" action = np.clip(action, low, high)\n",
|
||
" return action\n",
|
||
" \n",
|
||
"def all_seed(env,seed = 1):\n",
|
||
" ''' 万能的seed函数\n",
|
||
" '''\n",
|
||
" env.seed(seed) # env config\n",
|
||
" np.random.seed(seed)\n",
|
||
" random.seed(seed)\n",
|
||
" torch.manual_seed(seed) # config for CPU\n",
|
||
" torch.cuda.manual_seed(seed) # config for GPU\n",
|
||
" os.environ['PYTHONHASHSEED'] = str(seed) # config for python scripts\n",
|
||
" # config for cudnn\n",
|
||
" torch.backends.cudnn.deterministic = True\n",
|
||
" torch.backends.cudnn.benchmark = False\n",
|
||
" torch.backends.cudnn.enabled = False\n",
|
||
"def env_agent_config(cfg):\n",
|
||
" env = NormalizedActions(gym.make(cfg.env_name)) # 创建环境\n",
|
||
" all_seed(env,seed=cfg.seed)\n",
|
||
" n_states = env.observation_space.shape[0]\n",
|
||
" n_actions = env.action_space.shape[0]\n",
|
||
" print(f\"状态空间维度:{n_states},动作空间维度:{n_actions}\")\n",
|
||
" # 更新n_states和n_actions到cfg参数中\n",
|
||
" setattr(cfg, 'n_states', n_states)\n",
|
||
" setattr(cfg, 'n_actions', n_actions) \n",
|
||
" setattr(cfg, 'action_space', env.action_space) \n",
|
||
" agent = SAC(cfg)\n",
|
||
" return env,agent"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "87423249",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4.设置参数"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 84,
|
||
"id": "fbd710ef",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import argparse\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"class Config:\n",
|
||
" def __init__(self):\n",
|
||
" self.algo_name = 'SAC'\n",
|
||
" self.env_name = 'Pendulum-v1'\n",
|
||
" self.seed = 50 # 随机种子\n",
|
||
" self.train_eps = 400 # 训练迭代次数\n",
|
||
" self.test_eps = 10 # 测试迭代次数\n",
|
||
" self.eval_eps = 10 # 评估迭代次数\n",
|
||
" self.max_steps = 200 # 每次迭代最大时间步\n",
|
||
" self.gamma = 0.99 #折扣因子\n",
|
||
" self.mean_lambda=1e-3 # 重参数化分布均值的损失权重\n",
|
||
" self.std_lambda=1e-3 # 重参数化分布标准差的损失权重\n",
|
||
" self.z_lambda=0.0 # 重参数化分布抽样值的损失权重\n",
|
||
" self.soft_tau=1e-2 # 目标网络软更新系数\n",
|
||
" self.value_lr = 3e-4 # 值网络的学习率\n",
|
||
" self.soft_q_lr = 3e-4 # Q网络的学习率\n",
|
||
" self.policy_lr = 3e-4 # 策略网络的学习率\n",
|
||
" self.capacity = 1000000 # 经验缓存池的大小\n",
|
||
" self.hidden_dim = 256 # 隐藏层维度\n",
|
||
" self.batch_size = 128 # 批次大小\n",
|
||
" self.start_steps = 1000 # 利用前的探索步数\n",
|
||
" self.buffer_size = 1000000 # 经验回放池大小\n",
|
||
" self.device=torch.device(\"cpu\") # 使用设备\n",
|
||
"\n",
|
||
"def smooth(data, weight=0.9): \n",
|
||
" '''用于平滑曲线,类似于Tensorboard中的smooth曲线\n",
|
||
" '''\n",
|
||
" last = data[0] \n",
|
||
" smoothed = []\n",
|
||
" for point in data:\n",
|
||
" smoothed_val = last * weight + (1 - weight) * point # 计算平滑值\n",
|
||
" smoothed.append(smoothed_val) \n",
|
||
" last = smoothed_val \n",
|
||
" return smoothed\n",
|
||
"\n",
|
||
"def plot_rewards(rewards,title=\"learning curve\"):\n",
|
||
" sns.set()\n",
|
||
" plt.figure() # 创建一个图形实例,方便同时多画几个图\n",
|
||
" plt.title(f\"{title}\")\n",
|
||
" plt.xlim(0, len(rewards), 10) # 设置x轴的范围\n",
|
||
" plt.xlabel('epsiodes')\n",
|
||
" plt.plot(rewards, label='rewards')\n",
|
||
" plt.plot(smooth(rewards), label='smoothed')\n",
|
||
" plt.legend()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2bc3f1c6",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5.开始训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 85,
|
||
"id": "80de3242",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"状态空间维度:3,动作空间维度:1\n",
|
||
"开始训练!\n",
|
||
"回合:10/400,奖励:-1165.43\n",
|
||
"回合:20/400,奖励:-994.18\n",
|
||
"回合:30/400,奖励:-634.86\n",
|
||
"回合:40/400,奖励:-883.46\n",
|
||
"回合:50/400,奖励:-745.16\n",
|
||
"回合:60/400,奖励:-617.77\n",
|
||
"回合:70/400,奖励:-0.98\n",
|
||
"回合:80/400,奖励:-386.85\n",
|
||
"回合:90/400,奖励:-245.64\n",
|
||
"回合:100/400,奖励:-118.91\n",
|
||
"回合:110/400,奖励:-250.71\n",
|
||
"回合:120/400,奖励:-245.33\n",
|
||
"回合:130/400,奖励:-125.58\n",
|
||
"回合:140/400,奖励:-124.50\n",
|
||
"回合:150/400,奖励:-121.64\n",
|
||
"回合:160/400,奖励:-244.01\n",
|
||
"回合:170/400,奖励:-121.33\n",
|
||
"回合:180/400,奖励:-239.27\n",
|
||
"回合:190/400,奖励:-127.06\n",
|
||
"回合:200/400,奖励:-122.01\n",
|
||
"回合:210/400,奖励:-126.99\n",
|
||
"回合:220/400,奖励:-348.44\n",
|
||
"回合:230/400,奖励:-116.88\n",
|
||
"回合:240/400,奖励:-124.86\n",
|
||
"回合:250/400,奖励:-121.31\n",
|
||
"回合:260/400,奖励:-3.03\n",
|
||
"回合:270/400,奖励:-125.63\n",
|
||
"回合:280/400,奖励:-244.81\n",
|
||
"回合:290/400,奖励:-123.32\n",
|
||
"回合:300/400,奖励:-119.85\n",
|
||
"回合:310/400,奖励:-121.64\n",
|
||
"回合:320/400,奖励:-4.73\n",
|
||
"回合:330/400,奖励:-127.96\n",
|
||
"回合:340/400,奖励:-119.40\n",
|
||
"回合:350/400,奖励:-244.30\n",
|
||
"回合:360/400,奖励:-121.79\n",
|
||
"回合:370/400,奖励:-244.21\n",
|
||
"回合:380/400,奖励:-123.19\n",
|
||
"回合:390/400,奖励:-341.91\n",
|
||
"回合:400/400,奖励:-117.78\n",
|
||
"完成训练!\n",
|
||
"开始测试!\n",
|
||
"回合:1/10,奖励:-123.43\n",
|
||
"回合:2/10,奖励:-245.39\n",
|
||
"回合:3/10,奖励:-366.64\n",
|
||
"回合:4/10,奖励:-121.86\n",
|
||
"回合:5/10,奖励:-124.73\n",
|
||
"回合:6/10,奖励:-359.53\n",
|
||
"回合:7/10,奖励:-125.78\n",
|
||
"回合:8/10,奖励:-2.40\n",
|
||
"回合:9/10,奖励:-348.00\n",
|
||
"回合:10/10,奖励:-361.15\n",
|
||
"完成测试!\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 获取参数\n",
|
||
"cfg = Config() \n",
|
||
"# 训练\n",
|
||
"env, agent = env_agent_config(cfg)\n",
|
||
"res_dic = train(cfg, env, agent)\n",
|
||
" \n",
|
||
"plot_rewards(res_dic['rewards'], title=f\"training curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}\") \n",
|
||
"# 测试\n",
|
||
"res_dic = test(cfg, env, agent)\n",
|
||
"plot_rewards(res_dic['rewards'], title=f\"testing curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}\") # 画出结果"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "46214798",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "badea21e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python (tensorflowenv)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|