update codes

This commit is contained in:
johnjim0816
2021-11-19 16:02:34 +08:00
parent 129c0c65fa
commit 64c319cab4
47 changed files with 262 additions and 255 deletions

View File

@@ -18,6 +18,7 @@ from PPO.memory import PPOMemory
class PPO:
def __init__(self, state_dim, action_dim,cfg):
self.gamma = cfg.gamma
self.continuous = cfg.continuous
self.policy_clip = cfg.policy_clip
self.n_epochs = cfg.n_epochs
self.gae_lambda = cfg.gae_lambda
@@ -29,13 +30,13 @@ class PPO:
self.memory = PPOMemory(cfg.batch_size)
self.loss = 0
def choose_action(self, state,continuous=False):
def choose_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
dist = self.actor(state)
value = self.critic(state)
action = dist.sample()
probs = torch.squeeze(dist.log_prob(action)).item()
if continuous:
if self.continuous:
action = torch.tanh(action)
else:
action = torch.squeeze(action).item()