update codes

This commit is contained in:
johnjim0816
2021-12-28 18:46:52 +08:00
parent 41fb561d25
commit bd51b5a7ad
52 changed files with 305 additions and 292 deletions

View File

@@ -17,9 +17,9 @@ from PolicyGradient.model import MLP
class PolicyGradient:
def __init__(self, n_states,cfg):
def __init__(self, state_dim,cfg):
self.gamma = cfg.gamma
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
self.policy_net = MLP(state_dim,hidden_dim=cfg.hidden_dim)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
self.batch_size = cfg.batch_size