update codes

This commit is contained in:
johnjim0816
2021-12-28 18:46:52 +08:00
parent 41fb561d25
commit bd51b5a7ad
52 changed files with 305 additions and 292 deletions

View File

@@ -12,16 +12,16 @@ Environment:
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Actor(nn.Module):
def __init__(self,n_states, n_actions,
def __init__(self,state_dim, action_dim,
hidden_dim):
super(Actor, self).__init__()
self.actor = nn.Sequential(
nn.Linear(n_states, hidden_dim),
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, state):
@@ -30,10 +30,10 @@ class Actor(nn.Module):
return dist
class Critic(nn.Module):
def __init__(self, n_states,hidden_dim):
def __init__(self, state_dim,hidden_dim):
super(Critic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(n_states, hidden_dim),
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),