hot update PG

This commit is contained in:
johnjim0816
2022-08-25 21:00:53 +08:00
parent 4f4658503e
commit 80f20c73be
34 changed files with 1391 additions and 1695 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2020-11-22 23:27:44
LastEditor: John
LastEditTime: 2022-08-22 17:35:34
LastEditTime: 2022-08-25 20:58:59
Discription:
Environment:
'''
@@ -19,12 +19,12 @@ import numpy as np
class PolicyGradient:
def __init__(self, n_states,model,memory,cfg):
self.gamma = cfg.gamma
self.device = torch.device(cfg.device)
def __init__(self, model,memory,cfg):
self.gamma = cfg['gamma']
self.device = torch.device(cfg['device'])
self.memory = memory
self.policy_net = model.to(self.device)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg['lr'])
def sample_action(self,state):