This commit is contained in:
JohnJim0816
2021-03-28 11:18:52 +08:00
parent 2df8d965d2
commit 6e4d966e1f
56 changed files with 497 additions and 165 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2021-03-13 15:01:27
LastEditTime: 2021-03-28 11:07:35
@Discription:
@Environment: python 3.7.7
'''
@@ -16,16 +16,15 @@ LastEditTime: 2021-03-13 15:01:27
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import math
import numpy as np
from common.memory import ReplayBuffer
from common.model import MLP2
from common.model import MLP
class DoubleDQN:
def __init__(self, n_states, n_actions, cfg):
def __init__(self, state_dim, action_dim, cfg):
self.n_actions = n_actions # 总的动作个数
self.action_dim = action_dim # 总的动作个数
self.device = cfg.device # 设备cpu或gpu等
self.gamma = cfg.gamma
# e-greedy策略相关参数
@@ -34,8 +33,8 @@ class DoubleDQN:
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.batch_size = cfg.batch_size
self.policy_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
# target_net的初始模型参数完全复制policy_net
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval() # 不启用 BatchNormalization 和 Dropout
@@ -63,7 +62,7 @@ class DoubleDQN:
# 所以tensor.max(1)[1]返回最大值对应的下标即action
action = q_value.max(1)[1].item()
else:
action = random.randrange(self.n_actions)
action = random.randrange(self.action_dim)
return action
def update(self):