update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-09 20:25:52
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-03-31 00:56:32
|
||||
LastEditTime: 2021-05-04 14:50:17
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -26,6 +26,7 @@ class DDPG:
|
||||
self.target_critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
|
||||
# copy parameters to target net
|
||||
for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
|
||||
@@ -42,7 +43,6 @@ class DDPG:
|
||||
def choose_action(self, state):
|
||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
action = self.actor(state)
|
||||
# torch.detach()用于切断反向传播
|
||||
return action.detach().cpu().numpy()[0, 0]
|
||||
|
||||
def update(self):
|
||||
@@ -50,13 +50,13 @@ class DDPG:
|
||||
return
|
||||
state, action, reward, next_state, done = self.memory.sample(
|
||||
self.batch_size)
|
||||
# 将所有变量转为张量
|
||||
# convert variables to Tensor
|
||||
state = torch.FloatTensor(state).to(self.device)
|
||||
next_state = torch.FloatTensor(next_state).to(self.device)
|
||||
action = torch.FloatTensor(action).to(self.device)
|
||||
reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device)
|
||||
done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)
|
||||
# 注意critic将(s_t,a)作为输入
|
||||
|
||||
policy_loss = self.critic(state, self.actor(state))
|
||||
policy_loss = -policy_loss.mean()
|
||||
next_action = self.target_actor(next_state)
|
||||
|
||||
Reference in New Issue
Block a user