Files
easy-rl/codes/PolicyGradient/model.py
JohnJim0816 6e4d966e1f update
2021-03-28 11:18:52 +08:00

30 lines
838 B
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-23 16:35:58
LastEditor: John
LastEditTime: 2021-03-23 16:36:20
Discription:
Environment:
'''
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
''' 多层感知机
输入state维度
输出:概率
'''
def __init__(self,state_dim,hidden_dim = 36):
super(MLP, self).__init__()
# 24和36为hidden layer的层数可根据state_dim, action_dim的情况来改变
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
return x