update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:50:49
|
||||
@LastEditor: John
|
||||
LastEditTime: 2022-07-13 00:08:18
|
||||
LastEditTime: 2022-07-20 23:57:16
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -64,8 +64,8 @@ class ReplayBuffer:
|
||||
class DQN:
|
||||
def __init__(self, n_states,n_actions,cfg):
|
||||
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.device = cfg.device # 设备,cpu或gpu等
|
||||
self.n_actions = n_actions
|
||||
self.device = torch.device(cfg.device) # cpu or cuda
|
||||
self.gamma = cfg.gamma # 奖励的折扣因子
|
||||
# e-greedy策略相关参数
|
||||
self.frame_idx = 0 # 用于epsilon的衰减计数
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"algo_name": "DQN",
|
||||
"env_name": "CartPole-v0",
|
||||
"train_eps": 200,
|
||||
"test_eps": 20,
|
||||
"gamma": 0.95,
|
||||
"epsilon_start": 0.95,
|
||||
"epsilon_end": 0.01,
|
||||
"epsilon_decay": 500,
|
||||
"lr": 0.0001,
|
||||
"memory_capacity": 100000,
|
||||
"batch_size": 64,
|
||||
"target_update": 4,
|
||||
"hidden_dim": 256,
|
||||
"deivce": "cpu",
|
||||
"result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials/outputs/CartPole-v0/20220713-211653/results/",
|
||||
"model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials/outputs/CartPole-v0/20220713-211653/models/",
|
||||
"save_fig": true
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
------------------ start ------------------
|
||||
algo_name : DQN
|
||||
env_name : CartPole-v0
|
||||
train_eps : 200
|
||||
test_eps : 20
|
||||
gamma : 0.95
|
||||
epsilon_start : 0.95
|
||||
epsilon_end : 0.01
|
||||
epsilon_decay : 500
|
||||
lr : 0.0001
|
||||
memory_capacity : 100000
|
||||
batch_size : 64
|
||||
target_update : 4
|
||||
hidden_dim : 256
|
||||
result_path : C:\Users\24438\Desktop\rl-tutorials\codes\DQN/outputs/CartPole-v0/20220713-211653/results/
|
||||
model_path : C:\Users\24438\Desktop\rl-tutorials\codes\DQN/outputs/CartPole-v0/20220713-211653/models/
|
||||
save_fig : True
|
||||
device : cuda
|
||||
------------------- end -------------------
|
||||
@@ -1,12 +1,9 @@
|
||||
from lib2to3.pytree import type_repr
|
||||
import sys
|
||||
import os
|
||||
from parso import parse
|
||||
import sys,os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add to system path
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
|
||||
parent_path = os.path.dirname(curr_path) # 父路径
|
||||
sys.path.append(parent_path) # 添加路径到系统路径
|
||||
|
||||
import gym
|
||||
import torch
|
||||
@@ -35,14 +32,13 @@ def get_args():
|
||||
parser.add_argument('--batch_size',default=64,type=int)
|
||||
parser.add_argument('--target_update',default=4,type=int)
|
||||
parser.add_argument('--hidden_dim',default=256,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' ) # path to save models
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu") # check GPU
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def env_agent_config(cfg,seed=1):
|
||||
|
||||
Reference in New Issue
Block a user