update codes
This commit is contained in:
@@ -17,9 +17,9 @@ from PolicyGradient.model import MLP
|
||||
|
||||
class PolicyGradient:
|
||||
|
||||
def __init__(self, state_dim,cfg):
|
||||
def __init__(self, n_states,cfg):
|
||||
self.gamma = cfg.gamma
|
||||
self.policy_net = MLP(state_dim,hidden_dim=cfg.hidden_dim)
|
||||
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
|
||||
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class MLP(nn.Module):
|
||||
'''
|
||||
def __init__(self,input_dim,hidden_dim = 36):
|
||||
super(MLP, self).__init__()
|
||||
# 24和36为hidden layer的层数,可根据input_dim, action_dim的情况来改变
|
||||
# 24和36为hidden layer的层数,可根据input_dim, n_actions的情况来改变
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
|
||||
self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left
|
||||
|
||||
@@ -46,8 +46,8 @@ class PGConfig:
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = gym.make(cfg.env)
|
||||
env.seed(seed)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
agent = PolicyGradient(state_dim,cfg)
|
||||
n_states = env.observation_space.shape[0]
|
||||
agent = PolicyGradient(n_states,cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
|
||||
Reference in New Issue
Block a user