#!/usr/bin/env python # coding=utf-8 ''' Author: JiangJi Email: johnjim0816@gmail.com Date: 2021-04-29 12:53:58 LastEditor: JiangJi LastEditTime: 2021-04-29 12:57:29 Discription: Environment: ''' import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal device=torch.device("cuda" if torch.cuda.is_available() else "cpu") class ValueNet(nn.Module): def __init__(self, state_dim, hidden_dim, init_w=3e-3): super(ValueNet, self).__init__() self.linear1 = nn.Linear(state_dim, hidden_dim) self.linear2 = nn.Linear(hidden_dim, hidden_dim) self.linear3 = nn.Linear(hidden_dim, 1) self.linear3.weight.data.uniform_(-init_w, init_w) self.linear3.bias.data.uniform_(-init_w, init_w) def forward(self, state): x = F.relu(self.linear1(state)) x = F.relu(self.linear2(x)) x = self.linear3(x) return x class SoftQNet(nn.Module): def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3): super(SoftQNet, self).__init__() self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.linear3 = nn.Linear(hidden_size, 1) self.linear3.weight.data.uniform_(-init_w, init_w) self.linear3.bias.data.uniform_(-init_w, init_w) def forward(self, state, action): x = torch.cat([state, action], 1) x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) x = self.linear3(x) return x class PolicyNet(nn.Module): def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2): super(PolicyNet, self).__init__() self.log_std_min = log_std_min self.log_std_max = log_std_max self.linear1 = nn.Linear(num_inputs, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.mean_linear = nn.Linear(hidden_size, num_actions) self.mean_linear.weight.data.uniform_(-init_w, init_w) self.mean_linear.bias.data.uniform_(-init_w, init_w) self.log_std_linear = nn.Linear(hidden_size, num_actions) self.log_std_linear.weight.data.uniform_(-init_w, init_w) self.log_std_linear.bias.data.uniform_(-init_w, init_w) def forward(self, state): x = F.relu(self.linear1(state)) x = F.relu(self.linear2(x)) mean = self.mean_linear(x) log_std = self.log_std_linear(x) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) return mean, log_std def evaluate(self, state, epsilon=1e-6): mean, log_std = self.forward(state) std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon) log_prob = log_prob.sum(-1, keepdim=True) return action, log_prob, z, mean, log_std def get_action(self, state): state = torch.FloatTensor(state).unsqueeze(0).to(device) mean, log_std = self.forward(state) std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.detach().cpu().numpy() return action[0]