hot update A2C

This commit is contained in:
johnjim0816
2022-08-29 15:12:33 +08:00
parent 99a3c1afec
commit 0b0f7e857d
109 changed files with 8213 additions and 1658 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-11 17:59:16
LastEditor: John
LastEditTime: 2022-08-25 14:26:36
LastEditTime: 2022-08-26 23:03:39
Discription:
Environment:
'''
@@ -20,117 +20,105 @@ import argparse
from envs.register import register_env
from envs.wrappers import CliffWalkingWapper
from Sarsa.sarsa import Sarsa
from common.utils import save_results,make_dir,plot_rewards,save_args,all_seed
from common.utils import all_seed
from common.launcher import Launcher
def get_args():
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='Racetrack-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=10,type=int,help="seed")
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
}
args = {**vars(args),**default_args} # type(dict)
return args
class Main(Launcher):
def get_args(self):
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default = 'Sarsa',type=str,help="name of algorithm")
parser.add_argument('--env_name',default = 'Racetrack-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default = 300,type=int,help="episodes of training")
parser.add_argument('--test_eps',default = 20,type=int,help="episodes of testing")
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=10,type=int,help="seed")
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
}
args = {**vars(args),**default_args} # type(dict)
return args
def env_agent_config(cfg):
register_env(cfg['env_name'])
env = gym.make(cfg['env_name'])
if cfg['seed'] !=0: # set random seed
all_seed(env,seed= cfg['seed'])
if cfg['env_name'] == 'CliffWalking-v0':
env = CliffWalkingWapper(env)
try: # state dimension
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
except AttributeError:
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
n_actions = env.action_space.n # action dimension
print(f"n_states: {n_states}, n_actions: {n_actions}")
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
agent = Sarsa(cfg)
return env,agent
def env_agent_config(self,cfg):
register_env(cfg['env_name'])
env = gym.make(cfg['env_name'])
if cfg['seed'] !=0: # set random seed
all_seed(env,seed= cfg['seed'])
if cfg['env_name'] == 'CliffWalking-v0':
env = CliffWalkingWapper(env)
try: # state dimension
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
except AttributeError:
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
n_actions = env.action_space.n # action dimension
print(f"n_states: {n_states}, n_actions: {n_actions}")
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
agent = Sarsa(cfg)
return env,agent
def train(cfg,env,agent):
print("Start training!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = [] # record steps for all episodes
for i_ep in range(cfg['train_eps']):
ep_reward = 0 # reward per episode
ep_step = 0 # step per episode
state = env.reset() # reset and obtain initial state
action = agent.sample_action(state)
while True:
# for _ in range(cfg.ep_max_steps):
next_state, reward, done, _ = env.step(action) # update env and return transitions
next_action = agent.sample_action(next_state)
agent.update(state, action, reward, next_state, next_action,done) # update agent
state = next_state # update state
action = next_action
ep_reward += reward
ep_step += 1
if done:
break
rewards.append(ep_reward)
steps.append(ep_step)
if (i_ep+1)%10==0:
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
print("Finish training!")
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
def train(self,cfg,env,agent):
print("Start training!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = [] # record steps for all episodes
for i_ep in range(cfg['train_eps']):
ep_reward = 0 # reward per episode
ep_step = 0 # step per episode
state = env.reset() # reset and obtain initial state
action = agent.sample_action(state)
# while True:
for _ in range(cfg['ep_max_steps']):
next_state, reward, done, _ = env.step(action) # update env and return transitions
next_action = agent.sample_action(next_state)
agent.update(state, action, reward, next_state, next_action,done) # update agent
state = next_state # update state
action = next_action
ep_reward += reward
ep_step += 1
if done:
break
rewards.append(ep_reward)
steps.append(ep_step)
if (i_ep+1)%10==0:
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps: {ep_step}, Epislon: {agent.epsilon:.3f}')
print("Finish training!")
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
def test(cfg,env,agent):
print("Start testing!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = [] # record steps for all episodes
for i_ep in range(cfg['test_eps']):
ep_reward = 0 # reward per episode
ep_step = 0
while True:
# for _ in range(cfg.ep_max_steps):
action = agent.predict_action(state)
next_state, reward, done = env.step(action)
state = next_state
ep_reward+=reward
ep_step+=1
if done:
break
rewards.append(ep_reward)
steps.append(ep_step)
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
print("Finish testing!")
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
def test(self,cfg,env,agent):
print("Start testing!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = [] # record steps for all episodes
for i_ep in range(cfg['test_eps']):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
for _ in range(cfg['ep_max_steps']):
action = agent.predict_action(state)
next_state, reward, done, _ = env.step(action)
state = next_state
ep_reward+=reward
ep_step+=1
if done:
break
rewards.append(ep_reward)
steps.append(ep_step)
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps: {ep_step}, Reward: {ep_reward:.2f}")
print("Finish testing!")
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
if __name__ == "__main__":
cfg = get_args()
# 训练
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path)
save_args(cfg) # save parameters
agent.save(path=cfg.model_path) # save model
save_results(res_dic, tag='train',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, tag="train")
# 测试
env, agent = env_agent_config(cfg)
agent.load(path=cfg.model_path) # 导入模型
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path=cfg.result_path) # 保存结果
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
main = Main()
main.run()

View File

@@ -1 +0,0 @@
{"algo_name": "Sarsa", "env_name": "CliffWalking-v0", "train_eps": 300, "test_eps": 20, "ep_max_steps": 200, "gamma": 0.99, "epsilon_start": 0.9, "epsilon_end": 0.01, "epsilon_decay": 200, "lr": 0.2, "device": "cpu", "result_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220803-142740/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220803-142740/models/", "save_fig": true}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

View File

@@ -1,15 +0,0 @@
{
"algo_name": "Sarsa",
"env_name": "CliffWalking-v0",
"train_eps": 400,
"test_eps": 20,
"gamma": 0.9,
"epsilon_start": 0.95,
"epsilon_end": 0.01,
"epsilon_decay": 300,
"lr": 0.1,
"device": "cpu",
"result_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\Sarsa/outputs/CliffWalking-v0/20220804-223029/results/",
"model_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\Sarsa/outputs/CliffWalking-v0/20220804-223029/models/",
"save_fig": true
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -0,0 +1,19 @@
{
"algo_name": "Sarsa",
"env_name": "CliffWalking-v0",
"train_eps": 400,
"test_eps": 20,
"gamma": 0.9,
"epsilon_start": 0.95,
"epsilon_end": 0.01,
"epsilon_decay": 300,
"lr": 0.1,
"device": "cpu",
"seed": 10,
"show_fig": false,
"save_fig": true,
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220825-213316/results/",
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220825-213316/models/",
"n_states": 48,
"n_actions": 4
}

View File

@@ -0,0 +1,21 @@
episodes,rewards,steps
0,-15,15
1,-15,15
2,-15,15
3,-15,15
4,-15,15
5,-15,15
6,-15,15
7,-15,15
8,-15,15
9,-15,15
10,-15,15
11,-15,15
12,-15,15
13,-15,15
14,-15,15
15,-15,15
16,-15,15
17,-15,15
18,-15,15
19,-15,15
1 episodes rewards steps
2 0 -15 15
3 1 -15 15
4 2 -15 15
5 3 -15 15
6 4 -15 15
7 5 -15 15
8 6 -15 15
9 7 -15 15
10 8 -15 15
11 9 -15 15
12 10 -15 15
13 11 -15 15
14 12 -15 15
15 13 -15 15
16 14 -15 15
17 15 -15 15
18 16 -15 15
19 17 -15 15
20 18 -15 15
21 19 -15 15

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

@@ -0,0 +1,401 @@
episodes,rewards,steps
0,-649,154
1,-2822,842
2,-176,176
3,-139,139
4,-221,221
5,-51,51
6,-219,219
7,-247,148
8,-90,90
9,-145,145
10,-104,104
11,-162,162
12,-49,49
13,-129,129
14,-140,140
15,-19,19
16,-131,131
17,-115,115
18,-43,43
19,-133,133
20,-73,73
21,-89,89
22,-131,131
23,-61,61
24,-113,113
25,-119,119
26,-119,119
27,-71,71
28,-132,132
29,-47,47
30,-79,79
31,-57,57
32,-125,125
33,-77,77
34,-87,87
35,-49,49
36,-57,57
37,-81,81
38,-81,81
39,-97,97
40,-61,61
41,-85,85
42,-217,118
43,-39,39
44,-117,117
45,-41,41
46,-71,71
47,-105,105
48,-73,73
49,-68,68
50,-95,95
51,-41,41
52,-41,41
53,-67,67
54,-71,71
55,-65,65
56,-41,41
57,-61,61
58,-81,81
59,-21,21
60,-76,76
61,-80,80
62,-23,23
63,-53,53
64,-67,67
65,-33,33
66,-41,41
67,-59,59
68,-33,33
69,-64,64
70,-188,89
71,-47,47
72,-57,57
73,-45,45
74,-33,33
75,-79,79
76,-45,45
77,-23,23
78,-47,47
79,-57,57
80,-47,47
81,-45,45
82,-53,53
83,-29,29
84,-33,33
85,-69,69
86,-61,61
87,-35,35
88,-59,59
89,-43,43
90,-17,17
91,-39,39
92,-59,59
93,-29,29
94,-31,31
95,-55,55
96,-35,35
97,-45,45
98,-29,29
99,-59,59
100,-25,25
101,-29,29
102,-33,33
103,-39,39
104,-19,19
105,-47,47
106,-57,57
107,-19,19
108,-47,47
109,-25,25
110,-23,23
111,-53,53
112,-39,39
113,-34,34
114,-27,27
115,-27,27
116,-63,63
117,-33,33
118,-17,17
119,-21,21
120,-19,19
121,-49,49
122,-25,25
123,-39,39
124,-25,25
125,-167,68
126,-35,35
127,-29,29
128,-31,31
129,-44,44
130,-33,33
131,-23,23
132,-37,37
133,-134,35
134,-31,31
135,-19,19
136,-29,29
137,-37,37
138,-25,25
139,-39,39
140,-47,47
141,-29,29
142,-27,27
143,-21,21
144,-41,41
145,-29,29
146,-25,25
147,-25,25
148,-21,21
149,-29,29
150,-39,39
151,-35,35
152,-35,35
153,-32,32
154,-31,31
155,-19,19
156,-21,21
157,-35,35
158,-33,33
159,-37,37
160,-25,25
161,-41,41
162,-25,25
163,-23,23
164,-27,27
165,-25,25
166,-39,39
167,-28,28
168,-24,24
169,-23,23
170,-41,41
171,-17,17
172,-35,35
173,-23,23
174,-29,29
175,-17,17
176,-39,39
177,-33,33
178,-29,29
179,-24,24
180,-23,23
181,-19,19
182,-15,15
183,-23,23
184,-39,39
185,-25,25
186,-35,35
187,-33,33
188,-19,19
189,-35,35
190,-21,21
191,-131,32
192,-15,15
193,-23,23
194,-21,21
195,-17,17
196,-23,23
197,-31,31
198,-21,21
199,-31,31
200,-35,35
201,-27,27
202,-19,19
203,-21,21
204,-23,23
205,-23,23
206,-21,21
207,-31,31
208,-25,25
209,-23,23
210,-17,17
211,-19,19
212,-25,25
213,-23,23
214,-19,19
215,-19,19
216,-25,25
217,-25,25
218,-25,25
219,-25,25
220,-23,23
221,-19,19
222,-19,19
223,-149,50
224,-41,41
225,-19,19
226,-29,29
227,-37,37
228,-17,17
229,-17,17
230,-19,19
231,-27,27
232,-19,19
233,-33,33
234,-23,23
235,-23,23
236,-34,34
237,-15,15
238,-33,33
239,-29,29
240,-17,17
241,-23,23
242,-17,17
243,-19,19
244,-21,21
245,-23,23
246,-17,17
247,-15,15
248,-39,39
249,-21,21
250,-23,23
251,-29,29
252,-15,15
253,-17,17
254,-29,29
255,-15,15
256,-21,21
257,-19,19
258,-19,19
259,-21,21
260,-17,17
261,-21,21
262,-27,27
263,-27,27
264,-21,21
265,-19,19
266,-17,17
267,-23,23
268,-19,19
269,-17,17
270,-19,19
271,-19,19
272,-17,17
273,-23,23
274,-17,17
275,-22,22
276,-31,31
277,-19,19
278,-17,17
279,-33,33
280,-19,19
281,-17,17
282,-31,31
283,-15,15
284,-15,15
285,-15,15
286,-29,29
287,-19,19
288,-17,17
289,-26,26
290,-17,17
291,-19,19
292,-15,15
293,-21,21
294,-21,21
295,-15,15
296,-19,19
297,-15,15
298,-17,17
299,-19,19
300,-17,17
301,-21,21
302,-17,17
303,-27,27
304,-17,17
305,-19,19
306,-15,15
307,-19,19
308,-33,33
309,-17,17
310,-20,20
311,-19,19
312,-17,17
313,-15,15
314,-23,23
315,-15,15
316,-15,15
317,-17,17
318,-25,25
319,-15,15
320,-17,17
321,-19,19
322,-17,17
323,-15,15
324,-23,23
325,-19,19
326,-17,17
327,-23,23
328,-15,15
329,-19,19
330,-15,15
331,-17,17
332,-19,19
333,-15,15
334,-17,17
335,-17,17
336,-19,19
337,-15,15
338,-19,19
339,-19,19
340,-17,17
341,-15,15
342,-21,21
343,-19,19
344,-17,17
345,-17,17
346,-15,15
347,-21,21
348,-20,20
349,-15,15
350,-15,15
351,-15,15
352,-19,19
353,-17,17
354,-15,15
355,-27,27
356,-15,15
357,-15,15
358,-23,23
359,-125,26
360,-132,33
361,-17,17
362,-15,15
363,-17,17
364,-23,23
365,-17,17
366,-15,15
367,-15,15
368,-17,17
369,-15,15
370,-17,17
371,-15,15
372,-15,15
373,-15,15
374,-15,15
375,-15,15
376,-15,15
377,-15,15
378,-15,15
379,-15,15
380,-17,17
381,-15,15
382,-15,15
383,-19,19
384,-15,15
385,-17,17
386,-27,27
387,-15,15
388,-21,21
389,-125,26
390,-15,15
391,-15,15
392,-15,15
393,-27,27
394,-15,15
395,-15,15
396,-17,17
397,-15,15
398,-15,15
399,-15,15
1 episodes rewards steps
2 0 -649 154
3 1 -2822 842
4 2 -176 176
5 3 -139 139
6 4 -221 221
7 5 -51 51
8 6 -219 219
9 7 -247 148
10 8 -90 90
11 9 -145 145
12 10 -104 104
13 11 -162 162
14 12 -49 49
15 13 -129 129
16 14 -140 140
17 15 -19 19
18 16 -131 131
19 17 -115 115
20 18 -43 43
21 19 -133 133
22 20 -73 73
23 21 -89 89
24 22 -131 131
25 23 -61 61
26 24 -113 113
27 25 -119 119
28 26 -119 119
29 27 -71 71
30 28 -132 132
31 29 -47 47
32 30 -79 79
33 31 -57 57
34 32 -125 125
35 33 -77 77
36 34 -87 87
37 35 -49 49
38 36 -57 57
39 37 -81 81
40 38 -81 81
41 39 -97 97
42 40 -61 61
43 41 -85 85
44 42 -217 118
45 43 -39 39
46 44 -117 117
47 45 -41 41
48 46 -71 71
49 47 -105 105
50 48 -73 73
51 49 -68 68
52 50 -95 95
53 51 -41 41
54 52 -41 41
55 53 -67 67
56 54 -71 71
57 55 -65 65
58 56 -41 41
59 57 -61 61
60 58 -81 81
61 59 -21 21
62 60 -76 76
63 61 -80 80
64 62 -23 23
65 63 -53 53
66 64 -67 67
67 65 -33 33
68 66 -41 41
69 67 -59 59
70 68 -33 33
71 69 -64 64
72 70 -188 89
73 71 -47 47
74 72 -57 57
75 73 -45 45
76 74 -33 33
77 75 -79 79
78 76 -45 45
79 77 -23 23
80 78 -47 47
81 79 -57 57
82 80 -47 47
83 81 -45 45
84 82 -53 53
85 83 -29 29
86 84 -33 33
87 85 -69 69
88 86 -61 61
89 87 -35 35
90 88 -59 59
91 89 -43 43
92 90 -17 17
93 91 -39 39
94 92 -59 59
95 93 -29 29
96 94 -31 31
97 95 -55 55
98 96 -35 35
99 97 -45 45
100 98 -29 29
101 99 -59 59
102 100 -25 25
103 101 -29 29
104 102 -33 33
105 103 -39 39
106 104 -19 19
107 105 -47 47
108 106 -57 57
109 107 -19 19
110 108 -47 47
111 109 -25 25
112 110 -23 23
113 111 -53 53
114 112 -39 39
115 113 -34 34
116 114 -27 27
117 115 -27 27
118 116 -63 63
119 117 -33 33
120 118 -17 17
121 119 -21 21
122 120 -19 19
123 121 -49 49
124 122 -25 25
125 123 -39 39
126 124 -25 25
127 125 -167 68
128 126 -35 35
129 127 -29 29
130 128 -31 31
131 129 -44 44
132 130 -33 33
133 131 -23 23
134 132 -37 37
135 133 -134 35
136 134 -31 31
137 135 -19 19
138 136 -29 29
139 137 -37 37
140 138 -25 25
141 139 -39 39
142 140 -47 47
143 141 -29 29
144 142 -27 27
145 143 -21 21
146 144 -41 41
147 145 -29 29
148 146 -25 25
149 147 -25 25
150 148 -21 21
151 149 -29 29
152 150 -39 39
153 151 -35 35
154 152 -35 35
155 153 -32 32
156 154 -31 31
157 155 -19 19
158 156 -21 21
159 157 -35 35
160 158 -33 33
161 159 -37 37
162 160 -25 25
163 161 -41 41
164 162 -25 25
165 163 -23 23
166 164 -27 27
167 165 -25 25
168 166 -39 39
169 167 -28 28
170 168 -24 24
171 169 -23 23
172 170 -41 41
173 171 -17 17
174 172 -35 35
175 173 -23 23
176 174 -29 29
177 175 -17 17
178 176 -39 39
179 177 -33 33
180 178 -29 29
181 179 -24 24
182 180 -23 23
183 181 -19 19
184 182 -15 15
185 183 -23 23
186 184 -39 39
187 185 -25 25
188 186 -35 35
189 187 -33 33
190 188 -19 19
191 189 -35 35
192 190 -21 21
193 191 -131 32
194 192 -15 15
195 193 -23 23
196 194 -21 21
197 195 -17 17
198 196 -23 23
199 197 -31 31
200 198 -21 21
201 199 -31 31
202 200 -35 35
203 201 -27 27
204 202 -19 19
205 203 -21 21
206 204 -23 23
207 205 -23 23
208 206 -21 21
209 207 -31 31
210 208 -25 25
211 209 -23 23
212 210 -17 17
213 211 -19 19
214 212 -25 25
215 213 -23 23
216 214 -19 19
217 215 -19 19
218 216 -25 25
219 217 -25 25
220 218 -25 25
221 219 -25 25
222 220 -23 23
223 221 -19 19
224 222 -19 19
225 223 -149 50
226 224 -41 41
227 225 -19 19
228 226 -29 29
229 227 -37 37
230 228 -17 17
231 229 -17 17
232 230 -19 19
233 231 -27 27
234 232 -19 19
235 233 -33 33
236 234 -23 23
237 235 -23 23
238 236 -34 34
239 237 -15 15
240 238 -33 33
241 239 -29 29
242 240 -17 17
243 241 -23 23
244 242 -17 17
245 243 -19 19
246 244 -21 21
247 245 -23 23
248 246 -17 17
249 247 -15 15
250 248 -39 39
251 249 -21 21
252 250 -23 23
253 251 -29 29
254 252 -15 15
255 253 -17 17
256 254 -29 29
257 255 -15 15
258 256 -21 21
259 257 -19 19
260 258 -19 19
261 259 -21 21
262 260 -17 17
263 261 -21 21
264 262 -27 27
265 263 -27 27
266 264 -21 21
267 265 -19 19
268 266 -17 17
269 267 -23 23
270 268 -19 19
271 269 -17 17
272 270 -19 19
273 271 -19 19
274 272 -17 17
275 273 -23 23
276 274 -17 17
277 275 -22 22
278 276 -31 31
279 277 -19 19
280 278 -17 17
281 279 -33 33
282 280 -19 19
283 281 -17 17
284 282 -31 31
285 283 -15 15
286 284 -15 15
287 285 -15 15
288 286 -29 29
289 287 -19 19
290 288 -17 17
291 289 -26 26
292 290 -17 17
293 291 -19 19
294 292 -15 15
295 293 -21 21
296 294 -21 21
297 295 -15 15
298 296 -19 19
299 297 -15 15
300 298 -17 17
301 299 -19 19
302 300 -17 17
303 301 -21 21
304 302 -17 17
305 303 -27 27
306 304 -17 17
307 305 -19 19
308 306 -15 15
309 307 -19 19
310 308 -33 33
311 309 -17 17
312 310 -20 20
313 311 -19 19
314 312 -17 17
315 313 -15 15
316 314 -23 23
317 315 -15 15
318 316 -15 15
319 317 -17 17
320 318 -25 25
321 319 -15 15
322 320 -17 17
323 321 -19 19
324 322 -17 17
325 323 -15 15
326 324 -23 23
327 325 -19 19
328 326 -17 17
329 327 -23 23
330 328 -15 15
331 329 -19 19
332 330 -15 15
333 331 -17 17
334 332 -19 19
335 333 -15 15
336 334 -17 17
337 335 -17 17
338 336 -19 19
339 337 -15 15
340 338 -19 19
341 339 -19 19
342 340 -17 17
343 341 -15 15
344 342 -21 21
345 343 -19 19
346 344 -17 17
347 345 -17 17
348 346 -15 15
349 347 -21 21
350 348 -20 20
351 349 -15 15
352 350 -15 15
353 351 -15 15
354 352 -19 19
355 353 -17 17
356 354 -15 15
357 355 -27 27
358 356 -15 15
359 357 -15 15
360 358 -23 23
361 359 -125 26
362 360 -132 33
363 361 -17 17
364 362 -15 15
365 363 -17 17
366 364 -23 23
367 365 -17 17
368 366 -15 15
369 367 -15 15
370 368 -17 17
371 369 -15 15
372 370 -17 17
373 371 -15 15
374 372 -15 15
375 373 -15 15
376 374 -15 15
377 375 -15 15
378 376 -15 15
379 377 -15 15
380 378 -15 15
381 379 -15 15
382 380 -17 17
383 381 -15 15
384 382 -15 15
385 383 -19 19
386 384 -15 15
387 385 -17 17
388 386 -27 27
389 387 -15 15
390 388 -21 21
391 389 -125 26
392 390 -15 15
393 391 -15 15
394 392 -15 15
395 393 -27 27
396 394 -15 15
397 395 -15 15
398 396 -17 17
399 397 -15 15
400 398 -15 15
401 399 -15 15

View File

@@ -0,0 +1 @@
{"algo_name": "Sarsa", "env_name": "Racetrack-v0", "train_eps": 300, "test_eps": 20, "gamma": 0.99, "epsilon_start": 0.9, "epsilon_end": 0.01, "epsilon_decay": 200, "lr": 0.2, "device": "cpu", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/Racetrack-v0/20220825-212738/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/Racetrack-v0/20220825-212738/models/", "n_states": 4, "n_actions": 9}

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

View File

@@ -0,0 +1,21 @@
episodes,rewards,steps
0,4,6
1,4,6
2,-1010,1000
3,-14,14
4,4,6
5,4,6
6,4,6
7,-1060,1000
8,2,8
9,-12,12
10,3,7
11,-15,15
12,3,7
13,4,6
14,-14,14
15,3,7
16,-18,18
17,4,6
18,4,6
19,-1020,1000
1 episodes rewards steps
2 0 4 6
3 1 4 6
4 2 -1010 1000
5 3 -14 14
6 4 4 6
7 5 4 6
8 6 4 6
9 7 -1060 1000
10 8 2 8
11 9 -12 12
12 10 3 7
13 11 -15 15
14 12 3 7
15 13 4 6
16 14 -14 14
17 15 3 7
18 16 -18 18
19 17 4 6
20 18 4 6
21 19 -1020 1000

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View File

@@ -0,0 +1,301 @@
episodes,rewards,steps
0,-3460,1000
1,-2800,1000
2,-2910,1000
3,-2620,1000
4,-2620,1000
5,-2590,1000
6,-2390,1000
7,-2510,1000
8,-2470,1000
9,-611,251
10,-891,371
11,-265,125
12,-2281,911
13,-1203,523
14,-616,266
15,-213,113
16,-633,273
17,-1112,482
18,-350,160
19,-852,342
20,-87,47
21,-11,11
22,-27,17
23,-117,57
24,-15,15
25,4,6
26,-27,17
27,-94,44
28,-184,84
29,-44,24
30,-150,80
31,-14,14
32,-219,89
33,-50,30
34,-111,61
35,-10,10
36,-28,18
37,-34,24
38,-12,12
39,-19,19
40,-136,66
41,-171,71
42,-51,31
43,4,6
44,-117,57
45,4,6
46,4,6
47,-127,67
48,-78,48
49,-311,131
50,-25,15
51,4,6
52,-49,29
53,-25,15
54,-78,48
55,-238,108
56,4,6
57,-17,17
58,-29,19
59,-218,98
60,4,6
61,-129,59
62,-344,144
63,-25,15
64,-15,15
65,-77,37
66,2,8
67,0,10
68,4,6
69,4,6
70,-242,102
71,3,7
72,4,6
73,-53,33
74,-14,14
75,4,6
76,4,6
77,-30,20
78,-12,12
79,2,8
80,-12,12
81,-150,70
82,-48,28
83,-102,52
84,4,6
85,-97,47
86,-10,10
87,-125,55
88,-28,18
89,-26,16
90,-107,57
91,4,6
92,-16,16
93,-84,44
94,-13,13
95,-43,23
96,-14,14
97,-12,12
98,-13,13
99,-2,12
100,-14,14
101,-47,27
102,4,6
103,4,6
104,-91,51
105,-65,35
106,4,6
107,-12,12
108,-14,14
109,-13,13
110,4,6
111,-41,31
112,-13,13
113,4,6
114,-4,14
115,-74,34
116,4,6
117,-60,30
118,4,6
119,-15,15
120,3,7
121,4,6
122,4,6
123,-19,19
124,4,6
125,-49,29
126,-13,13
127,-30,20
128,2,8
129,-21,21
130,-45,25
131,-32,22
132,-67,37
133,-46,26
134,0,10
135,-12,12
136,-9,9
137,-10,10
138,-14,14
139,4,6
140,-11,11
141,-12,12
142,2,8
143,-35,25
144,4,6
145,-73,43
146,4,6
147,-20,20
148,4,6
149,2,8
150,-29,19
151,-20,20
152,4,6
153,-28,18
154,4,6
155,4,6
156,4,6
157,4,6
158,-34,24
159,4,6
160,4,6
161,4,6
162,-25,15
163,4,6
164,3,7
165,-48,28
166,4,6
167,-58,38
168,-20,20
169,-9,9
170,3,7
171,4,6
172,3,7
173,-33,23
174,-50,30
175,-16,16
176,-32,22
177,-65,35
178,4,6
179,-13,13
180,-11,11
181,3,7
182,4,6
183,-16,16
184,-12,12
185,4,6
186,-48,28
187,-13,13
188,2,8
189,3,7
190,-27,17
191,3,7
192,4,6
193,4,6
194,4,6
195,4,6
196,4,6
197,-13,13
198,-14,14
199,4,6
200,4,6
201,-13,13
202,-33,23
203,4,6
204,-32,22
205,4,6
206,-48,28
207,4,6
208,4,6
209,3,7
210,4,6
211,-34,24
212,3,7
213,4,6
214,4,6
215,4,6
216,3,7
217,-12,12
218,3,7
219,-8,8
220,3,7
221,4,6
222,-46,26
223,-33,23
224,4,6
225,1,9
226,3,7
227,2,8
228,-34,24
229,4,6
230,4,6
231,4,6
232,4,6
233,-55,35
234,-37,27
235,4,6
236,-14,14
237,-65,35
238,4,6
239,-13,13
240,4,6
241,4,6
242,-13,13
243,-30,20
244,3,7
245,-13,13
246,4,6
247,4,6
248,-13,13
249,-32,22
250,4,6
251,-55,35
252,-12,12
253,3,7
254,3,7
255,3,7
256,4,6
257,2,8
258,-12,12
259,3,7
260,-10,10
261,-12,12
262,4,6
263,3,7
264,3,7
265,-16,16
266,3,7
267,-47,27
268,-13,13
269,4,6
270,3,7
271,-13,13
272,4,6
273,4,6
274,-17,17
275,4,6
276,3,7
277,3,7
278,4,6
279,-41,31
280,3,7
281,-47,27
282,-32,22
283,4,6
284,3,7
285,-17,17
286,3,7
287,3,7
288,3,7
289,-12,12
290,4,6
291,3,7
292,3,7
293,-24,14
294,3,7
295,4,6
296,3,7
297,3,7
298,3,7
299,-13,13
1 episodes rewards steps
2 0 -3460 1000
3 1 -2800 1000
4 2 -2910 1000
5 3 -2620 1000
6 4 -2620 1000
7 5 -2590 1000
8 6 -2390 1000
9 7 -2510 1000
10 8 -2470 1000
11 9 -611 251
12 10 -891 371
13 11 -265 125
14 12 -2281 911
15 13 -1203 523
16 14 -616 266
17 15 -213 113
18 16 -633 273
19 17 -1112 482
20 18 -350 160
21 19 -852 342
22 20 -87 47
23 21 -11 11
24 22 -27 17
25 23 -117 57
26 24 -15 15
27 25 4 6
28 26 -27 17
29 27 -94 44
30 28 -184 84
31 29 -44 24
32 30 -150 80
33 31 -14 14
34 32 -219 89
35 33 -50 30
36 34 -111 61
37 35 -10 10
38 36 -28 18
39 37 -34 24
40 38 -12 12
41 39 -19 19
42 40 -136 66
43 41 -171 71
44 42 -51 31
45 43 4 6
46 44 -117 57
47 45 4 6
48 46 4 6
49 47 -127 67
50 48 -78 48
51 49 -311 131
52 50 -25 15
53 51 4 6
54 52 -49 29
55 53 -25 15
56 54 -78 48
57 55 -238 108
58 56 4 6
59 57 -17 17
60 58 -29 19
61 59 -218 98
62 60 4 6
63 61 -129 59
64 62 -344 144
65 63 -25 15
66 64 -15 15
67 65 -77 37
68 66 2 8
69 67 0 10
70 68 4 6
71 69 4 6
72 70 -242 102
73 71 3 7
74 72 4 6
75 73 -53 33
76 74 -14 14
77 75 4 6
78 76 4 6
79 77 -30 20
80 78 -12 12
81 79 2 8
82 80 -12 12
83 81 -150 70
84 82 -48 28
85 83 -102 52
86 84 4 6
87 85 -97 47
88 86 -10 10
89 87 -125 55
90 88 -28 18
91 89 -26 16
92 90 -107 57
93 91 4 6
94 92 -16 16
95 93 -84 44
96 94 -13 13
97 95 -43 23
98 96 -14 14
99 97 -12 12
100 98 -13 13
101 99 -2 12
102 100 -14 14
103 101 -47 27
104 102 4 6
105 103 4 6
106 104 -91 51
107 105 -65 35
108 106 4 6
109 107 -12 12
110 108 -14 14
111 109 -13 13
112 110 4 6
113 111 -41 31
114 112 -13 13
115 113 4 6
116 114 -4 14
117 115 -74 34
118 116 4 6
119 117 -60 30
120 118 4 6
121 119 -15 15
122 120 3 7
123 121 4 6
124 122 4 6
125 123 -19 19
126 124 4 6
127 125 -49 29
128 126 -13 13
129 127 -30 20
130 128 2 8
131 129 -21 21
132 130 -45 25
133 131 -32 22
134 132 -67 37
135 133 -46 26
136 134 0 10
137 135 -12 12
138 136 -9 9
139 137 -10 10
140 138 -14 14
141 139 4 6
142 140 -11 11
143 141 -12 12
144 142 2 8
145 143 -35 25
146 144 4 6
147 145 -73 43
148 146 4 6
149 147 -20 20
150 148 4 6
151 149 2 8
152 150 -29 19
153 151 -20 20
154 152 4 6
155 153 -28 18
156 154 4 6
157 155 4 6
158 156 4 6
159 157 4 6
160 158 -34 24
161 159 4 6
162 160 4 6
163 161 4 6
164 162 -25 15
165 163 4 6
166 164 3 7
167 165 -48 28
168 166 4 6
169 167 -58 38
170 168 -20 20
171 169 -9 9
172 170 3 7
173 171 4 6
174 172 3 7
175 173 -33 23
176 174 -50 30
177 175 -16 16
178 176 -32 22
179 177 -65 35
180 178 4 6
181 179 -13 13
182 180 -11 11
183 181 3 7
184 182 4 6
185 183 -16 16
186 184 -12 12
187 185 4 6
188 186 -48 28
189 187 -13 13
190 188 2 8
191 189 3 7
192 190 -27 17
193 191 3 7
194 192 4 6
195 193 4 6
196 194 4 6
197 195 4 6
198 196 4 6
199 197 -13 13
200 198 -14 14
201 199 4 6
202 200 4 6
203 201 -13 13
204 202 -33 23
205 203 4 6
206 204 -32 22
207 205 4 6
208 206 -48 28
209 207 4 6
210 208 4 6
211 209 3 7
212 210 4 6
213 211 -34 24
214 212 3 7
215 213 4 6
216 214 4 6
217 215 4 6
218 216 3 7
219 217 -12 12
220 218 3 7
221 219 -8 8
222 220 3 7
223 221 4 6
224 222 -46 26
225 223 -33 23
226 224 4 6
227 225 1 9
228 226 3 7
229 227 2 8
230 228 -34 24
231 229 4 6
232 230 4 6
233 231 4 6
234 232 4 6
235 233 -55 35
236 234 -37 27
237 235 4 6
238 236 -14 14
239 237 -65 35
240 238 4 6
241 239 -13 13
242 240 4 6
243 241 4 6
244 242 -13 13
245 243 -30 20
246 244 3 7
247 245 -13 13
248 246 4 6
249 247 4 6
250 248 -13 13
251 249 -32 22
252 250 4 6
253 251 -55 35
254 252 -12 12
255 253 3 7
256 254 3 7
257 255 3 7
258 256 4 6
259 257 2 8
260 258 -12 12
261 259 3 7
262 260 -10 10
263 261 -12 12
264 262 4 6
265 263 3 7
266 264 3 7
267 265 -16 16
268 266 3 7
269 267 -47 27
270 268 -13 13
271 269 4 6
272 270 3 7
273 271 -13 13
274 272 4 6
275 273 4 6
276 274 -17 17
277 275 4 6
278 276 3 7
279 277 3 7
280 278 4 6
281 279 -41 31
282 280 3 7
283 281 -47 27
284 282 -32 22
285 283 4 6
286 284 3 7
287 285 -17 17
288 286 3 7
289 287 3 7
290 288 3 7
291 289 -12 12
292 290 4 6
293 291 3 7
294 292 3 7
295 293 -24 14
296 294 3 7
297 295 4 6
298 296 3 7
299 297 3 7
300 298 3 7
301 299 -13 13

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:58:16
LastEditor: John
LastEditTime: 2022-08-25 00:23:22
LastEditTime: 2022-08-25 21:26:08
Discription:
Environment:
'''
@@ -30,7 +30,7 @@ class Sarsa(object):
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
best_action = np.argmax(self.Q_table[state])
best_action = np.argmax(self.Q_table[str(state)]) # array cannot be hashtable, thus convert to str
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
action_probs[best_action] += (1.0 - self.epsilon)
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
@@ -38,27 +38,27 @@ class Sarsa(object):
def predict_action(self,state):
''' predict action while testing
'''
action = np.argmax(self.Q_table[state])
action = np.argmax(self.Q_table[str(state)])
return action
def update(self, state, action, reward, next_state, next_action,done):
Q_predict = self.Q_table[state][action]
Q_predict = self.Q_table[str(state)][action]
if done:
Q_target = reward # terminal state
else:
Q_target = reward + self.gamma * self.Q_table[next_state][next_action] # the only difference from Q learning
self.Q_table[state][action] += self.lr * (Q_target - Q_predict)
Q_target = reward + self.gamma * self.Q_table[str(next_state)][next_action] # the only difference from Q learning
self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)
def save_model(self,path):
import dill
from pathlib import Path
# create path
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(
obj=self.Q_table_table,
obj=self.Q_table,
f=path+"checkpoint.pkl",
pickle_module=dill
)
print("Model saved!")
def load_model(self, path):
import dill
self.Q_table_table =torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
self.Q_table=torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
print("Mode loaded!")

View File

@@ -1,131 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2020-09-11 23:03:00
LastEditor: John
LastEditTime: 2022-08-04 22:44:00
Discription:
Environment:
'''
import sys
import os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
import gym
import torch
import datetime
import argparse
from envs.gridworld_env import CliffWalkingWapper
from Sarsa.sarsa import Sarsa
from common.utils import plot_rewards,save_args
from common.utils import save_results,make_dir
def get_args():
"""
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training") # 训练的回合数
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing") # 测试的回合数
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor") # 折扣因子
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon") # e-greedy策略中初始epsilon
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon") # e-greedy策略中的终止epsilon
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon") # e-greedy策略中epsilon的衰减率
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/' ) # path to save models
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args([])
return args
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录奖励
for i_ep in range(cfg.train_eps):
ep_reward = 0 # 记录每个回合的奖励
state = env.reset() # 重置环境,即开始新的回合
action = agent.sample(state)
while True:
action = agent.sample(state) # 根据算法采样一个动作
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
next_action = agent.sample(next_state)
agent.update(state, action, reward, next_state, next_action,done) # 算法更新
state = next_state # 更新状态
action = next_action
ep_reward += reward
if done:
break
rewards.append(ep_reward)
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f}Epsilon{agent.epsilon}")
print('完成训练!')
return {"rewards":rewards}
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
for i_ep in range(cfg.test_eps):
ep_reward = 0 # 记录每个episode的reward
state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合)
while True:
action = agent.predict(state) # 根据算法选择一个动作
next_state, reward, done, _ = env.step(action) # 与环境进行一个交互
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
print('完成测试!')
return {"rewards":rewards}
def env_agent_config(cfg,seed=1):
'''创建环境和智能体
Args:
cfg ([type]): [description]
seed (int, optional): 随机种子. Defaults to 1.
Returns:
env [type]: 环境
agent : 智能体
'''
env = gym.make(cfg.env_name)
env = CliffWalkingWapper(env)
env.seed(seed) # 设置随机种子
n_states = env.observation_space.n # 状态维度
n_actions = env.action_space.n # 动作维度
print(f"状态数:{n_states},动作数:{n_actions}")
agent = Sarsa(n_actions,cfg)
return env,agent
if __name__ == "__main__":
cfg = get_args()
# 训练
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path)
save_args(cfg) # save parameters
agent.save(path=cfg.model_path) # save model
save_results(res_dic, tag='train',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, tag="train")
# 测试
env, agent = env_agent_config(cfg)
agent.load(path=cfg.model_path) # 导入模型
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path=cfg.result_path) # 保存结果
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果