update codes
This commit is contained in:
@@ -29,13 +29,16 @@ class PPO:
|
||||
self.memory = PPOMemory(cfg.batch_size)
|
||||
self.loss = 0
|
||||
|
||||
def choose_action(self, observation):
|
||||
state = torch.tensor([observation], dtype=torch.float).to(self.device)
|
||||
def choose_action(self, state,continuous=False):
|
||||
state = torch.tensor([state], dtype=torch.float).to(self.device)
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state)
|
||||
action = dist.sample()
|
||||
probs = torch.squeeze(dist.log_prob(action)).item()
|
||||
action = torch.squeeze(action).item()
|
||||
if continuous:
|
||||
action = torch.tanh(action)
|
||||
else:
|
||||
action = torch.squeeze(action).item()
|
||||
value = torch.squeeze(value).item()
|
||||
return action, probs, value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user