update codes

This commit is contained in:
johnjim0816
2021-11-18 15:41:27 +08:00
parent 442e307b01
commit 129c0c65fa
103 changed files with 1025 additions and 558 deletions

View File

@@ -29,13 +29,16 @@ class PPO:
self.memory = PPOMemory(cfg.batch_size)
self.loss = 0
def choose_action(self, observation):
state = torch.tensor([observation], dtype=torch.float).to(self.device)
def choose_action(self, state,continuous=False):
state = torch.tensor([state], dtype=torch.float).to(self.device)
dist = self.actor(state)
value = self.critic(state)
action = dist.sample()
probs = torch.squeeze(dist.log_prob(action)).item()
action = torch.squeeze(action).item()
if continuous:
action = torch.tanh(action)
else:
action = torch.squeeze(action).item()
value = torch.squeeze(value).item()
return action, probs, value