This commit is contained in:
johnjim0816
2021-12-22 11:19:13 +08:00
parent c257313d5b
commit 75df999258
55 changed files with 605 additions and 403 deletions

View File

@@ -5,21 +5,22 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-23 16:35:58
LastEditor: John
LastEditTime: 2021-03-23 16:36:20
LastEditTime: 2021-12-21 23:21:26
Discription:
Environment:
'''
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
''' 多层感知机
输入state维度
输出:概率
'''
def __init__(self,state_dim,hidden_dim = 36):
def __init__(self,input_dim,hidden_dim = 36):
super(MLP, self).__init__()
# 24和36为hidden layer的层数可根据state_dim, action_dim的情况来改变
self.fc1 = nn.Linear(state_dim, hidden_dim)
# 24和36为hidden layer的层数可根据input_dim, action_dim的情况来改变
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left

View File

@@ -34,7 +34,7 @@ class PGConfig:
self.model_path = curr_path+"/outputs/" + self.env + \
'/'+curr_time+'/models/' # 保存模型的路径
self.train_eps = 300 # 训练的回合数
self.eval_eps = 30 # 测试的回合数
self.test_eps = 30 # 测试的回合数
self.batch_size = 8
self.lr = 0.01 # 学习率
self.gamma = 0.99
@@ -94,7 +94,7 @@ def eval(cfg,env,agent):
print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')
rewards = []
ma_rewards = []
for i_ep in range(cfg.eval_eps):
for i_ep in range(cfg.test_eps):
state = env.reset()
ep_reward = 0
for _ in count():