update codes

This commit is contained in:
johnjim0816
2021-12-21 20:14:13 +08:00
parent 64c319cab4
commit 3b712e8815
71 changed files with 1097 additions and 1340 deletions

View File

@@ -14,16 +14,57 @@ LastEditTime: 2021-09-15 13:35:36
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import math
import numpy as np
from common.memory import ReplayBuffer
from common.model import MLP
class DQN:
def __init__(self, n_states, n_actions, cfg):
self.n_actions = n_actions # 总的动作个数
class MLP(nn.Module):
def __init__(self, state_dim,action_dim,hidden_dim=128):
""" 初始化q网络为全连接网络
state_dim: 输入的特征数即环境的状态数
action_dim: 输出的动作维度
"""
super(MLP, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim) # 输入层
self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层
self.fc3 = nn.Linear(hidden_dim, action_dim) # 输出层
def forward(self, x):
# 各层对应的激活函数
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity # 经验回放的容量
self.buffer = [] # 缓冲区
self.position = 0
def push(self, state, action, reward, next_state, done):
''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition)
'''
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移
state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等
return state, action, reward, next_state, done
def __len__(self):
''' 返回当前存储的量
'''
return len(self.buffer)
class DQN:
def __init__(self, state_dim, action_dim, cfg):
self.action_dim = action_dim # 总的动作个数
self.device = cfg.device # 设备cpu或gpu等
self.gamma = cfg.gamma # 奖励的折扣因子
# e-greedy策略相关参数
@@ -32,8 +73,8 @@ class DQN:
(cfg.epsilon_start - cfg.epsilon_end) * \
math.exp(-1. * frame_idx / cfg.epsilon_decay)
self.batch_size = cfg.batch_size
self.policy_net = MLP(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
@@ -49,7 +90,7 @@ class DQN:
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # 选择Q值最大的动作
else:
action = random.randrange(self.n_actions)
action = random.randrange(self.action_dim)
return action
def update(self):
if len(self.memory) < self.batch_size: # 当memory中不满足一个批量时不更新策略