update codes

This commit is contained in:
johnjim0816
2021-12-22 16:55:09 +08:00
parent 75df999258
commit 41fb561d25
75 changed files with 1248 additions and 918 deletions

View File

@@ -12,16 +12,16 @@ Environment:
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Actor(nn.Module):
def __init__(self,state_dim, action_dim,
def __init__(self,n_states, n_actions,
hidden_dim):
super(Actor, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Linear(hidden_dim, n_actions),
nn.Softmax(dim=-1)
)
def forward(self, state):
@@ -30,10 +30,10 @@ class Actor(nn.Module):
return dist
class Critic(nn.Module):
def __init__(self, state_dim,hidden_dim):
def __init__(self, n_states,hidden_dim):
super(Critic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),