update rainbowdqn
This commit is contained in:
@@ -46,15 +46,15 @@ class ReplayBuffer:
|
||||
return len(self.buffer)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, state_dim,action_dim,hidden_dim=128):
|
||||
def __init__(self, n_states,n_actions,hidden_dim=128):
|
||||
""" 初始化q网络,为全连接网络
|
||||
state_dim: 输入的特征数即环境的状态维度
|
||||
action_dim: 输出的动作维度
|
||||
n_states: 输入的特征数即环境的状态维度
|
||||
n_actions: 输出的动作维度
|
||||
"""
|
||||
super(MLP, self).__init__()
|
||||
self.fc1 = nn.Linear(state_dim, hidden_dim) # 输入层
|
||||
self.fc1 = nn.Linear(n_states, hidden_dim) # 输入层
|
||||
self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层
|
||||
self.fc3 = nn.Linear(hidden_dim, action_dim) # 输出层
|
||||
self.fc3 = nn.Linear(hidden_dim, n_actions) # 输出层
|
||||
|
||||
def forward(self, x):
|
||||
# 各层对应的激活函数
|
||||
@@ -63,8 +63,8 @@ class MLP(nn.Module):
|
||||
return self.fc3(x)
|
||||
|
||||
class DoubleDQN:
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
self.action_dim = action_dim # 总的动作个数
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.device = cfg.device # 设备,cpu或gpu等
|
||||
self.gamma = cfg.gamma
|
||||
# e-greedy策略相关参数
|
||||
@@ -73,8 +73,8 @@ class DoubleDQN:
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.batch_size = cfg.batch_size
|
||||
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.policy_net = MLP(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
# target_net copy from policy_net
|
||||
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
@@ -103,7 +103,7 @@ class DoubleDQN:
|
||||
# 所以tensor.max(1)[1]返回最大值对应的下标,即action
|
||||
action = q_value.max(1)[1].item()
|
||||
else:
|
||||
action = random.randrange(self.action_dim)
|
||||
action = random.randrange(self.n_actions)
|
||||
return action
|
||||
def update(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user