update
This commit is contained in:
@@ -20,11 +20,11 @@ import random
|
||||
import math
|
||||
import numpy as np
|
||||
from common.memory import ReplayBuffer
|
||||
from common.model import MLP2
|
||||
from common.model import MLP
|
||||
class DQN:
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.action_dim = action_dim # 总的动作个数
|
||||
self.device = cfg.device # 设备,cpu或gpu等
|
||||
self.gamma = cfg.gamma # 奖励的折扣因子
|
||||
# e-greedy策略相关参数
|
||||
@@ -34,8 +34,8 @@ class DQN:
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.batch_size = cfg.batch_size
|
||||
self.policy_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
# target_net的初始模型参数完全复制policy_net
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval() # 不启用 BatchNormalization 和 Dropout
|
||||
@@ -64,7 +64,7 @@ class DQN:
|
||||
# 所以tensor.max(1)[1]返回最大值对应的下标,即action
|
||||
action = q_value.max(1)[1].item()
|
||||
else:
|
||||
action = random.randrange(self.n_actions)
|
||||
action = random.randrange(self.action_dim)
|
||||
return action
|
||||
else:
|
||||
with torch.no_grad(): # 取消保存梯度
|
||||
|
||||
Reference in New Issue
Block a user