update codes

This commit is contained in:
johnjim0816
2021-12-22 16:55:09 +08:00
parent 75df999258
commit 41fb561d25
75 changed files with 1248 additions and 918 deletions

View File

@@ -17,9 +17,9 @@ from PolicyGradient.model import MLP
class PolicyGradient:
def __init__(self, state_dim,cfg):
def __init__(self, n_states,cfg):
self.gamma = cfg.gamma
self.policy_net = MLP(state_dim,hidden_dim=cfg.hidden_dim)
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
self.batch_size = cfg.batch_size