update
This commit is contained in:
@@ -92,14 +92,10 @@ class TD3(object):
|
||||
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
|
||||
self.memory = ReplayBuffer(state_dim, action_dim)
|
||||
|
||||
|
||||
|
||||
|
||||
def choose_action(self, state):
|
||||
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
|
||||
return self.actor(state).cpu().data.numpy().flatten()
|
||||
|
||||
|
||||
def update(self):
|
||||
self.total_it += 1
|
||||
|
||||
@@ -167,4 +163,4 @@ class TD3(object):
|
||||
self.actor.load_state_dict(torch.load(path + "td3_actor"))
|
||||
self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer"))
|
||||
self.actor_target = copy.deepcopy(self.actor)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user