Files
easy-rl/projects/codes/DQN/dqn.py
2022-08-15 22:31:37 +08:00

96 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# coding=utf-8
'''
@Author: John
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2022-08-11 09:52:23
@Discription:
@Environment: python 3.7.7
'''
'''off-policy
'''
import torch
import torch.nn as nn
import torch.optim as optim
import random
import math
import numpy as np
class DQN:
def __init__(self,n_actions,model,memory,cfg):
self.n_actions = n_actions
self.device = torch.device(cfg.device) # cpu or cuda
self.gamma = cfg.gamma # 奖励的折扣因子
# e-greedy策略相关参数
self.sample_count = 0 # 用于epsilon的衰减计数
self.epsilon = cfg.epsilon_start
self.sample_count = 0
self.epsilon_start = cfg.epsilon_start
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.batch_size = cfg.batch_size
self.policy_net = model.to(self.device)
self.target_net = model.to(self.device)
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
self.memory = memory # 经验回放
def sample(self, state):
''' 选择动作
'''
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay) # epsilon是会递减的这里选择指数递减
if random.random() > self.epsilon:
with torch.no_grad():
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # 选择Q值最大的动作
else:
action = random.randrange(self.n_actions)
return action
def predict(self,state):
with torch.no_grad():
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # 选择Q值最大的动作
return action
def update(self):
if len(self.memory) < self.batch_size: # 当memory中不满足一个批量时不更新策略
return
# 从经验回放中(replay memory)中随机采样一个批量的转移(transition)
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
self.batch_size)
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1)
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float)
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch), device=self.device)
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 计算当前状态(s_t,a)对应的Q(s_t, a)
next_q_values = self.target_net(next_state_batch).max(1)[0].detach() # 计算下一时刻的状态(s_t_,a)对应的Q值
# 计算期望的Q值对于终止状态此时done_batch[0]=1, 对应的expected_q_value等于reward
expected_q_values = reward_batch + self.gamma * next_q_values * (1-done_batch)
loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) # 计算均方根损失
# 优化更新模型
self.optimizer.zero_grad()
loss.backward()
for param in self.policy_net.parameters(): # clip防止梯度爆炸
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
def save(self, path):
from pathlib import Path
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.target_net.state_dict(), path+'checkpoint.pth')
def load(self, path):
self.target_net.load_state_dict(torch.load(path+'checkpoint.pth'))
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
param.data.copy_(target_param.data)