update
This commit is contained in:
@@ -46,15 +46,15 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FCN(nn.Module):
|
||||
def __init__(self, n_states=4, n_actions=18):
|
||||
def __init__(self, state_dim=4, action_dim=18):
|
||||
""" 初始化q网络,为全连接网络
|
||||
n_states: 输入的feature即环境的state数目
|
||||
n_actions: 输出的action总个数
|
||||
state_dim: 输入的feature即环境的state数目
|
||||
action_dim: 输出的action总个数
|
||||
"""
|
||||
super(FCN, self).__init__()
|
||||
self.fc1 = nn.Linear(n_states, 128) # 输入层
|
||||
self.fc1 = nn.Linear(state_dim, 128) # 输入层
|
||||
self.fc2 = nn.Linear(128, 128) # 隐藏层
|
||||
self.fc3 = nn.Linear(128, n_actions) # 输出层
|
||||
self.fc3 = nn.Linear(128, action_dim) # 输出层
|
||||
|
||||
def forward(self, x):
|
||||
# 各层对应的激活函数
|
||||
@@ -66,8 +66,8 @@ class FCN(nn.Module):
|
||||
|
||||
在```agent.py```中我们定义强化学习算法,包括```choose_action```和```update```两个主要函数,初始化中:
|
||||
```python
|
||||
self.policy_net = FCN(n_states, n_actions).to(self.device)
|
||||
self.target_net = FCN(n_states, n_actions).to(self.device)
|
||||
self.policy_net = FCN(state_dim, action_dim).to(self.device)
|
||||
self.target_net = FCN(state_dim, action_dim).to(self.device)
|
||||
# target_net的初始模型参数完全复制policy_net
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval() # 不启用 BatchNormalization 和 Dropout
|
||||
|
||||
@@ -20,11 +20,11 @@ import random
|
||||
import math
|
||||
import numpy as np
|
||||
from common.memory import ReplayBuffer
|
||||
from common.model import MLP2
|
||||
from common.model import MLP
|
||||
class DQN:
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.action_dim = action_dim # 总的动作个数
|
||||
self.device = cfg.device # 设备,cpu或gpu等
|
||||
self.gamma = cfg.gamma # 奖励的折扣因子
|
||||
# e-greedy策略相关参数
|
||||
@@ -34,8 +34,8 @@ class DQN:
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.batch_size = cfg.batch_size
|
||||
self.policy_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP2(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
|
||||
# target_net的初始模型参数完全复制policy_net
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval() # 不启用 BatchNormalization 和 Dropout
|
||||
@@ -64,7 +64,7 @@ class DQN:
|
||||
# 所以tensor.max(1)[1]返回最大值对应的下标,即action
|
||||
action = q_value.max(1)[1].item()
|
||||
else:
|
||||
action = random.randrange(self.n_actions)
|
||||
action = random.randrange(self.action_dim)
|
||||
return action
|
||||
else:
|
||||
with torch.no_grad(): # 取消保存梯度
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:48:57
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-03-17 20:35:37
|
||||
LastEditTime: 2021-03-26 17:17:17
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -40,7 +40,7 @@ class DQNConfig:
|
||||
self.lr = 0.01 # 学习率
|
||||
self.memory_capacity = 800 # Replay Memory容量
|
||||
self.batch_size = 64
|
||||
self.train_eps = 250 # 训练的episode数目
|
||||
self.train_eps = 300 # 训练的episode数目
|
||||
self.train_steps = 200 # 训练每个episode的最大长度
|
||||
self.target_update = 2 # target net的更新频率
|
||||
self.eval_eps = 20 # 测试的episode数目
|
||||
@@ -84,9 +84,9 @@ if __name__ == "__main__":
|
||||
cfg = DQNConfig()
|
||||
env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym,此处一般不需要
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
agent = DQN(n_states,n_actions,cfg)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.n
|
||||
agent = DQN(state_dim,action_dim,cfg)
|
||||
rewards,ma_rewards = train(cfg,env,agent)
|
||||
agent.save(path=SAVED_MODEL_PATH)
|
||||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
|
||||
|
||||
BIN
codes/DQN/results/20210326-171704/ma_rewards_train.npy
Normal file
BIN
codes/DQN/results/20210326-171704/ma_rewards_train.npy
Normal file
Binary file not shown.
BIN
codes/DQN/results/20210326-171704/rewards_curve_train.png
Normal file
BIN
codes/DQN/results/20210326-171704/rewards_curve_train.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
BIN
codes/DQN/results/20210326-171704/rewards_train.npy
Normal file
BIN
codes/DQN/results/20210326-171704/rewards_train.npy
Normal file
Binary file not shown.
BIN
codes/DQN/results/20210326-171722/ma_rewards_train.npy
Normal file
BIN
codes/DQN/results/20210326-171722/ma_rewards_train.npy
Normal file
Binary file not shown.
BIN
codes/DQN/results/20210326-171722/rewards_curve_train.png
Normal file
BIN
codes/DQN/results/20210326-171722/rewards_curve_train.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 66 KiB |
BIN
codes/DQN/results/20210326-171722/rewards_train.npy
Normal file
BIN
codes/DQN/results/20210326-171722/rewards_train.npy
Normal file
Binary file not shown.
BIN
codes/DQN/saved_model/20210326-171704/dqn_checkpoint.pth
Normal file
BIN
codes/DQN/saved_model/20210326-171704/dqn_checkpoint.pth
Normal file
Binary file not shown.
BIN
codes/DQN/saved_model/20210326-171722/dqn_checkpoint.pth
Normal file
BIN
codes/DQN/saved_model/20210326-171722/dqn_checkpoint.pth
Normal file
Binary file not shown.
Reference in New Issue
Block a user