diff --git a/codes/dqn/.vscode/settings.json b/codes/dqn/.vscode/settings.json new file mode 100644 index 0000000..be0f1ab --- /dev/null +++ b/codes/dqn/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "/Users/jj/anaconda3/envs/py37/bin/python" +} \ No newline at end of file diff --git a/codes/dqn/README.md b/codes/dqn/README.md new file mode 100644 index 0000000..e0419b9 --- /dev/null +++ b/codes/dqn/README.md @@ -0,0 +1,24 @@ +python 3.7.9 + +pytorch 1.6.0 + +tensorboard 2.3.0 + +torchvision 0.7.0 + + +train: + +```python +python main.py +``` + +eval: + +```python +python main.py --train 0 +``` + +```python +tensorboard --logdir logs +``` \ No newline at end of file diff --git a/codes/dqn/dqn.py b/codes/dqn/agent.py similarity index 63% rename from codes/dqn/dqn.py rename to codes/dqn/agent.py index bff4cef..9ff8b28 100644 --- a/codes/dqn/dqn.py +++ b/codes/dqn/agent.py @@ -5,7 +5,7 @@ @Email: johnjim0816@gmail.com @Date: 2020-06-12 00:50:49 @LastEditor: John -LastEditTime: 2020-10-07 17:32:18 +LastEditTime: 2020-10-15 21:56:21 @Discription: @Environment: python 3.7.7 ''' @@ -13,8 +13,6 @@ LastEditTime: 2020-10-07 17:32:18 ''' - - import torch import torch.nn as nn import torch.optim as optim @@ -30,7 +28,7 @@ class DQN: self.n_actions = n_actions # 总的动作个数 self.device = device # 设备,cpu或gpu等 self.gamma = gamma - # e-greedy 策略相关参数 + # e-greedy策略相关参数 self.epsilon = 0 self.epsilon_start = epsilon_start self.epsilon_end = epsilon_end @@ -46,32 +44,41 @@ class DQN: self.loss = 0 self.memory = ReplayBuffer(memory_capacity) - def select_action(self, state): + def choose_action(self, state, train=True): '''选择动作 - Args: - state [array]: [description] - Returns: - action [array]: [description] ''' - self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ - math.exp(-1. * self.actions_count / self.epsilon_decay) - self.actions_count += 1 - if random.random() > self.epsilon: + if train: + self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ + math.exp(-1. * self.actions_count / self.epsilon_decay) + self.actions_count += 1 + if random.random() > self.epsilon: + with torch.no_grad(): + # 先转为张量便于丢给神经网络,state元素数据原本为float64 + # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价 + state = torch.tensor( + [state], device=self.device, dtype=torch.float32) + # 如tensor([[-0.0798, -0.0079]], grad_fn=) + q_value = self.policy_net(state) + # tensor.max(1)返回每行的最大值以及对应的下标, + # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0])) + # 所以tensor.max(1)[1]返回最大值对应的下标,即action + action = q_value.max(1)[1].item() + else: + action = random.randrange(self.n_actions) + return action + else: with torch.no_grad(): - # 先转为张量便于丢给神经网络,state元素数据原本为float64 - # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价 - state = torch.tensor( - [state], device=self.device, dtype=torch.float32) - # 如tensor([[-0.0798, -0.0079]], grad_fn=) - q_value = self.policy_net(state) - # tensor.max(1)返回每行的最大值以及对应的下标, - # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0])) - # 所以tensor.max(1)[1]返回最大值对应的下标,即action - action = q_value.max(1)[1].item() - else: - action = random.randrange(self.n_actions) - return action - + # 先转为张量便于丢给神经网络,state元素数据原本为float64 + # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价 + state = torch.tensor( + [state], device='cpu', dtype=torch.float32) + # 如tensor([[-0.0798, -0.0079]], grad_fn=) + q_value = self.target_net(state) + # tensor.max(1)返回每行的最大值以及对应的下标, + # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0])) + # 所以tensor.max(1)[1]返回最大值对应的下标,即action + action = q_value.max(1)[1].item() + return action def update(self): if len(self.memory) < self.batch_size: @@ -113,8 +120,9 @@ class DQN: for param in self.policy_net.parameters(): # clip防止梯度爆炸 param.grad.data.clamp_(-1, 1) self.optimizer.step() # 更新模型 - - def save_model(): - pass - def load_model(): - pass \ No newline at end of file + + def save_model(self,path): + torch.save(self.target_net.state_dict(), path) + + def load_model(self,path): + self.target_net.load_state_dict(torch.load(path)) diff --git a/codes/dqn/logs/eval/20201015-215937/events.out.tfevents.1602770409.MacBook-Pro.local.21607.3 b/codes/dqn/logs/eval/20201015-215937/events.out.tfevents.1602770409.MacBook-Pro.local.21607.3 new file mode 100644 index 0000000..8ceddf5 Binary files /dev/null and b/codes/dqn/logs/eval/20201015-215937/events.out.tfevents.1602770409.MacBook-Pro.local.21607.3 differ diff --git a/codes/dqn/logs/eval/20201015-215937/rewards_moving_average/events.out.tfevents.1602770409.MacBook-Pro.local.21607.5 b/codes/dqn/logs/eval/20201015-215937/rewards_moving_average/events.out.tfevents.1602770409.MacBook-Pro.local.21607.5 new file mode 100644 index 0000000..aada812 Binary files /dev/null and b/codes/dqn/logs/eval/20201015-215937/rewards_moving_average/events.out.tfevents.1602770409.MacBook-Pro.local.21607.5 differ diff --git a/codes/dqn/logs/eval/20201015-215937/rewards_raw/events.out.tfevents.1602770409.MacBook-Pro.local.21607.4 b/codes/dqn/logs/eval/20201015-215937/rewards_raw/events.out.tfevents.1602770409.MacBook-Pro.local.21607.4 new file mode 100644 index 0000000..ae17517 Binary files /dev/null and b/codes/dqn/logs/eval/20201015-215937/rewards_raw/events.out.tfevents.1602770409.MacBook-Pro.local.21607.4 differ diff --git a/codes/dqn/logs/train/20201015-215937/events.out.tfevents.1602770377.MacBook-Pro.local.21607.0 b/codes/dqn/logs/train/20201015-215937/events.out.tfevents.1602770377.MacBook-Pro.local.21607.0 new file mode 100644 index 0000000..9a4f3f5 Binary files /dev/null and b/codes/dqn/logs/train/20201015-215937/events.out.tfevents.1602770377.MacBook-Pro.local.21607.0 differ diff --git a/codes/dqn/logs/train/20201015-215937/rewards_moving_average/events.out.tfevents.1602770377.MacBook-Pro.local.21607.2 b/codes/dqn/logs/train/20201015-215937/rewards_moving_average/events.out.tfevents.1602770377.MacBook-Pro.local.21607.2 new file mode 100644 index 0000000..8eed693 Binary files /dev/null and b/codes/dqn/logs/train/20201015-215937/rewards_moving_average/events.out.tfevents.1602770377.MacBook-Pro.local.21607.2 differ diff --git a/codes/dqn/logs/train/20201015-215937/rewards_raw/events.out.tfevents.1602770377.MacBook-Pro.local.21607.1 b/codes/dqn/logs/train/20201015-215937/rewards_raw/events.out.tfevents.1602770377.MacBook-Pro.local.21607.1 new file mode 100644 index 0000000..4322bc3 Binary files /dev/null and b/codes/dqn/logs/train/20201015-215937/rewards_raw/events.out.tfevents.1602770377.MacBook-Pro.local.21607.1 differ diff --git a/codes/dqn/main.py b/codes/dqn/main.py index 21852fb..9bdc94d 100644 --- a/codes/dqn/main.py +++ b/codes/dqn/main.py @@ -5,20 +5,28 @@ @Email: johnjim0816@gmail.com @Date: 2020-06-12 00:48:57 @LastEditor: John -LastEditTime: 2020-08-22 18:02:56 +LastEditTime: 2020-10-15 22:00:28 @Discription: @Environment: python 3.7.7 ''' import gym import torch -from dqn import DQN -from plot import plot +from agent import DQN import argparse +from torch.utils.tensorboard import SummaryWriter +import datetime +import os +from utils import save_results + +SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") +SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/' +RESULT_PATH = os.path.split(os.path.abspath(__file__))[0]+"/result/"+SEQUENCE+'/' def get_args(): '''模型参数 ''' parser = argparse.ArgumentParser() + parser.add_argument("--train", default=1, type=int) # 1 表示训练,0表示只进行eval parser.add_argument("--gamma", default=0.99, type=float) # q-learning中的gamma parser.add_argument("--epsilon_start", default=0.95, @@ -31,20 +39,19 @@ def get_args(): parser.add_argument("--batch_size", default=32, type=int, help="batch size of memory sampling") - parser.add_argument("--max_episodes", default=200, type=int) # 训练的最大episode数目 - parser.add_argument("--max_steps", default=200, type=int) - # 将目标网络的更新频率改为1就是普通的dqn,大于1就是double dqn - parser.add_argument("--target_update", default=1, type=int, - help="when(every default 10 eisodes) to update target net ") + parser.add_argument("--train_eps", default=200, type=int) # 训练的最大episode数目 + parser.add_argument("--train_steps", default=200, type=int) + parser.add_argument("--target_update", default=2, type=int, + help="when(every default 2 eisodes) to update target net ") # 更新频率 + + parser.add_argument("--eval_eps", default=100, type=int) # 训练的最大episode数目 + parser.add_argument("--eval_steps", default=200, + type=int) # 训练每个episode的长度 config = parser.parse_args() return config - - -if __name__ == "__main__": - - cfg = get_args() - # if gpu is to be used +def train(cfg): + print('Start to train ! \n') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym,此处一般不需要 env.seed(1) # 设置env随机种子 @@ -55,11 +62,13 @@ if __name__ == "__main__": rewards = [] moving_average_rewards = [] ep_steps = [] - for i_episode in range(1, cfg.max_episodes+1): + log_dir=os.path.split(os.path.abspath(__file__))[0]+"/logs/train/" + SEQUENCE + writer = SummaryWriter(log_dir) + for i_episode in range(1, cfg.train_eps+1): state = env.reset() # reset环境状态 ep_reward = 0 - for i_step in range(1, cfg.max_steps+1): - action = agent.select_action(state) # 根据当前环境state选择action + for i_step in range(1, cfg.train_steps+1): + action = agent.choose_action(state) # 根据当前环境state选择action next_state, reward, done, _ = env.step(action) # 更新环境参数 ep_reward += reward agent.memory.push(state, action, reward, next_state, done) # 将state等这些transition存入memory @@ -80,17 +89,68 @@ if __name__ == "__main__": else: moving_average_rewards.append( 0.9*moving_average_rewards[-1]+0.1*ep_reward) - # 存储reward等相关结果 - import os - import numpy as np - output_path = os.path.dirname(__file__)+"/result/" - # 检测是否存在文件夹 - if not os.path.exists(output_path): - os.mkdir(output_path) - np.save(output_path+"rewards.npy", rewards) - np.save(output_path+"moving_average_rewards.npy", moving_average_rewards) - np.save(output_path+"steps.npy", ep_steps) - print('Complete!') - plot(rewards) - plot(moving_average_rewards, ylabel="moving_average_rewards") - plot(ep_steps, ylabel="steps_of_each_episode") + writer.add_scalars('rewards',{'raw':rewards[-1], 'moving_average': moving_average_rewards[-1]}, i_episode) + writer.add_scalar('steps_of_each_episode', + ep_steps[-1], i_episode) + writer.close() + print('Complete training!') + ''' 保存模型 ''' + if not os.path.exists(SAVED_MODEL_PATH): # 检测是否存在文件夹 + os.mkdir(SAVED_MODEL_PATH) + agent.save_model(SAVED_MODEL_PATH+'checkpoint.pth') + print('model saved!') + '''存储reward等相关结果''' + save_results(rewards,moving_average_rewards,ep_steps,tag='train',result_path=RESULT_PATH) + + +def eval(cfg, saved_model_path = SAVED_MODEL_PATH): + print('start to eval ! \n') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu + env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym,此处一般不需要 + env.seed(1) # 设置env随机种子 + n_states = env.observation_space.shape[0] + n_actions = env.action_space.n + agent = DQN(n_states=n_states, n_actions=n_actions, device=device, gamma=cfg.gamma, epsilon_start=cfg.epsilon_start, + epsilon_end=cfg.epsilon_end, epsilon_decay=cfg.epsilon_decay, policy_lr=cfg.policy_lr, memory_capacity=cfg.memory_capacity, batch_size=cfg.batch_size) + agent.load_model(saved_model_path+'checkpoint.pth') + rewards = [] + moving_average_rewards = [] + ep_steps = [] + log_dir=os.path.split(os.path.abspath(__file__))[0]+"/logs/eval/" + SEQUENCE + writer = SummaryWriter(log_dir) + for i_episode in range(1, cfg.eval_eps+1): + state = env.reset() # reset环境状态 + ep_reward = 0 + for i_step in range(1, cfg.eval_steps+1): + action = agent.choose_action(state,train=False) # 根据当前环境state选择action + next_state, reward, done, _ = env.step(action) # 更新环境参数 + ep_reward += reward + state = next_state # 跳转到下一个状态 + if done: + break + print('Episode:', i_episode, ' Reward: %i' % + int(ep_reward), 'n_steps:', i_step, 'done: ', done) + ep_steps.append(i_step) + rewards.append(ep_reward) + # 计算滑动窗口的reward + if i_episode == 1: + moving_average_rewards.append(ep_reward) + else: + moving_average_rewards.append( + 0.9*moving_average_rewards[-1]+0.1*ep_reward) + writer.add_scalars('rewards',{'raw':rewards[-1], 'moving_average': moving_average_rewards[-1]}, i_episode) + writer.add_scalar('steps_of_each_episode', + ep_steps[-1], i_episode) + writer.close() + '''存储reward等相关结果''' + save_results(rewards,moving_average_rewards,ep_steps,tag='eval',result_path=RESULT_PATH) + print('Complete evaling!') + +if __name__ == "__main__": + cfg = get_args() + if cfg.train: + train(cfg) + eval(cfg) + else: + model_path = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/" + eval(cfg,saved_model_path=model_path) diff --git a/codes/dqn/plot.py b/codes/dqn/plot.py index 41f524e..2bc3e04 100644 --- a/codes/dqn/plot.py +++ b/codes/dqn/plot.py @@ -5,7 +5,7 @@ @Email: johnjim0816@gmail.com @Date: 2020-06-11 16:30:09 @LastEditor: John -LastEditTime: 2020-10-07 20:57:22 +LastEditTime: 2020-10-15 22:01:50 @Discription: @Environment: python 3.7.7 ''' @@ -14,19 +14,45 @@ import seaborn as sns import numpy as np import os -def plot(item,ylabel='rewards'): +def plot(item,ylabel='rewards_train', save_fig = True): + '''plot using searborn to plot + ''' sns.set() plt.figure() plt.plot(np.arange(len(item)), item) plt.title(ylabel+' of DQN') plt.ylabel(ylabel) plt.xlabel('episodes') - plt.savefig(os.path.dirname(__file__)+"/result/"+ylabel+".png") + if save_fig: + plt.savefig(os.path.dirname(__file__)+"/result/"+ylabel+".png") plt.show() + +# def plot(item,ylabel='rewards'): +# +# df = pd.DataFrame(dict(time=np.arange(len(item)),value=item)) +# g = sns.relplot(x="time", y="value", kind="line", data=df) +# # g.fig.autofmt_xdate() +# # sns.lineplot(time=time, data=item, color="r", condition="behavior_cloning") +# # # sns.tsplot(time=time, data=x2, color="b", condition="dagger") +# # plt.ylabel("Reward") +# # plt.xlabel("Iteration Number") +# # plt.title("Imitation Learning") + + # plt.show() if __name__ == "__main__": - output_path = os.path.dirname(__file__)+"/result/" - rewards=np.load(output_path+"rewards.npy", ) - moving_average_rewards=np.load(output_path+"moving_average_rewards.npy",) + output_path = os.path.split(os.path.abspath(__file__))[0]+"/result/" + tag = 'train' + rewards=np.load(output_path+"rewards_"+tag+".npy", ) + moving_average_rewards=np.load(output_path+"moving_average_rewards_"+tag+".npy",) + steps=np.load(output_path+"steps_"+tag+".npy") plot(rewards) - plot(moving_average_rewards,ylabel='moving_average_rewards') + plot(moving_average_rewards,ylabel='moving_average_rewards_'+tag) + plot(steps,ylabel='steps_'+tag) + tag = 'eval' + rewards=np.load(output_path+"rewards_"+tag+".npy", ) + moving_average_rewards=np.load(output_path+"moving_average_rewards_"+tag+".npy",) + steps=np.load(output_path+"steps_"+tag+".npy") + plot(rewards,ylabel='rewards_'+tag) + plot(moving_average_rewards,ylabel='moving_average_rewards_'+tag) + plot(steps,ylabel='steps_'+tag) diff --git a/codes/dqn/result/20201015-215937/moving_average_rewards_eval.npy b/codes/dqn/result/20201015-215937/moving_average_rewards_eval.npy new file mode 100644 index 0000000..4d9dbaa Binary files /dev/null and b/codes/dqn/result/20201015-215937/moving_average_rewards_eval.npy differ diff --git a/codes/dqn/result/20201015-215937/moving_average_rewards_train.npy b/codes/dqn/result/20201015-215937/moving_average_rewards_train.npy new file mode 100644 index 0000000..67c5579 Binary files /dev/null and b/codes/dqn/result/20201015-215937/moving_average_rewards_train.npy differ diff --git a/codes/dqn/result/20201015-215937/rewards_eval.npy b/codes/dqn/result/20201015-215937/rewards_eval.npy new file mode 100644 index 0000000..b992efa Binary files /dev/null and b/codes/dqn/result/20201015-215937/rewards_eval.npy differ diff --git a/codes/dqn/result/20201015-215937/rewards_train.npy b/codes/dqn/result/20201015-215937/rewards_train.npy new file mode 100644 index 0000000..b4758a9 Binary files /dev/null and b/codes/dqn/result/20201015-215937/rewards_train.npy differ diff --git a/codes/dqn/result/20201015-215937/steps_eval.npy b/codes/dqn/result/20201015-215937/steps_eval.npy new file mode 100644 index 0000000..d10f0eb Binary files /dev/null and b/codes/dqn/result/20201015-215937/steps_eval.npy differ diff --git a/codes/dqn/result/20201015-215937/steps_train.npy b/codes/dqn/result/20201015-215937/steps_train.npy new file mode 100644 index 0000000..ccc81c7 Binary files /dev/null and b/codes/dqn/result/20201015-215937/steps_train.npy differ diff --git a/codes/dqn/result/moving_average_rewards.npy b/codes/dqn/result/moving_average_rewards.npy deleted file mode 100644 index b770ce2..0000000 Binary files a/codes/dqn/result/moving_average_rewards.npy and /dev/null differ diff --git a/codes/dqn/result/moving_average_rewards.png b/codes/dqn/result/moving_average_rewards.png deleted file mode 100644 index 86469f8..0000000 Binary files a/codes/dqn/result/moving_average_rewards.png and /dev/null differ diff --git a/codes/dqn/result/moving_average_rewards_eval.npy b/codes/dqn/result/moving_average_rewards_eval.npy new file mode 100644 index 0000000..4d9dbaa Binary files /dev/null and b/codes/dqn/result/moving_average_rewards_eval.npy differ diff --git a/codes/dqn/result/moving_average_rewards_eval.png b/codes/dqn/result/moving_average_rewards_eval.png new file mode 100644 index 0000000..c2ba80b Binary files /dev/null and b/codes/dqn/result/moving_average_rewards_eval.png differ diff --git a/codes/dqn/result/moving_average_rewards_train.npy b/codes/dqn/result/moving_average_rewards_train.npy new file mode 100644 index 0000000..67c5579 Binary files /dev/null and b/codes/dqn/result/moving_average_rewards_train.npy differ diff --git a/codes/dqn/result/moving_average_rewards_train.png b/codes/dqn/result/moving_average_rewards_train.png new file mode 100644 index 0000000..34af087 Binary files /dev/null and b/codes/dqn/result/moving_average_rewards_train.png differ diff --git a/codes/dqn/result/rewards.npy b/codes/dqn/result/rewards.npy deleted file mode 100644 index 2d0ff55..0000000 Binary files a/codes/dqn/result/rewards.npy and /dev/null differ diff --git a/codes/dqn/result/rewards.png b/codes/dqn/result/rewards.png deleted file mode 100644 index 106d8ed..0000000 Binary files a/codes/dqn/result/rewards.png and /dev/null differ diff --git a/codes/dqn/result/rewards_eval.npy b/codes/dqn/result/rewards_eval.npy new file mode 100644 index 0000000..b992efa Binary files /dev/null and b/codes/dqn/result/rewards_eval.npy differ diff --git a/codes/dqn/result/rewards_eval.png b/codes/dqn/result/rewards_eval.png new file mode 100644 index 0000000..735fa2b Binary files /dev/null and b/codes/dqn/result/rewards_eval.png differ diff --git a/codes/dqn/result/rewards_train.npy b/codes/dqn/result/rewards_train.npy new file mode 100644 index 0000000..b4758a9 Binary files /dev/null and b/codes/dqn/result/rewards_train.npy differ diff --git a/codes/dqn/result/rewards_train.png b/codes/dqn/result/rewards_train.png new file mode 100644 index 0000000..471ecff Binary files /dev/null and b/codes/dqn/result/rewards_train.png differ diff --git a/codes/dqn/result/steps.npy b/codes/dqn/result/steps.npy deleted file mode 100644 index a40309e..0000000 Binary files a/codes/dqn/result/steps.npy and /dev/null differ diff --git a/codes/dqn/result/steps_eval.npy b/codes/dqn/result/steps_eval.npy new file mode 100644 index 0000000..d10f0eb Binary files /dev/null and b/codes/dqn/result/steps_eval.npy differ diff --git a/codes/dqn/result/steps_eval.png b/codes/dqn/result/steps_eval.png new file mode 100644 index 0000000..c3864ee Binary files /dev/null and b/codes/dqn/result/steps_eval.png differ diff --git a/codes/dqn/result/steps_of_each_episode.png b/codes/dqn/result/steps_of_each_episode.png deleted file mode 100644 index 408f4cd..0000000 Binary files a/codes/dqn/result/steps_of_each_episode.png and /dev/null differ diff --git a/codes/dqn/result/steps_train.npy b/codes/dqn/result/steps_train.npy new file mode 100644 index 0000000..ccc81c7 Binary files /dev/null and b/codes/dqn/result/steps_train.npy differ diff --git a/codes/dqn/result/steps_train.png b/codes/dqn/result/steps_train.png new file mode 100644 index 0000000..3ba5e60 Binary files /dev/null and b/codes/dqn/result/steps_train.png differ diff --git a/codes/dqn/saved_model/20201015-215937/checkpoint.pth b/codes/dqn/saved_model/20201015-215937/checkpoint.pth new file mode 100644 index 0000000..de23719 Binary files /dev/null and b/codes/dqn/saved_model/20201015-215937/checkpoint.pth differ diff --git a/codes/dqn/saved_model/checkpoint.pth b/codes/dqn/saved_model/checkpoint.pth new file mode 100644 index 0000000..de23719 Binary files /dev/null and b/codes/dqn/saved_model/checkpoint.pth differ diff --git a/codes/dqn/utils.py b/codes/dqn/utils.py new file mode 100644 index 0000000..0c75408 --- /dev/null +++ b/codes/dqn/utils.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-10-15 21:28:00 +LastEditor: John +LastEditTime: 2020-10-15 21:50:30 +Discription: +Environment: +''' +import os +import numpy as np + + +def save_results(rewards,moving_average_rewards,ep_steps,tag='train',result_path='./result'): + if not os.path.exists(result_path): # 检测是否存在文件夹 + os.mkdir(result_path) + np.save(result_path+'rewards_'+tag+'.npy', rewards) + np.save(result_path+'moving_average_rewards_'+tag+'.npy', moving_average_rewards) + np.save(result_path+'steps_'+tag+'.npy',ep_steps ) \ No newline at end of file