update codes

This commit is contained in:
johnjim0816
2021-12-28 18:46:52 +08:00
parent 41fb561d25
commit bd51b5a7ad
52 changed files with 305 additions and 292 deletions

View File

@@ -17,10 +17,10 @@ from torch.distributions import Normal
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ValueNet(nn.Module):
def __init__(self, n_states, hidden_dim, init_w=3e-3):
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
super(ValueNet, self).__init__()
self.linear1 = nn.Linear(n_states, hidden_dim)
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
@@ -35,10 +35,10 @@ class ValueNet(nn.Module):
class SoftQNet(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
super(SoftQNet, self).__init__()
self.linear1 = nn.Linear(n_states + n_actions, hidden_dim)
self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
@@ -54,20 +54,20 @@ class SoftQNet(nn.Module):
class PolicyNet(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
super(PolicyNet, self).__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.linear1 = nn.Linear(n_states, hidden_dim)
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.mean_linear = nn.Linear(hidden_dim, n_actions)
self.mean_linear = nn.Linear(hidden_dim, action_dim)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_dim, n_actions)
self.log_std_linear = nn.Linear(hidden_dim, action_dim)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)