更新算法模版

This commit is contained in:
johnjim0816
2022-11-06 12:15:36 +08:00
parent 466a17707f
commit dc78698262
256 changed files with 17282 additions and 10229 deletions

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2022-10-30 11:30:56
LastEditor: JiangJi
LastEditTime: 2022-10-31 00:33:15
Discription: default parameters of PPO
'''
from common.config import GeneralConfig,AlgoConfig
class GeneralConfigPPO(GeneralConfig):
def __init__(self) -> None:
self.env_name = "CartPole-v0"
self.algo_name = "PPO"
self.seed = 1
self.device = "cuda"
self.train_eps = 100 # number of episodes for training
self.test_eps = 10 # number of episodes for testing
self.max_steps = 200 # max steps for each episode
class AlgoConfigPPO(AlgoConfig):
def __init__(self) -> None:
self.gamma = 0.99 # discount factor
self.continuous = False # continuous action space or not
self.policy_clip = 0.2 # clip range of policy
self.n_epochs = 10 # number of epochs
self.gae_lambda = 0.95 # gae lambda
self.actor_lr = 0.0003 # learning rate of actor
self.critic_lr = 0.0003 # learning rate of critic
self.actor_hidden_dim = 256 #
self.critic_hidden_dim = 256
self.n_epochs = 4 # epochs
self.batch_size = 5 #
self.policy_clip = 0.2
self.update_fre = 20 # frequency of updating agent

View File

@@ -1,20 +0,0 @@
{
"algo_name": "PPO",
"env_name": "CartPole-v0",
"continuous": false,
"train_eps": 200,
"test_eps": 20,
"gamma": 0.99,
"batch_size": 5,
"n_epochs": 4,
"actor_lr": 0.0003,
"critic_lr": 0.0003,
"gae_lambda": 0.95,
"policy_clip": 0.2,
"update_fre": 20,
"hidden_dim": 256,
"device": "cpu",
"result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PPO/outputs/CartPole-v0/20220731-233512/results/",
"model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PPO/outputs/CartPole-v0/20220731-233512/models/",
"save_fig": true
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 65 KiB

View File

@@ -0,0 +1,25 @@
{
"algo_name": "PPO",
"env_name": "CartPole-v0",
"continuous": false,
"train_eps": 200,
"test_eps": 20,
"gamma": 0.99,
"batch_size": 5,
"n_epochs": 4,
"actor_lr": 0.0003,
"critic_lr": 0.0003,
"gae_lambda": 0.95,
"policy_clip": 0.2,
"update_fre": 20,
"actor_hidden_dim": 256,
"critic_hidden_dim": 256,
"device": "cpu",
"seed": 10,
"show_fig": false,
"save_fig": true,
"result_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PPO/outputs/CartPole-v0/20220920-213310/results/",
"model_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PPO/outputs/CartPole-v0/20220920-213310/models/",
"n_states": 4,
"n_actions": 2
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -0,0 +1,21 @@
episodes,rewards
0,200.0
1,200.0
2,200.0
3,200.0
4,200.0
5,200.0
6,200.0
7,200.0
8,200.0
9,200.0
10,200.0
11,200.0
12,200.0
13,200.0
14,200.0
15,200.0
16,200.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 200.0
5 3 200.0
6 4 200.0
7 5 200.0
8 6 200.0
9 7 200.0
10 8 200.0
11 9 200.0
12 10 200.0
13 11 200.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 200.0
18 16 200.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

View File

@@ -0,0 +1,201 @@
episodes,rewards
0,34.0
1,12.0
2,47.0
3,29.0
4,20.0
5,23.0
6,33.0
7,25.0
8,11.0
9,30.0
10,18.0
11,16.0
12,15.0
13,25.0
14,33.0
15,19.0
16,50.0
17,23.0
18,21.0
19,42.0
20,60.0
21,64.0
22,30.0
23,31.0
24,90.0
25,43.0
26,54.0
27,74.0
28,30.0
29,82.0
30,50.0
31,53.0
32,25.0
33,27.0
34,145.0
35,118.0
36,141.0
37,148.0
38,200.0
39,191.0
40,71.0
41,105.0
42,100.0
43,120.0
44,80.0
45,40.0
46,104.0
47,39.0
48,89.0
49,60.0
50,30.0
51,24.0
52,20.0
53,23.0
54,30.0
55,32.0
56,20.0
57,12.0
58,25.0
59,25.0
60,24.0
61,29.0
62,200.0
63,62.0
64,200.0
65,58.0
66,81.0
67,200.0
68,52.0
69,140.0
70,200.0
71,74.0
72,200.0
73,29.0
74,124.0
75,129.0
76,200.0
77,194.0
78,175.0
79,117.0
80,200.0
81,186.0
82,114.0
83,200.0
84,166.0
85,150.0
86,135.0
87,200.0
88,200.0
89,133.0
90,111.0
91,200.0
92,90.0
93,200.0
94,147.0
95,30.0
96,137.0
97,200.0
98,200.0
99,179.0
100,167.0
101,186.0
102,169.0
103,200.0
104,200.0
105,171.0
106,200.0
107,181.0
108,125.0
109,200.0
110,200.0
111,122.0
112,200.0
113,124.0
114,95.0
115,102.0
116,118.0
117,91.0
118,64.0
119,124.0
120,122.0
121,76.0
122,68.0
123,40.0
124,52.0
125,51.0
126,50.0
127,49.0
128,37.0
129,76.0
130,83.0
131,76.0
132,92.0
133,113.0
134,94.0
135,157.0
136,92.0
137,200.0
138,123.0
139,200.0
140,200.0
141,200.0
142,140.0
143,200.0
144,200.0
145,200.0
146,200.0
147,200.0
148,200.0
149,200.0
150,200.0
151,78.0
152,200.0
153,200.0
154,200.0
155,200.0
156,200.0
157,200.0
158,200.0
159,200.0
160,200.0
161,200.0
162,107.0
163,187.0
164,200.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,200.0
197,200.0
198,200.0
199,200.0
1 episodes rewards
2 0 34.0
3 1 12.0
4 2 47.0
5 3 29.0
6 4 20.0
7 5 23.0
8 6 33.0
9 7 25.0
10 8 11.0
11 9 30.0
12 10 18.0
13 11 16.0
14 12 15.0
15 13 25.0
16 14 33.0
17 15 19.0
18 16 50.0
19 17 23.0
20 18 21.0
21 19 42.0
22 20 60.0
23 21 64.0
24 22 30.0
25 23 31.0
26 24 90.0
27 25 43.0
28 26 54.0
29 27 74.0
30 28 30.0
31 29 82.0
32 30 50.0
33 31 53.0
34 32 25.0
35 33 27.0
36 34 145.0
37 35 118.0
38 36 141.0
39 37 148.0
40 38 200.0
41 39 191.0
42 40 71.0
43 41 105.0
44 42 100.0
45 43 120.0
46 44 80.0
47 45 40.0
48 46 104.0
49 47 39.0
50 48 89.0
51 49 60.0
52 50 30.0
53 51 24.0
54 52 20.0
55 53 23.0
56 54 30.0
57 55 32.0
58 56 20.0
59 57 12.0
60 58 25.0
61 59 25.0
62 60 24.0
63 61 29.0
64 62 200.0
65 63 62.0
66 64 200.0
67 65 58.0
68 66 81.0
69 67 200.0
70 68 52.0
71 69 140.0
72 70 200.0
73 71 74.0
74 72 200.0
75 73 29.0
76 74 124.0
77 75 129.0
78 76 200.0
79 77 194.0
80 78 175.0
81 79 117.0
82 80 200.0
83 81 186.0
84 82 114.0
85 83 200.0
86 84 166.0
87 85 150.0
88 86 135.0
89 87 200.0
90 88 200.0
91 89 133.0
92 90 111.0
93 91 200.0
94 92 90.0
95 93 200.0
96 94 147.0
97 95 30.0
98 96 137.0
99 97 200.0
100 98 200.0
101 99 179.0
102 100 167.0
103 101 186.0
104 102 169.0
105 103 200.0
106 104 200.0
107 105 171.0
108 106 200.0
109 107 181.0
110 108 125.0
111 109 200.0
112 110 200.0
113 111 122.0
114 112 200.0
115 113 124.0
116 114 95.0
117 115 102.0
118 116 118.0
119 117 91.0
120 118 64.0
121 119 124.0
122 120 122.0
123 121 76.0
124 122 68.0
125 123 40.0
126 124 52.0
127 125 51.0
128 126 50.0
129 127 49.0
130 128 37.0
131 129 76.0
132 130 83.0
133 131 76.0
134 132 92.0
135 133 113.0
136 134 94.0
137 135 157.0
138 136 92.0
139 137 200.0
140 138 123.0
141 139 200.0
142 140 200.0
143 141 200.0
144 142 140.0
145 143 200.0
146 144 200.0
147 145 200.0
148 146 200.0
149 147 200.0
150 148 200.0
151 149 200.0
152 150 200.0
153 151 78.0
154 152 200.0
155 153 200.0
156 154 200.0
157 155 200.0
158 156 200.0
159 157 200.0
160 158 200.0
161 159 200.0
162 160 200.0
163 161 200.0
164 162 107.0
165 163 187.0
166 164 200.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 200.0
199 197 200.0
200 198 200.0
201 199 200.0

View File

@@ -1,99 +1,53 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2021-03-23 15:17:42
LastEditor: John
LastEditTime: 2021-12-31 19:38:33
Discription:
Environment:
Date: 2022-09-26 16:11:36
LastEditor: JiangJi
LastEditTime: 2022-10-31 00:36:37
Discription: PPO-clip
'''
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.distributions.categorical import Categorical
class PPOMemory:
def __init__(self, batch_size):
self.states = []
self.probs = []
self.vals = []
self.actions = []
self.rewards = []
self.dones = []
self.batch_size = batch_size
def sample(self):
batch_step = np.arange(0, len(self.states), self.batch_size)
indices = np.arange(len(self.states), dtype=np.int64)
np.random.shuffle(indices)
batches = [indices[i:i+self.batch_size] for i in batch_step]
return np.array(self.states),np.array(self.actions),np.array(self.probs),\
np.array(self.vals),np.array(self.rewards),np.array(self.dones),batches
def push(self, state, action, probs, vals, reward, done):
self.states.append(state)
self.actions.append(action)
self.probs.append(probs)
self.vals.append(vals)
self.rewards.append(reward)
self.dones.append(done)
def clear(self):
self.states = []
self.probs = []
self.actions = []
self.rewards = []
self.dones = []
self.vals = []
class Actor(nn.Module):
def __init__(self,n_states, n_actions,
hidden_dim):
super(Actor, self).__init__()
self.actor = nn.Sequential(
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions),
nn.Softmax(dim=-1)
)
def forward(self, state):
dist = self.actor(state)
dist = Categorical(dist)
return dist
class Critic(nn.Module):
def __init__(self, n_states,hidden_dim):
super(Critic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
value = self.critic(state)
return value
class PPO:
def __init__(self, n_states, n_actions,cfg):
def __init__(self, models,memory,cfg):
self.gamma = cfg.gamma
self.continuous = cfg.continuous
self.continuous = cfg.continuous
self.policy_clip = cfg.policy_clip
self.n_epochs = cfg.n_epochs
self.batch_size = cfg.batch_size
self.gae_lambda = cfg.gae_lambda
self.device = cfg.device
self.actor = Actor(n_states, n_actions,cfg.hidden_dim).to(self.device)
self.critic = Critic(n_states,cfg.hidden_dim).to(self.device)
self.device = torch.device(cfg.device)
self.actor = models['Actor'].to(self.device)
self.critic = models['Critic'].to(self.device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg.actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)
self.memory = PPOMemory(cfg.batch_size)
self.memory = memory
self.loss = 0
def choose_action(self, state):
def sample_action(self, state):
state = np.array([state]) # 先转成数组再转tensor更高效
state = torch.tensor(state, dtype=torch.float).to(self.device)
probs = self.actor(state)
dist = Categorical(probs)
value = self.critic(state)
action = dist.sample()
probs = torch.squeeze(dist.log_prob(action)).item()
if self.continuous:
action = torch.tanh(action)
else:
action = torch.squeeze(action).item()
value = torch.squeeze(value).item()
return action, probs, value
@torch.no_grad()
def predict_action(self, state):
state = np.array([state]) # 先转成数组再转tensor更高效
state = torch.tensor(state, dtype=torch.float).to(self.device)
dist = self.actor(state)
@@ -148,12 +102,15 @@ class PPO:
self.actor_optimizer.step()
self.critic_optimizer.step()
self.memory.clear()
def save(self,path):
def save_model(self,path):
from pathlib import Path
# create path
Path(path).mkdir(parents=True, exist_ok=True)
actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
torch.save(self.actor.state_dict(), actor_checkpoint)
torch.save(self.critic.state_dict(), critic_checkpoint)
def load(self,path):
def load_model(self,path):
actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
self.actor.load_state_dict(torch.load(actor_checkpoint))

View File

@@ -1,132 +1,159 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
parent_path = os.path.dirname(curr_path) # parent path
sys.path.append(parent_path) # add path to system path
import gym
import torch
import numpy as np
import datetime
import numpy as np
import argparse
from common.utils import plot_rewards,save_args,save_results,make_dir
import torch.nn as nn
from common.utils import all_seed,merge_class_attrs
from common.models import ActorSoftmax, Critic
from common.memories import PGReplay
from common.launcher import Launcher
from envs.register import register_env
from ppo2 import PPO
from config,config import GeneralConfigPPO,AlgoConfigPPO
class PPOMemory:
def __init__(self, batch_size):
self.states = []
self.probs = []
self.vals = []
self.actions = []
self.rewards = []
self.terminateds = []
self.batch_size = batch_size
def sample(self):
batch_step = np.arange(0, len(self.states), self.batch_size)
indices = np.arange(len(self.states), dtype=np.int64)
np.random.shuffle(indices)
batches = [indices[i:i+self.batch_size] for i in batch_step]
return np.array(self.states),np.array(self.actions),np.array(self.probs),\
np.array(self.vals),np.array(self.rewards),np.array(self.terminateds),batches
def push(self, state, action, probs, vals, reward, terminated):
self.states.append(state)
self.actions.append(action)
self.probs.append(probs)
self.vals.append(vals)
self.rewards.append(reward)
self.terminateds.append(terminated)
def get_args():
""" Hyperparameters
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='PPO',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--continuous',default=False,type=bool,help="if PPO is continous") # PPO既可适用于连续动作空间也可以适用于离散动作空间
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
parser.add_argument('--batch_size',default=5,type=int) # mini-batch SGD中的批量大小
parser.add_argument('--n_epochs',default=4,type=int)
parser.add_argument('--actor_lr',default=0.0003,type=float,help="learning rate of actor net")
parser.add_argument('--critic_lr',default=0.0003,type=float,help="learning rate of critic net")
parser.add_argument('--gae_lambda',default=0.95,type=float)
parser.add_argument('--policy_clip',default=0.2,type=float) # PPO-clip中的clip参数一般是0.1~0.2左右
parser.add_argument('--update_fre',default=20,type=int)
parser.add_argument('--hidden_dim',default=256,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/' ) # path to save models
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
def env_agent_config(cfg,seed = 1):
''' 创建环境和智能体
'''
env = gym.make(cfg.env_name) # 创建环境
n_states = env.observation_space.shape[0] # 状态维度
if cfg.continuous:
n_actions = env.action_space.shape[0] # 动作维度
else:
n_actions = env.action_space.n # 动作维度
agent = PPO(n_states, n_actions, cfg) # 创建智能体
if seed !=0: # 设置随机种子
torch.manual_seed(seed)
env.seed(seed)
np.random.seed(seed)
return env, agent
def clear(self):
self.states = []
self.probs = []
self.actions = []
self.rewards = []
self.terminateds = []
self.vals = []
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
steps = 0
for i_ep in range(cfg.train_eps):
class Main(Launcher):
def __init__(self) -> None:
super().__init__()
self.cfgs['general_cfg'] = merge_class_attrs(self.cfgs['general_cfg'],GeneralConfigPPO())
self.cfgs['algo_cfg'] = merge_class_attrs(self.cfgs['algo_cfg'],AlgoConfigPPO())
def env_agent_config(self,cfg,logger):
''' create env and agent
'''
register_env(cfg.env_name)
env = gym.make(cfg.env_name,new_step_api=False) # create env
if cfg.seed !=0: # set random seed
all_seed(env,seed=cfg.seed)
try: # state dimension
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
except AttributeError:
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
n_actions = env.action_space.n # action dimension
logger.info(f"n_states: {n_states}, n_actions: {n_actions}") # print info
# update to cfg paramters
setattr(cfg, 'n_states', n_states)
setattr(cfg, 'n_actions', n_actions)
models = {'Actor':ActorSoftmax(n_states,n_actions, hidden_dim = cfg.actor_hidden_dim),'Critic':Critic(n_states,1,hidden_dim=cfg.critic_hidden_dim)}
memory = PGReplay # replay buffer
agent = PPO(models,memory,cfg) # create agent
return env, agent
def train_one_episode(self, env, agent, cfg):
ep_reward = 0 # reward per episode
ep_step = 0 # step per episode
state = env.reset()
done = False
ep_reward = 0
while not done:
action, prob, val = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
steps += 1
for _ in range(cfg.max_steps):
action, prob, val = agent.sample_action(state)
next_state, reward, terminated, _ = env.step(action)
ep_reward += reward
agent.memory.push(state, action, prob, val, reward, done)
if steps % cfg.update_fre == 0:
ep_step += 1
agent.memory.push((state, action, prob, val, reward, terminated))
if ep_step % cfg['update_fre'] == 0:
agent.update()
state = state_
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
if (i_ep+1)%10 == 0:
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f}")
print('完成训练!')
env.close()
res_dic = {'rewards':rewards,'ma_rewards':ma_rewards}
return res_dic
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.test_eps):
state = next_state
if terminated:
break
return agent, ep_reward, ep_step
def test_one_episode(self, env, agent, cfg):
ep_reward = 0 # reward per episode
ep_step = 0 # step per episode
state = env.reset()
done = False
ep_reward = 0
while not done:
action, prob, val = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
for _ in range(cfg.max_steps):
action, prob, val = agent.sample_action(state)
next_state, reward, terminated, _ = env.step(action)
ep_reward += reward
state = state_
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print('回合:{}/{}, 奖励:{}'.format(i_ep+1, cfg.test_eps, ep_reward))
print('完成训练!')
env.close()
res_dic = {'rewards':rewards,'ma_rewards':ma_rewards}
return res_dic
ep_step += 1
state = next_state
if terminated:
break
return agent, ep_reward, ep_step
def train(self,cfg,env,agent):
''' train agent
'''
print("Start training!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = 0
for i_ep in range(cfg['train_eps']):
state = env.reset()
ep_reward = 0
while True:
action, prob, val = agent.sample_action(state)
next_state, reward, terminated, _ = env.step(action)
steps += 1
ep_reward += reward
agent.memory.push(state, action, prob, val, reward, terminated)
if steps % cfg['update_fre'] == 0:
agent.update()
state = next_state
if terminated:
break
rewards.append(ep_reward)
if (i_ep+1)%10==0:
print(f"Episode: {i_ep+1}/{cfg['train_eps']}, Reward: {ep_reward:.2f}")
print("Finish training!")
return {'episodes':range(len(rewards)),'rewards':rewards}
def test(self,cfg,env,agent):
''' test agent
'''
print("Start testing!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
for i_ep in range(cfg['test_eps']):
state = env.reset()
ep_reward = 0
while True:
action, prob, val = agent.predict_action(state)
next_state, reward, terminated, _ = env.step(action)
ep_reward += reward
state = next_state
if terminated:
break
rewards.append(ep_reward)
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Reward: {ep_reward:.2f}")
print("Finish testing!")
return {'episodes':range(len(rewards)),'rewards':rewards}
if __name__ == "__main__":
cfg = get_args()
# 训练
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path)
save_args(cfg) # 保存参数
agent.save(path=cfg.model_path) # save model
save_results(res_dic, tag='train',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'], cfg, tag="train")
# 测试
env, agent = env_agent_config(cfg)
agent.load(path=cfg.model_path) # 导入模型
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path=cfg.result_path) # 保存结果
plot_rewards(res_dic['rewards'], res_dic['ma_rewards'],cfg, tag="test") # 画出结果
main = Main()
main.run()

View File

@@ -1,3 +1,13 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2022-09-19 14:48:16
LastEditor: JiangJi
LastEditTime: 2022-10-30 00:45:14
Discription:
'''
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径