#!/usr/bin/env python # coding=utf-8 ''' @Author: John @Email: johnjim0816@gmail.com @Date: 2020-06-10 15:03:59 @LastEditor: John @LastEditTime: 2020-06-14 11:42:45 @Discription: @Environment: python 3.7.7 ''' import torch import torch.nn as nn import torch.nn.functional as F class Critic(nn.Module): def __init__(self, n_obs, n_actions, hidden_size, init_w=3e-3): super(Critic, self).__init__() self.linear1 = nn.Linear(n_obs + n_actions, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.linear3 = nn.Linear(hidden_size, 1) self.linear3.weight.data.uniform_(-init_w, init_w) self.linear3.bias.data.uniform_(-init_w, init_w) def forward(self, state, action): x = torch.cat([state, action], 1) x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) x = self.linear3(x) return x class Actor(nn.Module): def __init__(self, n_obs, n_actions, hidden_size, init_w=3e-3): super(Actor, self).__init__() self.linear1 = nn.Linear(n_obs, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.linear3 = nn.Linear(hidden_size, n_actions) self.linear3.weight.data.uniform_(-init_w, init_w) self.linear3.bias.data.uniform_(-init_w, init_w) def forward(self, x): x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) x = F.tanh(self.linear3(x)) return x def select_action(self, state): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") state = torch.FloatTensor(state).unsqueeze(0).to(device) # print(state) action = self.forward(state) return action.detach().cpu().numpy()[0, 0]