update DQN

This commit is contained in:
JohnJim0816
2020-11-28 10:16:25 +08:00
parent abfe6ea62b
commit 59a09144f1
6 changed files with 41 additions and 39 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:48:57
@LastEditor: John
LastEditTime: 2020-10-15 22:00:28
LastEditTime: 2020-11-23 11:58:17
@Discription:
@Environment: python 3.7.7
'''
@@ -16,7 +16,7 @@ import argparse
from torch.utils.tensorboard import SummaryWriter
import datetime
import os
from utils import save_results
from utils import save_results,save_model
SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/'
@@ -53,7 +53,7 @@ def get_args():
def train(cfg):
print('Start to train ! \n')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu
env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym此处一般不需要
env = gym.make('CartPole-v0')
env.seed(1) # 设置env随机种子
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
@@ -95,10 +95,7 @@ def train(cfg):
writer.close()
print('Complete training')
''' 保存模型 '''
if not os.path.exists(SAVED_MODEL_PATH): # 检测是否存在文件夹
os.mkdir(SAVED_MODEL_PATH)
agent.save_model(SAVED_MODEL_PATH+'checkpoint.pth')
print('model saved')
save_model(agent,model_path=SAVED_MODEL_PATH)
'''存储reward等相关结果'''
save_results(rewards,moving_average_rewards,ep_steps,tag='train',result_path=RESULT_PATH)
@@ -110,7 +107,7 @@ def eval(cfg, saved_model_path = SAVED_MODEL_PATH):
env.seed(1) # 设置env随机种子
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = DQN(n_states=n_states, n_actions=n_actions, device=device, gamma=cfg.gamma, epsilon_start=cfg.epsilon_start,
agent = DQN(n_states=n_states, n_actions=n_actions, device="cpu", gamma=cfg.gamma, epsilon_start=cfg.epsilon_start,
epsilon_end=cfg.epsilon_end, epsilon_decay=cfg.epsilon_decay, policy_lr=cfg.policy_lr, memory_capacity=cfg.memory_capacity, batch_size=cfg.batch_size)
agent.load_model(saved_model_path+'checkpoint.pth')
rewards = []