Merge branch 'master' of github.com:datawhalechina/leedeeprl-notes
This commit is contained in:
42
codes/PolicyGradient/README.md
Normal file
42
codes/PolicyGradient/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Policy Gradient
|
||||
实现的是Policy Gradient最基本的REINFORCE方法
|
||||
## 原理讲解
|
||||
|
||||
参考我的博客[Policy Gradient算法实战](https://blog.csdn.net/JohnJim0/article/details/110236851)
|
||||
|
||||
## 环境
|
||||
|
||||
python 3.7.9
|
||||
|
||||
pytorch 1.6.0
|
||||
|
||||
tensorboard 2.3.0
|
||||
|
||||
torchvision 0.7.0
|
||||
|
||||
## 程序运行方法
|
||||
|
||||
train:
|
||||
|
||||
```python
|
||||
python main.py
|
||||
```
|
||||
|
||||
eval:
|
||||
|
||||
```python
|
||||
python main.py --train 0
|
||||
```
|
||||
tensorboard:
|
||||
```python
|
||||
tensorboard --logdir logs
|
||||
```
|
||||
|
||||
|
||||
## 参考
|
||||
|
||||
[REINFORCE和Reparameterization Trick](https://blog.csdn.net/JohnJim0/article/details/110230703)
|
||||
|
||||
[Policy Gradient paper](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
|
||||
[REINFORCE](https://towardsdatascience.com/policy-gradient-methods-104c783251e0)
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:27:44
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-11-23 12:05:03
|
||||
LastEditTime: 2020-11-23 17:04:37
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -18,9 +18,9 @@ from model import FCN
|
||||
|
||||
class PolicyGradient:
|
||||
|
||||
def __init__(self, n_states,device='cpu',gamma = 0.99,lr = 0.01,batch_size=5):
|
||||
def __init__(self, state_dim,device='cpu',gamma = 0.99,lr = 0.01,batch_size=5):
|
||||
self.gamma = gamma
|
||||
self.policy_net = FCN(n_states)
|
||||
self.policy_net = FCN(state_dim)
|
||||
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=lr)
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -65,4 +65,8 @@ class PolicyGradient:
|
||||
loss = -m.log_prob(action) * reward # Negtive score function x reward
|
||||
# print(loss)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.step()
|
||||
def save_model(self,path):
|
||||
torch.save(self.policy_net.state_dict(), path)
|
||||
def load_model(self,path):
|
||||
self.policy_net.load_state_dict(torch.load(path))
|
||||
@@ -14,6 +14,6 @@ import gym
|
||||
def env_init():
|
||||
env = gym.make('CartPole-v0') # 可google为什么unwrapped gym,此处一般不需要
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
state_dim = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
return env,n_states,n_actions
|
||||
return env,state_dim,n_actions
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -5,28 +5,38 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:21:53
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-11-23 12:06:15
|
||||
LastEditTime: 2020-11-24 19:52:40
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
from itertools import count
|
||||
import torch
|
||||
import os
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from env import env_init
|
||||
from params import get_args
|
||||
from agent import PolicyGradient
|
||||
|
||||
from params import SEQUENCE, SAVED_MODEL_PATH, RESULT_PATH
|
||||
from utils import save_results,save_model
|
||||
from plot import plot
|
||||
def train(cfg):
|
||||
env,n_states,n_actions = env_init()
|
||||
env,state_dim,n_actions = env_init()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu
|
||||
agent = PolicyGradient(n_states,device = device,lr = cfg.policy_lr)
|
||||
agent = PolicyGradient(state_dim,device = device,lr = cfg.policy_lr)
|
||||
'''下面带pool都是存放的transition序列用于gradient'''
|
||||
state_pool = [] # 存放每batch_size个episode的state序列
|
||||
action_pool = []
|
||||
reward_pool = []
|
||||
''' 存储每个episode的reward用于绘图'''
|
||||
rewards = []
|
||||
moving_average_rewards = []
|
||||
log_dir=os.path.split(os.path.abspath(__file__))[0]+"/logs/train/" + SEQUENCE
|
||||
writer = SummaryWriter(log_dir) # 使用tensorboard的writer
|
||||
for i_episode in range(cfg.train_eps):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for t in count():
|
||||
for _ in count():
|
||||
action = agent.choose_action(state) # 根据当前环境state选择action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
@@ -39,14 +49,61 @@ def train(cfg):
|
||||
if done:
|
||||
print('Episode:', i_episode, ' Reward:', ep_reward)
|
||||
break
|
||||
# if i_episode % cfg.batch_size == 0:
|
||||
if i_episode > 0 and i_episode % 5 == 0:
|
||||
if i_episode > 0 and i_episode % cfg.batch_size == 0:
|
||||
agent.update(reward_pool,state_pool,action_pool)
|
||||
state_pool = [] # 每个episode的state
|
||||
action_pool = []
|
||||
reward_pool = []
|
||||
rewards.append(ep_reward)
|
||||
if i_episode == 0:
|
||||
moving_average_rewards.append(ep_reward)
|
||||
else:
|
||||
moving_average_rewards.append(
|
||||
0.9*moving_average_rewards[-1]+0.1*ep_reward)
|
||||
writer.add_scalars('rewards',{'raw':rewards[-1], 'moving_average': moving_average_rewards[-1]}, i_episode+1)
|
||||
writer.close()
|
||||
print('Complete training!')
|
||||
save_model(agent,model_path=SAVED_MODEL_PATH)
|
||||
'''存储reward等相关结果'''
|
||||
save_results(rewards,moving_average_rewards,tag='train',result_path=RESULT_PATH)
|
||||
plot(rewards)
|
||||
plot(moving_average_rewards,ylabel='moving_average_rewards_train')
|
||||
|
||||
|
||||
def eval(cfg,saved_model_path = SAVED_MODEL_PATH):
|
||||
env,state_dim,n_actions = env_init()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu
|
||||
agent = PolicyGradient(state_dim,device = device,lr = cfg.policy_lr)
|
||||
agent.load_model(saved_model_path+'checkpoint.pth')
|
||||
rewards = []
|
||||
moving_average_rewards = []
|
||||
log_dir=os.path.split(os.path.abspath(__file__))[0]+"/logs/eval/" + SEQUENCE
|
||||
writer = SummaryWriter(log_dir) # 使用tensorboard的writer
|
||||
for i_episode in range(cfg.eval_eps):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
action = agent.choose_action(state) # 根据当前环境state选择action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
state = next_state
|
||||
if done:
|
||||
print('Episode:', i_episode, ' Reward:', ep_reward)
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
if i_episode == 0:
|
||||
moving_average_rewards.append(ep_reward)
|
||||
else:
|
||||
moving_average_rewards.append(
|
||||
0.9*moving_average_rewards[-1]+0.1*ep_reward)
|
||||
writer.add_scalars('rewards',{'raw':rewards[-1], 'moving_average': moving_average_rewards[-1]}, i_episode+1)
|
||||
writer.close()
|
||||
print('Complete evaling!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
train(cfg)
|
||||
if cfg.train:
|
||||
train(cfg)
|
||||
eval(cfg)
|
||||
else:
|
||||
model_path = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"
|
||||
eval(cfg,saved_model_path=model_path)
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:18:46
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-11-23 01:58:22
|
||||
LastEditTime: 2020-11-27 16:55:25
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -13,11 +13,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
class FCN(nn.Module):
|
||||
''' 全连接网络'''
|
||||
def __init__(self,n_states):
|
||||
def __init__(self,state_dim):
|
||||
super(FCN, self).__init__()
|
||||
# 24和36为hidden layer的层数,可根据n_states, n_actions的情况来改变
|
||||
self.fc1 = nn.Linear(n_states, 24)
|
||||
self.fc2 = nn.Linear(24, 36)
|
||||
# 24和36为hidden layer的层数,可根据state_dim, n_actions的情况来改变
|
||||
self.fc1 = nn.Linear(state_dim, 36)
|
||||
self.fc2 = nn.Linear(36, 36)
|
||||
self.fc3 = nn.Linear(36, 1) # Prob of Left
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -5,15 +5,25 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:25:37
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-11-22 23:32:44
|
||||
LastEditTime: 2020-11-26 19:11:21
|
||||
Discription: 存储参数
|
||||
Environment:
|
||||
'''
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
|
||||
SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/'
|
||||
RESULT_PATH = os.path.split(os.path.abspath(__file__))[0]+"/result/"+SEQUENCE+'/'
|
||||
|
||||
def get_args():
|
||||
'''训练参数'''
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--train_eps", default=1200, type=int) # 训练的最大episode数目
|
||||
parser.add_argument("--train", default=1, type=int) # 1 表示训练,0表示只进行eval
|
||||
parser.add_argument("--train_eps", default=300, type=int) # 训练的最大episode数目
|
||||
parser.add_argument("--eval_eps", default=100, type=int) # 训练的最大episode数目
|
||||
parser.add_argument("--batch_size", default=4, type=int) # 用于gradient的episode数目
|
||||
parser.add_argument("--policy_lr", default=0.01, type=float) # 学习率
|
||||
config = parser.parse_args()
|
||||
return config
|
||||
46
codes/PolicyGradient/plot.py
Normal file
46
codes/PolicyGradient/plot.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-23 13:48:46
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-11-23 13:48:48
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def plot(item,ylabel='rewards_train', save_fig = True):
|
||||
'''plot using searborn to plot
|
||||
'''
|
||||
sns.set()
|
||||
plt.figure()
|
||||
plt.plot(np.arange(len(item)), item)
|
||||
plt.title(ylabel+' of DQN')
|
||||
plt.ylabel(ylabel)
|
||||
plt.xlabel('episodes')
|
||||
if save_fig:
|
||||
plt.savefig(os.path.dirname(__file__)+"/result/"+ylabel+".png")
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
output_path = os.path.split(os.path.abspath(__file__))[0]+"/result/"
|
||||
tag = 'train'
|
||||
rewards=np.load(output_path+"rewards_"+tag+".npy", )
|
||||
moving_average_rewards=np.load(output_path+"moving_average_rewards_"+tag+".npy",)
|
||||
steps=np.load(output_path+"steps_"+tag+".npy")
|
||||
plot(rewards)
|
||||
plot(moving_average_rewards,ylabel='moving_average_rewards_'+tag)
|
||||
plot(steps,ylabel='steps_'+tag)
|
||||
tag = 'eval'
|
||||
rewards=np.load(output_path+"rewards_"+tag+".npy", )
|
||||
moving_average_rewards=np.load(output_path+"moving_average_rewards_"+tag+".npy",)
|
||||
steps=np.load(output_path+"steps_"+tag+".npy")
|
||||
plot(rewards,ylabel='rewards_'+tag)
|
||||
plot(moving_average_rewards,ylabel='moving_average_rewards_'+tag)
|
||||
plot(steps,ylabel='steps_'+tag)
|
||||
Binary file not shown.
BIN
codes/PolicyGradient/result/20201123-135302/rewards_train.npy
Normal file
BIN
codes/PolicyGradient/result/20201123-135302/rewards_train.npy
Normal file
Binary file not shown.
Binary file not shown.
BIN
codes/PolicyGradient/result/20201126-191039/rewards_train.npy
Normal file
BIN
codes/PolicyGradient/result/20201126-191039/rewards_train.npy
Normal file
Binary file not shown.
Binary file not shown.
BIN
codes/PolicyGradient/result/20201126-191145/rewards_train.npy
Normal file
BIN
codes/PolicyGradient/result/20201126-191145/rewards_train.npy
Normal file
Binary file not shown.
BIN
codes/PolicyGradient/result/moving_average_rewards_train.png
Normal file
BIN
codes/PolicyGradient/result/moving_average_rewards_train.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
codes/PolicyGradient/result/rewards_train.png
Normal file
BIN
codes/PolicyGradient/result/rewards_train.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 59 KiB |
BIN
codes/PolicyGradient/saved_model/20201123-135302/checkpoint.pth
Normal file
BIN
codes/PolicyGradient/saved_model/20201123-135302/checkpoint.pth
Normal file
Binary file not shown.
BIN
codes/PolicyGradient/saved_model/20201126-191039/checkpoint.pth
Normal file
BIN
codes/PolicyGradient/saved_model/20201126-191039/checkpoint.pth
Normal file
Binary file not shown.
BIN
codes/PolicyGradient/saved_model/20201126-191145/checkpoint.pth
Normal file
BIN
codes/PolicyGradient/saved_model/20201126-191145/checkpoint.pth
Normal file
Binary file not shown.
BIN
codes/PolicyGradient/saved_model/checkpoint.pth
Normal file
BIN
codes/PolicyGradient/saved_model/checkpoint.pth
Normal file
Binary file not shown.
29
codes/PolicyGradient/utils.py
Normal file
29
codes/PolicyGradient/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-23 13:44:52
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-11-23 13:45:42
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_results(rewards,moving_average_rewards,tag='train',result_path='./result'):
|
||||
'''保存reward等结果
|
||||
'''
|
||||
if not os.path.exists(result_path): # 检测是否存在文件夹
|
||||
os.mkdir(result_path)
|
||||
np.save(result_path+'rewards_'+tag+'.npy', rewards)
|
||||
np.save(result_path+'moving_average_rewards_'+tag+'.npy', moving_average_rewards)
|
||||
print('results saved!')
|
||||
|
||||
def save_model(agent,model_path='./saved_model'):
|
||||
if not os.path.exists(model_path): # 检测是否存在文件夹
|
||||
os.mkdir(model_path)
|
||||
agent.save_model(model_path+'checkpoint.pth')
|
||||
print('model saved!')
|
||||
3
codes/dqn/.vscode/settings.json
vendored
3
codes/dqn/.vscode/settings.json
vendored
@@ -1,3 +0,0 @@
|
||||
{
|
||||
"python.pythonPath": "/Users/jj/anaconda3/envs/py37/bin/python"
|
||||
}
|
||||
@@ -1,3 +1,8 @@
|
||||
## 思路
|
||||
|
||||
见[我的博客](https://blog.csdn.net/JohnJim0/article/details/109557173)
|
||||
## 环境
|
||||
|
||||
python 3.7.9
|
||||
|
||||
pytorch 1.6.0
|
||||
@@ -6,6 +11,7 @@ tensorboard 2.3.0
|
||||
|
||||
torchvision 0.7.0
|
||||
|
||||
## 使用
|
||||
|
||||
train:
|
||||
|
||||
@@ -18,7 +24,12 @@ eval:
|
||||
```python
|
||||
python main.py --train 0
|
||||
```
|
||||
|
||||
可视化:
|
||||
```python
|
||||
tensorboard --logdir logs
|
||||
```
|
||||
```
|
||||
|
||||
## Torch知识
|
||||
|
||||
[with torch.no_grad()](https://www.jianshu.com/p/1cea017f5d11)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:50:49
|
||||
@LastEditor: John
|
||||
LastEditTime: 2020-10-15 21:56:21
|
||||
LastEditTime: 2020-11-22 11:12:30
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -24,11 +24,12 @@ from memory import ReplayBuffer
|
||||
from model import FCN
|
||||
class DQN:
|
||||
def __init__(self, n_states, n_actions, gamma=0.99, epsilon_start=0.9, epsilon_end=0.05, epsilon_decay=200, memory_capacity=10000, policy_lr=0.01, batch_size=128, device="cpu"):
|
||||
self.actions_count = 0
|
||||
|
||||
self.n_actions = n_actions # 总的动作个数
|
||||
self.device = device # 设备,cpu或gpu等
|
||||
self.gamma = gamma
|
||||
self.gamma = gamma # 奖励的折扣因子
|
||||
# e-greedy策略相关参数
|
||||
self.actions_count = 0 # 用于epsilon的衰减计数
|
||||
self.epsilon = 0
|
||||
self.epsilon_start = epsilon_start
|
||||
self.epsilon_end = epsilon_end
|
||||
@@ -67,12 +68,11 @@ class DQN:
|
||||
action = random.randrange(self.n_actions)
|
||||
return action
|
||||
else:
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(): # 取消保存梯度
|
||||
# 先转为张量便于丢给神经网络,state元素数据原本为float64
|
||||
# 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价
|
||||
state = torch.tensor(
|
||||
[state], device='cpu', dtype=torch.float32)
|
||||
# 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
|
||||
[state], device='cpu', dtype=torch.float32) # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
|
||||
q_value = self.target_net(state)
|
||||
# tensor.max(1)返回每行的最大值以及对应的下标,
|
||||
# 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0]))
|
||||
@@ -86,8 +86,8 @@ class DQN:
|
||||
# 从memory中随机采样transition
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
|
||||
self.batch_size)
|
||||
# 转为张量
|
||||
# 例如tensor([[-4.5543e-02, -2.3910e-01, 1.8344e-02, 2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02, 2.3400e-01]])
|
||||
'''转为张量
|
||||
例如tensor([[-4.5543e-02, -2.3910e-01, 1.8344e-02, 2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02, 2.3400e-01]])'''
|
||||
state_batch = torch.tensor(
|
||||
state_batch, device=self.device, dtype=torch.float)
|
||||
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(
|
||||
@@ -99,9 +99,8 @@ class DQN:
|
||||
done_batch = torch.tensor(np.float32(
|
||||
done_batch), device=self.device).unsqueeze(1) # 将bool转为float然后转为张量
|
||||
|
||||
# 计算当前(s_t,a)对应的Q(s_t, a)
|
||||
# 关于torch.gather,对于a=torch.Tensor([[1,2],[3,4]])
|
||||
# 那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])
|
||||
'''计算当前(s_t,a)对应的Q(s_t, a)'''
|
||||
'''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
|
||||
q_values = self.policy_net(state_batch).gather(
|
||||
dim=1, index=action_batch) # 等价于self.forward
|
||||
# 计算所有next states的V(s_{t+1}),即通过target_net中选取reward最大的对应states
|
||||
@@ -119,6 +118,7 @@ class DQN:
|
||||
self.loss.backward()
|
||||
for param in self.policy_net.parameters(): # clip防止梯度爆炸
|
||||
param.grad.data.clamp_(-1, 1)
|
||||
|
||||
self.optimizer.step() # 更新模型
|
||||
|
||||
def save_model(self,path):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:48:57
|
||||
@LastEditor: John
|
||||
LastEditTime: 2020-10-15 22:00:28
|
||||
LastEditTime: 2020-11-23 11:58:17
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -16,7 +16,7 @@ import argparse
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import datetime
|
||||
import os
|
||||
from utils import save_results
|
||||
from utils import save_results,save_model
|
||||
|
||||
SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/'
|
||||
@@ -53,7 +53,7 @@ def get_args():
|
||||
def train(cfg):
|
||||
print('Start to train ! \n')
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu
|
||||
env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym,此处一般不需要
|
||||
env = gym.make('CartPole-v0')
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
@@ -95,10 +95,7 @@ def train(cfg):
|
||||
writer.close()
|
||||
print('Complete training!')
|
||||
''' 保存模型 '''
|
||||
if not os.path.exists(SAVED_MODEL_PATH): # 检测是否存在文件夹
|
||||
os.mkdir(SAVED_MODEL_PATH)
|
||||
agent.save_model(SAVED_MODEL_PATH+'checkpoint.pth')
|
||||
print('model saved!')
|
||||
save_model(agent,model_path=SAVED_MODEL_PATH)
|
||||
'''存储reward等相关结果'''
|
||||
save_results(rewards,moving_average_rewards,ep_steps,tag='train',result_path=RESULT_PATH)
|
||||
|
||||
@@ -110,7 +107,7 @@ def eval(cfg, saved_model_path = SAVED_MODEL_PATH):
|
||||
env.seed(1) # 设置env随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
agent = DQN(n_states=n_states, n_actions=n_actions, device=device, gamma=cfg.gamma, epsilon_start=cfg.epsilon_start,
|
||||
agent = DQN(n_states=n_states, n_actions=n_actions, device="cpu", gamma=cfg.gamma, epsilon_start=cfg.epsilon_start,
|
||||
epsilon_end=cfg.epsilon_end, epsilon_decay=cfg.epsilon_decay, policy_lr=cfg.policy_lr, memory_capacity=cfg.memory_capacity, batch_size=cfg.batch_size)
|
||||
agent.load_model(saved_model_path+'checkpoint.pth')
|
||||
rewards = []
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-11 16:30:09
|
||||
@LastEditor: John
|
||||
LastEditTime: 2020-10-15 22:01:50
|
||||
LastEditTime: 2020-11-23 13:48:31
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -27,18 +27,6 @@ def plot(item,ylabel='rewards_train', save_fig = True):
|
||||
plt.savefig(os.path.dirname(__file__)+"/result/"+ylabel+".png")
|
||||
plt.show()
|
||||
|
||||
# def plot(item,ylabel='rewards'):
|
||||
#
|
||||
# df = pd.DataFrame(dict(time=np.arange(len(item)),value=item))
|
||||
# g = sns.relplot(x="time", y="value", kind="line", data=df)
|
||||
# # g.fig.autofmt_xdate()
|
||||
# # sns.lineplot(time=time, data=item, color="r", condition="behavior_cloning")
|
||||
# # # sns.tsplot(time=time, data=x2, color="b", condition="dagger")
|
||||
# # plt.ylabel("Reward")
|
||||
# # plt.xlabel("Iteration Number")
|
||||
# # plt.title("Imitation Learning")
|
||||
|
||||
# plt.show()
|
||||
if __name__ == "__main__":
|
||||
|
||||
output_path = os.path.split(os.path.abspath(__file__))[0]+"/result/"
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-10-15 21:28:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2020-10-15 21:50:30
|
||||
LastEditTime: 2020-10-30 16:56:55
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -14,8 +14,17 @@ import numpy as np
|
||||
|
||||
|
||||
def save_results(rewards,moving_average_rewards,ep_steps,tag='train',result_path='./result'):
|
||||
'''保存reward等结果
|
||||
'''
|
||||
if not os.path.exists(result_path): # 检测是否存在文件夹
|
||||
os.mkdir(result_path)
|
||||
np.save(result_path+'rewards_'+tag+'.npy', rewards)
|
||||
np.save(result_path+'moving_average_rewards_'+tag+'.npy', moving_average_rewards)
|
||||
np.save(result_path+'steps_'+tag+'.npy',ep_steps )
|
||||
np.save(result_path+'steps_'+tag+'.npy',ep_steps )
|
||||
print('results saved!')
|
||||
|
||||
def save_model(agent,model_path='./saved_model'):
|
||||
if not os.path.exists(model_path): # 检测是否存在文件夹
|
||||
os.mkdir(model_path)
|
||||
agent.save_model(model_path+'checkpoint.pth')
|
||||
print('model saved!')
|
||||
Reference in New Issue
Block a user