update rainbowdqn
This commit is contained in:
@@ -70,9 +70,9 @@ class ReplayBuffer:
|
||||
return len(self.buffer)
|
||||
|
||||
class DQN:
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
|
||||
self.action_dim = action_dim # 总的动作个数
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.device = cfg.device # 设备,cpu或gpu等
|
||||
self.gamma = cfg.gamma # 奖励的折扣因子
|
||||
# e-greedy策略相关参数
|
||||
@@ -81,8 +81,8 @@ class DQN:
|
||||
(cfg.epsilon_start - cfg.epsilon_end) * \
|
||||
math.exp(-1. * frame_idx / cfg.epsilon_decay)
|
||||
self.batch_size = cfg.batch_size
|
||||
self.policy_net = CNN(state_dim, action_dim).to(self.device)
|
||||
self.target_net = CNN(state_dim, action_dim).to(self.device)
|
||||
self.policy_net = CNN(n_states, n_actions).to(self.device)
|
||||
self.target_net = CNN(n_states, n_actions).to(self.device)
|
||||
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
|
||||
target_param.data.copy_(param.data)
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
|
||||
@@ -94,11 +94,12 @@ class DQN:
|
||||
self.frame_idx += 1
|
||||
if random.random() > self.epsilon(self.frame_idx):
|
||||
with torch.no_grad():
|
||||
print(type(state))
|
||||
state = torch.tensor([state], device=self.device, dtype=torch.float32)
|
||||
q_values = self.policy_net(state)
|
||||
action = q_values.max(1)[1].item() # 选择Q值最大的动作
|
||||
else:
|
||||
action = random.randrange(self.action_dim)
|
||||
action = random.randrange(self.n_actions)
|
||||
return action
|
||||
def update(self):
|
||||
if len(self.memory) < self.batch_size: # 当memory中不满足一个批量时,不更新策略
|
||||
|
||||
Reference in New Issue
Block a user