update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:50:49
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-05-07 16:30:05
|
||||
LastEditTime: 2021-09-15 02:18:56
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -37,18 +37,20 @@ class DQN:
|
||||
self.batch_size = cfg.batch_size
|
||||
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # copy params from policy net
|
||||
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
|
||||
target_param.data.copy_(param.data)
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
|
||||
self.memory = ReplayBuffer(cfg.memory_capacity)
|
||||
|
||||
|
||||
def choose_action(self, state):
|
||||
'''选择动作
|
||||
'''
|
||||
self.frame_idx += 1
|
||||
if random.random() > self.epsilon(self.frame_idx):
|
||||
action = self.predict(state)
|
||||
with torch.no_grad():
|
||||
state = torch.tensor([state], device=self.device, dtype=torch.float32)
|
||||
q_values = self.policy_net(state)
|
||||
action = q_values.max(1)[1].item()
|
||||
else:
|
||||
action = random.randrange(self.action_dim)
|
||||
return action
|
||||
|
||||
Reference in New Issue
Block a user