update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:48:57
|
||||
@LastEditor: John
|
||||
LastEditTime: 2021-04-29 22:23:38
|
||||
LastEditTime: 2021-05-04 15:01:34
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -18,16 +18,13 @@ import datetime
|
||||
import torch
|
||||
import gym
|
||||
|
||||
from common.utils import save_results, make_dir, del_empty_dir
|
||||
from common.utils import save_results, make_dir
|
||||
from common.plot import plot_rewards
|
||||
from DQN.agent import DQN
|
||||
|
||||
|
||||
|
||||
curr_time = datetime.datetime.now().strftime(
|
||||
"%Y%m%d-%H%M%S") # obtain current time
|
||||
|
||||
|
||||
class DQNConfig:
|
||||
def __init__(self):
|
||||
self.algo = "DQN" # name of algo
|
||||
@@ -80,7 +77,7 @@ def train(cfg, env, agent):
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
print('Episode:{}/{}, Reward:{}'.format(i_episode+1, cfg.train_eps, ep_reward))
|
||||
rewards.append(ep_reward)
|
||||
# 计算滑动窗口的reward
|
||||
# save ma rewards
|
||||
if ma_rewards:
|
||||
ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user