修复DoubleDQN
This commit is contained in:
@@ -165,7 +165,7 @@
|
|||||||
" # 将数据转换为tensor\n",
|
" # 将数据转换为tensor\n",
|
||||||
" state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)\n",
|
" state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)\n",
|
||||||
" action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) \n",
|
" action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) \n",
|
||||||
" reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float) \n",
|
" reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) \n",
|
||||||
" next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float)\n",
|
" next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float)\n",
|
||||||
" done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1)\n",
|
" done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1)\n",
|
||||||
" q_value_batch = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 实际的Q值\n",
|
" q_value_batch = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 实际的Q值\n",
|
||||||
@@ -418,21 +418,7 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"开始训练!\n",
|
"开始训练!\n",
|
||||||
"开始更新策略!\n"
|
"开始更新策略!\n",
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"c:\\Users\\24438\\anaconda3\\envs\\easyrl\\lib\\site-packages\\torch\\nn\\modules\\loss.py:520: UserWarning: Using a target size (torch.Size([64, 64])) that is different to the input size (torch.Size([64, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
|
|
||||||
" return F.mse_loss(input, target, reduction=self.reduction)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"回合:10/100,奖励:24.00,Epislon:0.663\n",
|
"回合:10/100,奖励:24.00,Epislon:0.663\n",
|
||||||
"回合:20/100,奖励:10.00,Epislon:0.508\n",
|
"回合:20/100,奖励:10.00,Epislon:0.508\n",
|
||||||
"回合:30/100,奖励:10.00,Epislon:0.395\n",
|
"回合:30/100,奖励:10.00,Epislon:0.395\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user