update codes
This commit is contained in:
@@ -5,7 +5,7 @@ Author: JiangJi
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-04-29 12:53:58
|
||||
LastEditor: JiangJi
|
||||
LastEditTime: 2021-04-29 12:57:29
|
||||
LastEditTime: 2021-11-19 18:04:19
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -35,12 +35,12 @@ class ValueNet(nn.Module):
|
||||
|
||||
|
||||
class SoftQNet(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
|
||||
super(SoftQNet, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear3 = nn.Linear(hidden_size, 1)
|
||||
self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear3 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
self.linear3.weight.data.uniform_(-init_w, init_w)
|
||||
self.linear3.bias.data.uniform_(-init_w, init_w)
|
||||
@@ -54,20 +54,20 @@ class SoftQNet(nn.Module):
|
||||
|
||||
|
||||
class PolicyNet(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
|
||||
super(PolicyNet, self).__init__()
|
||||
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
|
||||
self.linear1 = nn.Linear(num_inputs, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear1 = nn.Linear(state_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
self.mean_linear = nn.Linear(hidden_size, num_actions)
|
||||
self.mean_linear = nn.Linear(hidden_dim, action_dim)
|
||||
self.mean_linear.weight.data.uniform_(-init_w, init_w)
|
||||
self.mean_linear.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
self.log_std_linear = nn.Linear(hidden_size, num_actions)
|
||||
self.log_std_linear = nn.Linear(hidden_dim, action_dim)
|
||||
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
|
||||
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user