From e522ba5510c74a4ff91a05550b40ac780f08a678 Mon Sep 17 00:00:00 2001 From: JohnJim0816 Date: Sat, 13 Mar 2021 11:51:51 +0800 Subject: [PATCH] update --- codes/common/model.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/codes/common/model.py b/codes/common/model.py index bb71722..0f367f2 100644 --- a/codes/common/model.py +++ b/codes/common/model.py @@ -5,20 +5,23 @@ Author: John Email: johnjim0816@gmail.com Date: 2021-03-12 21:14:12 LastEditor: John -LastEditTime: 2021-03-12 21:28:46 +LastEditTime: 2021-03-13 11:51:38 Discription: Environment: ''' import torch.nn as nn import torch.nn.functional as F class MLP(nn.Module): - ''' 全连接网络''' - def __init__(self,state_dim): + ''' 多层感知机 + 输入:state维度 + 输出:概率 + ''' + def __init__(self,n_states,hidden_dim = 36): super(MLP, self).__init__() # 24和36为hidden layer的层数,可根据state_dim, n_actions的情况来改变 - self.fc1 = nn.Linear(state_dim, 36) - self.fc2 = nn.Linear(36, 36) - self.fc3 = nn.Linear(36, 1) # Prob of Left + self.fc1 = nn.Linear(n_states, hidden_dim) + self.fc2 = nn.Linear(hidden_dim,hidden_dim) + self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left def forward(self, x): x = F.relu(self.fc1(x))