update
This commit is contained in:
@@ -5,10 +5,11 @@ Author: JiangJi
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-05-03 22:16:08
|
||||
LastEditor: JiangJi
|
||||
LastEditTime: 2021-05-03 22:23:48
|
||||
LastEditTime: 2022-07-20 23:54:40
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -42,7 +43,7 @@ class A2C:
|
||||
'''
|
||||
def __init__(self,n_states,n_actions,cfg) -> None:
|
||||
self.gamma = cfg.gamma
|
||||
self.device = cfg.device
|
||||
self.device = torch.device(cfg.device)
|
||||
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"algo_name": "A2C",
|
||||
"env_name": "CartPole-v0",
|
||||
"n_envs": 8,
|
||||
"max_steps": 20000,
|
||||
"n_steps": 5,
|
||||
"gamma": 0.99,
|
||||
"lr": 0.001,
|
||||
"hidden_dim": 256,
|
||||
"deivce": "cpu",
|
||||
"result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials/outputs/CartPole-v0/20220713-221850/results/",
|
||||
"model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials/outputs/CartPole-v0/20220713-221850/models/",
|
||||
"save_fig": true
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
------------------ start ------------------
|
||||
algo_name : A2C
|
||||
env_name : CartPole-v0
|
||||
n_envs : 8
|
||||
max_steps : 30000
|
||||
n_steps : 5
|
||||
gamma : 0.99
|
||||
lr : 0.001
|
||||
hidden_dim : 256
|
||||
result_path : c:\Users\24438\Desktop\rl-tutorials\codes\A2C/outputs/CartPole-v0/20220713-221850/results/
|
||||
model_path : c:\Users\24438\Desktop\rl-tutorials\codes\A2C/outputs/CartPole-v0/20220713-221850/models/
|
||||
save_fig : True
|
||||
device : cuda
|
||||
------------------- end -------------------
|
||||
File diff suppressed because one or more lines are too long
@@ -29,14 +29,13 @@ def get_args():
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--lr',default=1e-3,type=float,help="learning rate")
|
||||
parser.add_argument('--hidden_dim',default=256,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' ) # path to save models
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu") # check GPU
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def make_envs(env_name):
|
||||
|
||||
Reference in New Issue
Block a user