From ab3f84f1673e93991ce21f65a9ce3cead2002054 Mon Sep 17 00:00:00 2001 From: johnjim0816 <39483938+johnjim0816@users.noreply.github.com> Date: Sat, 14 Jan 2023 15:24:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DDoubleDQN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- notebooks/DoubleDQN.ipynb | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/notebooks/DoubleDQN.ipynb b/notebooks/DoubleDQN.ipynb index 5a98124..ffae5d2 100644 --- a/notebooks/DoubleDQN.ipynb +++ b/notebooks/DoubleDQN.ipynb @@ -165,7 +165,7 @@ " # 将数据转换为tensor\n", " state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)\n", " action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) \n", - " reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float) \n", + " reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) \n", " next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float)\n", " done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1)\n", " q_value_batch = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 实际的Q值\n", @@ -418,21 +418,7 @@ "output_type": "stream", "text": [ "开始训练!\n", - "开始更新策略!\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\24438\\anaconda3\\envs\\easyrl\\lib\\site-packages\\torch\\nn\\modules\\loss.py:520: UserWarning: Using a target size (torch.Size([64, 64])) that is different to the input size (torch.Size([64, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", - " return F.mse_loss(input, target, reduction=self.reduction)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "开始更新策略!\n", "回合:10/100,奖励:24.00,Epislon:0.663\n", "回合:20/100,奖励:10.00,Epislon:0.508\n", "回合:30/100,奖励:10.00,Epislon:0.395\n",