#!/usr/bin/env python # coding=utf-8 ''' Author: John Email: johnjim0816@gmail.com Date: 2021-03-12 21:14:12 LastEditor: John LastEditTime: 2021-05-04 02:45:27 Discription: Environment: ''' import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical class MLP(nn.Module): def __init__(self, input_dim,output_dim,hidden_dim=128): """ 初始化q网络,为全连接网络 input_dim: 输入的feature即环境的state数目 output_dim: 输出的action总个数 """ super(MLP, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) # 输入层 self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层 self.fc3 = nn.Linear(hidden_dim, output_dim) # 输出层 def forward(self, x): # 各层对应的激活函数 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x) class Critic(nn.Module): def __init__(self, n_obs, output_dim, hidden_size, init_w=3e-3): super(Critic, self).__init__() self.linear1 = nn.Linear(n_obs + output_dim, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.linear3 = nn.Linear(hidden_size, 1) # 随机初始化为较小的值 self.linear3.weight.data.uniform_(-init_w, init_w) self.linear3.bias.data.uniform_(-init_w, init_w) def forward(self, state, action): # 按维数1拼接 x = torch.cat([state, action], 1) x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) x = self.linear3(x) return x class Actor(nn.Module): def __init__(self, n_obs, output_dim, hidden_size, init_w=3e-3): super(Actor, self).__init__() self.linear1 = nn.Linear(n_obs, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.linear3 = nn.Linear(hidden_size, output_dim) self.linear3.weight.data.uniform_(-init_w, init_w) self.linear3.bias.data.uniform_(-init_w, init_w) def forward(self, x): x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) x = torch.tanh(self.linear3(x)) return x class ActorCritic(nn.Module): def __init__(self, input_dim, output_dim, hidden_dim=256): super(ActorCritic, self).__init__() self.critic = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) self.actor = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), nn.Softmax(dim=1), ) def forward(self, x): value = self.critic(x) probs = self.actor(x) dist = Categorical(probs) return dist, value