This commit is contained in:
JohnJim0816
2021-05-06 02:07:56 +08:00
parent 747f3238c0
commit b17c8f4e41
107 changed files with 1439 additions and 987 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-24 22:18:18
LastEditor: John
LastEditTime: 2021-03-31 14:51:09
LastEditTime: 2021-05-04 22:39:34
Discription:
Environment:
'''
@@ -65,11 +65,11 @@ class HierarchicalDQN:
if self.batch_size > len(self.memory):
return
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.batch_size)
state_batch = torch.tensor(state_batch,dtype=torch.float)
action_batch = torch.tensor(action_batch,dtype=torch.int64).unsqueeze(1)
reward_batch = torch.tensor(reward_batch,dtype=torch.float)
next_state_batch = torch.tensor(next_state_batch, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch))
state_batch = torch.tensor(state_batch,device=self.device,dtype=torch.float)
action_batch = torch.tensor(action_batch,device=self.device,dtype=torch.int64).unsqueeze(1)
reward_batch = torch.tensor(reward_batch,device=self.device,dtype=torch.float)
next_state_batch = torch.tensor(next_state_batch,device=self.device, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch),device=self.device)
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch).squeeze(1)
next_state_values = self.policy_net(next_state_batch).max(1)[0].detach()
expected_q_values = reward_batch + 0.99 * next_state_values * (1-done_batch)
@@ -79,17 +79,17 @@ class HierarchicalDQN:
for param in self.policy_net.parameters(): # clip防止梯度爆炸
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
self.loss_numpy = loss.detach().numpy()
self.loss_numpy = loss.detach().cpu().numpy()
self.losses.append(self.loss_numpy)
def update_meta(self):
if self.batch_size > len(self.meta_memory):
return
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.meta_memory.sample(self.batch_size)
state_batch = torch.tensor(state_batch,dtype=torch.float)
action_batch = torch.tensor(action_batch,dtype=torch.int64).unsqueeze(1)
reward_batch = torch.tensor(reward_batch,dtype=torch.float)
next_state_batch = torch.tensor(next_state_batch, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch))
state_batch = torch.tensor(state_batch,device=self.device,dtype=torch.float)
action_batch = torch.tensor(action_batch,device=self.device,dtype=torch.int64).unsqueeze(1)
reward_batch = torch.tensor(reward_batch,device=self.device,dtype=torch.float)
next_state_batch = torch.tensor(next_state_batch,device=self.device, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch),device=self.device)
q_values = self.meta_policy_net(state_batch).gather(dim=1, index=action_batch).squeeze(1)
next_state_values = self.meta_policy_net(next_state_batch).max(1)[0].detach()
expected_q_values = reward_batch + 0.99 * next_state_values * (1-done_batch)
@@ -99,7 +99,7 @@ class HierarchicalDQN:
for param in self.meta_policy_net.parameters(): # clip防止梯度爆炸
param.grad.data.clamp_(-1, 1)
self.meta_optimizer.step()
self.meta_loss_numpy = meta_loss.detach().numpy()
self.meta_loss_numpy = meta_loss.detach().cpu().numpy()
self.meta_losses.append(self.meta_loss_numpy)
def save(self, path):