update codes
This commit is contained in:
@@ -18,6 +18,7 @@ from PPO.memory import PPOMemory
|
||||
class PPO:
|
||||
def __init__(self, state_dim, action_dim,cfg):
|
||||
self.gamma = cfg.gamma
|
||||
self.continuous = cfg.continuous
|
||||
self.policy_clip = cfg.policy_clip
|
||||
self.n_epochs = cfg.n_epochs
|
||||
self.gae_lambda = cfg.gae_lambda
|
||||
@@ -29,13 +30,13 @@ class PPO:
|
||||
self.memory = PPOMemory(cfg.batch_size)
|
||||
self.loss = 0
|
||||
|
||||
def choose_action(self, state,continuous=False):
|
||||
def choose_action(self, state):
|
||||
state = torch.tensor([state], dtype=torch.float).to(self.device)
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state)
|
||||
action = dist.sample()
|
||||
probs = torch.squeeze(dist.log_prob(action)).item()
|
||||
if continuous:
|
||||
if self.continuous:
|
||||
action = torch.tanh(action)
|
||||
else:
|
||||
action = torch.squeeze(action).item()
|
||||
|
||||
Reference in New Issue
Block a user