108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
Author: JiangJi
|
|
Email: johnjim0816@gmail.com
|
|
Date: 2021-04-29 12:53:58
|
|
LastEditor: JiangJi
|
|
LastEditTime: 2021-04-29 12:57:29
|
|
Discription:
|
|
Environment:
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributions import Normal
|
|
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
class ValueNet(nn.Module):
|
|
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
|
|
super(ValueNet, self).__init__()
|
|
|
|
self.linear1 = nn.Linear(state_dim, hidden_dim)
|
|
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.linear3 = nn.Linear(hidden_dim, 1)
|
|
|
|
self.linear3.weight.data.uniform_(-init_w, init_w)
|
|
self.linear3.bias.data.uniform_(-init_w, init_w)
|
|
|
|
def forward(self, state):
|
|
x = F.relu(self.linear1(state))
|
|
x = F.relu(self.linear2(x))
|
|
x = self.linear3(x)
|
|
return x
|
|
|
|
|
|
class SoftQNet(nn.Module):
|
|
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
|
|
super(SoftQNet, self).__init__()
|
|
|
|
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
|
|
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
|
self.linear3 = nn.Linear(hidden_size, 1)
|
|
|
|
self.linear3.weight.data.uniform_(-init_w, init_w)
|
|
self.linear3.bias.data.uniform_(-init_w, init_w)
|
|
|
|
def forward(self, state, action):
|
|
x = torch.cat([state, action], 1)
|
|
x = F.relu(self.linear1(x))
|
|
x = F.relu(self.linear2(x))
|
|
x = self.linear3(x)
|
|
return x
|
|
|
|
|
|
class PolicyNet(nn.Module):
|
|
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
|
|
super(PolicyNet, self).__init__()
|
|
|
|
self.log_std_min = log_std_min
|
|
self.log_std_max = log_std_max
|
|
|
|
self.linear1 = nn.Linear(num_inputs, hidden_size)
|
|
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
|
|
|
self.mean_linear = nn.Linear(hidden_size, num_actions)
|
|
self.mean_linear.weight.data.uniform_(-init_w, init_w)
|
|
self.mean_linear.bias.data.uniform_(-init_w, init_w)
|
|
|
|
self.log_std_linear = nn.Linear(hidden_size, num_actions)
|
|
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
|
|
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
|
|
|
|
def forward(self, state):
|
|
x = F.relu(self.linear1(state))
|
|
x = F.relu(self.linear2(x))
|
|
|
|
mean = self.mean_linear(x)
|
|
log_std = self.log_std_linear(x)
|
|
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
|
|
|
return mean, log_std
|
|
|
|
def evaluate(self, state, epsilon=1e-6):
|
|
mean, log_std = self.forward(state)
|
|
std = log_std.exp()
|
|
|
|
normal = Normal(mean, std)
|
|
z = normal.sample()
|
|
action = torch.tanh(z)
|
|
|
|
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
|
|
log_prob = log_prob.sum(-1, keepdim=True)
|
|
|
|
return action, log_prob, z, mean, log_std
|
|
|
|
|
|
def get_action(self, state):
|
|
state = torch.FloatTensor(state).unsqueeze(0).to(device)
|
|
mean, log_std = self.forward(state)
|
|
std = log_std.exp()
|
|
|
|
normal = Normal(mean, std)
|
|
z = normal.sample()
|
|
action = torch.tanh(z)
|
|
|
|
action = action.detach().cpu().numpy()
|
|
return action[0] |