This commit is contained in:
JohnJim0816
2021-03-28 11:18:52 +08:00
parent 2df8d965d2
commit 6e4d966e1f
56 changed files with 497 additions and 165 deletions

View File

@@ -17,9 +17,9 @@ from PolicyGradient.model import MLP
class PolicyGradient:
def __init__(self, n_states,cfg):
def __init__(self, state_dim,cfg):
self.gamma = cfg.gamma
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
self.policy_net = MLP(state_dim,hidden_dim=cfg.hidden_dim)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
self.batch_size = cfg.batch_size

View File

@@ -80,9 +80,9 @@ if __name__ == "__main__":
cfg = PGConfig()
env = gym.make('CartPole-v0') # 可google为什么unwrapped gym此处一般不需要
env.seed(1) # 设置env随机种子
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = PolicyGradient(n_states,cfg)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PolicyGradient(state_dim,cfg)
rewards, ma_rewards = train(cfg,env,agent)
agent.save_model(SAVED_MODEL_PATH)
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)

View File

@@ -16,10 +16,10 @@ class MLP(nn.Module):
输入state维度
输出:概率
'''
def __init__(self,n_states,hidden_dim = 36):
def __init__(self,state_dim,hidden_dim = 36):
super(MLP, self).__init__()
# 24和36为hidden layer的层数可根据state_dim, n_actions的情况来改变
self.fc1 = nn.Linear(n_states, hidden_dim)
# 24和36为hidden layer的层数可根据state_dim, action_dim的情况来改变
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left