This commit is contained in:
JohnJim0816
2021-03-31 15:37:09 +08:00
parent 6a92f97138
commit b6f63a91bf
65 changed files with 1244 additions and 459 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-09 20:25:52
@LastEditor: John
LastEditTime: 2021-03-17 20:43:25
LastEditTime: 2021-03-31 00:56:32
@Discription:
@Environment: python 3.7.7
'''
@@ -58,9 +58,7 @@ class DDPG:
done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)
# 注意critic将(s_t,a)作为输入
policy_loss = self.critic(state, self.actor(state))
policy_loss = -policy_loss.mean()
next_action = self.target_actor(next_state)
target_value = self.target_critic(next_state, next_action.detach())
expected_value = reward + (1.0 - done) * self.gamma * target_value
@@ -87,7 +85,7 @@ class DDPG:
param.data * self.soft_tau
)
def save(self,path):
torch.save(self.target_net.state_dict(), path+'DDPG_checkpoint.pth')
torch.save(self.actor.state_dict(), path+'checkpoint.pt')
def load(self,path):
self.actor.load_state_dict(torch.load(path+'DDPG_checkpoint.pth'))
self.actor.load_state_dict(torch.load(path+'checkpoint.pt'))