This commit is contained in:
johnjim0816
2021-09-15 10:32:52 +08:00
parent 95f3f4dd57
commit 5085040330
74 changed files with 431 additions and 433 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2021-05-07 16:30:05
LastEditTime: 2021-09-15 02:18:56
@Discription:
@Environment: python 3.7.7
'''
@@ -37,18 +37,20 @@ class DQN:
self.batch_size = cfg.batch_size
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # copy params from policy net
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
self.memory = ReplayBuffer(cfg.memory_capacity)
def choose_action(self, state):
'''选择动作
'''
self.frame_idx += 1
if random.random() > self.epsilon(self.frame_idx):
action = self.predict(state)
with torch.no_grad():
state = torch.tensor([state], device=self.device, dtype=torch.float32)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item()
else:
action = random.randrange(self.action_dim)
return action