This commit is contained in:
johnjim0816
2021-04-28 22:11:22 +08:00
parent e4690ac89f
commit ed7b60fd5b
73 changed files with 502 additions and 187 deletions

View File

@@ -92,14 +92,10 @@ class TD3(object):
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
self.memory = ReplayBuffer(state_dim, action_dim)
def choose_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
return self.actor(state).cpu().data.numpy().flatten()
def update(self):
self.total_it += 1
@@ -167,4 +163,4 @@ class TD3(object):
self.actor.load_state_dict(torch.load(path + "td3_actor"))
self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer"))
self.actor_target = copy.deepcopy(self.actor)