update rainbowdqn

This commit is contained in:
johnjim0816
2022-05-31 01:20:58 +08:00
parent cfc0f6492e
commit c7c94468c9
149 changed files with 1866 additions and 1549 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2021-12-22 14:01:37
LastEditTime: 2022-03-02 11:05:11
@Discription:
@Environment: python 3.7.7
'''
@@ -20,22 +20,7 @@ import random
import math
import numpy as np
class MLP(nn.Module):
def __init__(self, state_dim,action_dim,hidden_dim=128):
""" 初始化q网络为全连接网络
state_dim: 输入的特征数即环境的状态维度
action_dim: 输出的动作维度
"""
super(MLP, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim) # 输入层
self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层
self.fc3 = nn.Linear(hidden_dim, action_dim) # 输出层
def forward(self, x):
# 各层对应的激活函数
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
class ReplayBuffer:
def __init__(self, capacity):
@@ -62,9 +47,9 @@ class ReplayBuffer:
return len(self.buffer)
class DQN:
def __init__(self, state_dim, action_dim, cfg):
def __init__(self, n_actions,model,cfg):
self.action_dim = action_dim # 总的动作个数
self.n_actions = n_actions # 总的动作个数
self.device = cfg.device # 设备cpu或gpu等
self.gamma = cfg.gamma # 奖励的折扣因子
# e-greedy策略相关参数
@@ -73,8 +58,8 @@ class DQN:
(cfg.epsilon_start - cfg.epsilon_end) * \
math.exp(-1. * frame_idx / cfg.epsilon_decay)
self.batch_size = cfg.batch_size
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
self.policy_net = model.to(self.device)
self.target_net = model.to(self.device)
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
@@ -86,23 +71,24 @@ class DQN:
self.frame_idx += 1
if random.random() > self.epsilon(self.frame_idx):
with torch.no_grad():
state = torch.tensor([state], device=self.device, dtype=torch.float32)
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # 选择Q值最大的动作
else:
action = random.randrange(self.action_dim)
action = random.randrange(self.n_actions)
return action
def update(self):
if len(self.memory) < self.batch_size: # 当memory中不满足一个批量时不更新策略
return
# 从经验回放中(replay memory)中随机采样一个批量的转移(transition)
# print('updating')
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
self.batch_size)
# 转为张量
state_batch = torch.tensor(state_batch, device=self.device, dtype=torch.float)
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1)
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float)
next_state_batch = torch.tensor(next_state_batch, device=self.device, dtype=torch.float)
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch), device=self.device)
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 计算当前状态(s_t,a)对应的Q(s_t, a)
next_q_values = self.target_net(next_state_batch).max(1)[0].detach() # 计算下一时刻的状态(s_t_,a)对应的Q值