update
This commit is contained in:
@@ -17,9 +17,9 @@ from PolicyGradient.model import MLP
|
||||
|
||||
class PolicyGradient:
|
||||
|
||||
def __init__(self, n_states,cfg):
|
||||
def __init__(self, state_dim,cfg):
|
||||
self.gamma = cfg.gamma
|
||||
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
|
||||
self.policy_net = MLP(state_dim,hidden_dim=cfg.hidden_dim)
|
||||
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user