This commit is contained in:
JohnJim0816
2021-05-04 15:30:01 +08:00
parent 4b96f5a6b0
commit 747f3238c0
41 changed files with 282 additions and 782 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:48:57
@LastEditor: John
LastEditTime: 2021-04-29 22:23:38
LastEditTime: 2021-05-04 15:01:34
@Discription:
@Environment: python 3.7.7
'''
@@ -18,16 +18,13 @@ import datetime
import torch
import gym
from common.utils import save_results, make_dir, del_empty_dir
from common.utils import save_results, make_dir
from common.plot import plot_rewards
from DQN.agent import DQN
curr_time = datetime.datetime.now().strftime(
"%Y%m%d-%H%M%S") # obtain current time
class DQNConfig:
def __init__(self):
self.algo = "DQN" # name of algo
@@ -80,7 +77,7 @@ def train(cfg, env, agent):
agent.target_net.load_state_dict(agent.policy_net.state_dict())
print('Episode:{}/{}, Reward:{}'.format(i_episode+1, cfg.train_eps, ep_reward))
rewards.append(ep_reward)
# 计算滑动窗口的reward
# save ma rewards
if ma_rewards:
ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)
else: