hot update Double DQN

This commit is contained in:
johnjim0816
2022-08-30 16:29:57 +08:00
parent 0b0f7e857d
commit 764ba63d40
26 changed files with 803 additions and 365 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2022-08-23 23:59:54
LastEditTime: 2022-08-29 23:30:08
@Discription:
@Environment: python 3.7.7
'''
@@ -78,7 +78,7 @@ class DQN:
self.batch_size)
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) # shape(batchsize,1)
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize)
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1) # shape(batchsize,1)
# print(state_batch.shape,action_batch.shape,reward_batch.shape,next_state_batch.shape,done_batch.shape)
@@ -91,7 +91,7 @@ class DQN:
# compute expected q value, for terminal state, done_batch[0]=1, and expected_q_value=rewardcorrespondingly
expected_q_value_batch = reward_batch + self.gamma * next_max_q_value_batch* (1-done_batch)
# print(expected_q_value_batch.shape,expected_q_value_batch.requires_grad)
loss = nn.MSELoss()(q_value_batch, expected_q_value_batch) # shape same to
loss = nn.MSELoss()(q_value_batch, expected_q_value_batch) # shape same to
# backpropagation
self.optimizer.zero_grad()
loss.backward()