hot update A2C

This commit is contained in:
johnjim0816
2022-08-29 15:12:33 +08:00
parent 99a3c1afec
commit 0b0f7e857d
109 changed files with 8213 additions and 1658 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2020-11-22 23:21:53
LastEditor: John
LastEditTime: 2022-08-25 20:59:23
LastEditTime: 2022-08-27 00:04:08
Discription:
Environment:
'''
@@ -34,7 +34,7 @@ class PGNet(MLP):
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
x = torch.sigmoid(self.fc3(x))
return x
class Main(Launcher):
@@ -47,8 +47,9 @@ class Main(Launcher):
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
parser.add_argument('--lr',default=0.01,type=float,help="learning rate")
parser.add_argument('--update_fre',default=8,type=int)
parser.add_argument('--hidden_dim',default=36,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
@@ -81,7 +82,7 @@ class Main(Launcher):
for i_ep in range(cfg['train_eps']):
state = env.reset()
ep_reward = 0
for _ in count():
for _ in range(cfg['ep_max_steps']):
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action)
ep_reward += reward
@@ -90,8 +91,9 @@ class Main(Launcher):
agent.memory.push((state,float(action),reward))
state = next_state
if done:
print(f"Episode{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
break
if (i_ep+1) % 10 == 0:
print(f"Episode{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
if (i_ep+1) % cfg['update_fre'] == 0:
agent.update()
rewards.append(ep_reward)
@@ -107,7 +109,7 @@ class Main(Launcher):
for i_ep in range(cfg['test_eps']):
state = env.reset()
ep_reward = 0
for _ in count():
for _ in range(cfg['ep_max_steps']):
action = agent.predict_action(state)
next_state, reward, done, _ = env.step(action)
ep_reward += reward
@@ -115,9 +117,9 @@ class Main(Launcher):
reward = 0
state = next_state
if done:
print(f"Episode: {i_ep+1}/{cfg['test_eps']}Reward: {ep_reward:.2f}")
break
rewards.append(ep_reward)
print(f"Episode: {i_ep+1}/{cfg['test_eps']}Reward: {ep_reward:.2f}")
rewards.append(ep_reward)
print("Finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}

View File

@@ -1 +0,0 @@
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.99, "lr": 0.005, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/models/", "n_states": 4, "n_actions": 2}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

View File

@@ -1,201 +0,0 @@
episodes,rewards
0,26.0
1,53.0
2,10.0
3,37.0
4,22.0
5,21.0
6,12.0
7,34.0
8,38.0
9,40.0
10,23.0
11,14.0
12,16.0
13,25.0
14,15.0
15,23.0
16,11.0
17,28.0
18,21.0
19,62.0
20,33.0
21,27.0
22,15.0
23,17.0
24,26.0
25,35.0
26,26.0
27,14.0
28,42.0
29,45.0
30,34.0
31,39.0
32,31.0
33,17.0
34,42.0
35,41.0
36,31.0
37,39.0
38,28.0
39,12.0
40,36.0
41,33.0
42,47.0
43,40.0
44,63.0
45,36.0
46,64.0
47,79.0
48,49.0
49,40.0
50,65.0
51,47.0
52,51.0
53,30.0
54,26.0
55,41.0
56,86.0
57,61.0
58,38.0
59,200.0
60,49.0
61,70.0
62,61.0
63,101.0
64,200.0
65,152.0
66,108.0
67,46.0
68,72.0
69,87.0
70,27.0
71,126.0
72,46.0
73,25.0
74,14.0
75,42.0
76,38.0
77,55.0
78,42.0
79,51.0
80,67.0
81,83.0
82,178.0
83,115.0
84,140.0
85,97.0
86,85.0
87,61.0
88,153.0
89,200.0
90,200.0
91,200.0
92,200.0
93,64.0
94,200.0
95,200.0
96,157.0
97,128.0
98,160.0
99,35.0
100,140.0
101,113.0
102,200.0
103,154.0
104,200.0
105,200.0
106,200.0
107,198.0
108,137.0
109,200.0
110,200.0
111,102.0
112,200.0
113,200.0
114,200.0
115,200.0
116,148.0
117,200.0
118,200.0
119,200.0
120,200.0
121,200.0
122,194.0
123,200.0
124,200.0
125,200.0
126,183.0
127,200.0
128,200.0
129,200.0
130,200.0
131,200.0
132,200.0
133,200.0
134,200.0
135,200.0
136,93.0
137,96.0
138,84.0
139,103.0
140,79.0
141,104.0
142,82.0
143,105.0
144,200.0
145,200.0
146,171.0
147,200.0
148,200.0
149,200.0
150,200.0
151,197.0
152,133.0
153,142.0
154,147.0
155,156.0
156,131.0
157,181.0
158,163.0
159,146.0
160,200.0
161,176.0
162,200.0
163,173.0
164,177.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,190.0
197,200.0
198,189.0
199,200.0
1 episodes rewards
2 0 26.0
3 1 53.0
4 2 10.0
5 3 37.0
6 4 22.0
7 5 21.0
8 6 12.0
9 7 34.0
10 8 38.0
11 9 40.0
12 10 23.0
13 11 14.0
14 12 16.0
15 13 25.0
16 14 15.0
17 15 23.0
18 16 11.0
19 17 28.0
20 18 21.0
21 19 62.0
22 20 33.0
23 21 27.0
24 22 15.0
25 23 17.0
26 24 26.0
27 25 35.0
28 26 26.0
29 27 14.0
30 28 42.0
31 29 45.0
32 30 34.0
33 31 39.0
34 32 31.0
35 33 17.0
36 34 42.0
37 35 41.0
38 36 31.0
39 37 39.0
40 38 28.0
41 39 12.0
42 40 36.0
43 41 33.0
44 42 47.0
45 43 40.0
46 44 63.0
47 45 36.0
48 46 64.0
49 47 79.0
50 48 49.0
51 49 40.0
52 50 65.0
53 51 47.0
54 52 51.0
55 53 30.0
56 54 26.0
57 55 41.0
58 56 86.0
59 57 61.0
60 58 38.0
61 59 200.0
62 60 49.0
63 61 70.0
64 62 61.0
65 63 101.0
66 64 200.0
67 65 152.0
68 66 108.0
69 67 46.0
70 68 72.0
71 69 87.0
72 70 27.0
73 71 126.0
74 72 46.0
75 73 25.0
76 74 14.0
77 75 42.0
78 76 38.0
79 77 55.0
80 78 42.0
81 79 51.0
82 80 67.0
83 81 83.0
84 82 178.0
85 83 115.0
86 84 140.0
87 85 97.0
88 86 85.0
89 87 61.0
90 88 153.0
91 89 200.0
92 90 200.0
93 91 200.0
94 92 200.0
95 93 64.0
96 94 200.0
97 95 200.0
98 96 157.0
99 97 128.0
100 98 160.0
101 99 35.0
102 100 140.0
103 101 113.0
104 102 200.0
105 103 154.0
106 104 200.0
107 105 200.0
108 106 200.0
109 107 198.0
110 108 137.0
111 109 200.0
112 110 200.0
113 111 102.0
114 112 200.0
115 113 200.0
116 114 200.0
117 115 200.0
118 116 148.0
119 117 200.0
120 118 200.0
121 119 200.0
122 120 200.0
123 121 200.0
124 122 194.0
125 123 200.0
126 124 200.0
127 125 200.0
128 126 183.0
129 127 200.0
130 128 200.0
131 129 200.0
132 130 200.0
133 131 200.0
134 132 200.0
135 133 200.0
136 134 200.0
137 135 200.0
138 136 93.0
139 137 96.0
140 138 84.0
141 139 103.0
142 140 79.0
143 141 104.0
144 142 82.0
145 143 105.0
146 144 200.0
147 145 200.0
148 146 171.0
149 147 200.0
150 148 200.0
151 149 200.0
152 150 200.0
153 151 197.0
154 152 133.0
155 153 142.0
156 154 147.0
157 155 156.0
158 156 131.0
159 157 181.0
160 158 163.0
161 159 146.0
162 160 200.0
163 161 176.0
164 162 200.0
165 163 173.0
166 164 177.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 190.0
199 197 200.0
200 198 189.0
201 199 200.0

View File

@@ -0,0 +1 @@
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "ep_max_steps": 100000, "gamma": 0.99, "lr": 0.01, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PolicyGradient/outputs/CartPole-v0/20220827-000433/results/", "model_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PolicyGradient/outputs/CartPole-v0/20220827-000433/models/", "n_states": 4, "n_actions": 2}

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

View File

@@ -1,7 +1,7 @@
episodes,rewards
0,200.0
1,200.0
2,165.0
2,200.0
3,200.0
4,200.0
5,200.0
@@ -10,12 +10,12 @@ episodes,rewards
8,200.0
9,200.0
10,200.0
11,168.0
11,200.0
12,200.0
13,200.0
14,200.0
15,115.0
16,198.0
15,200.0
16,200.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 165.0 200.0
5 3 200.0
6 4 200.0
7 5 200.0
10 8 200.0
11 9 200.0
12 10 200.0
13 11 168.0 200.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 115.0 200.0
18 16 198.0 200.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@@ -0,0 +1,201 @@
episodes,rewards
0,26.0
1,53.0
2,10.0
3,37.0
4,22.0
5,21.0
6,12.0
7,34.0
8,93.0
9,36.0
10,29.0
11,18.0
12,14.0
13,62.0
14,20.0
15,40.0
16,10.0
17,10.0
18,10.0
19,11.0
20,10.0
21,14.0
22,12.0
23,8.0
24,19.0
25,33.0
26,22.0
27,32.0
28,16.0
29,24.0
30,24.0
31,24.0
32,75.0
33,33.0
34,33.0
35,72.0
36,110.0
37,48.0
38,60.0
39,43.0
40,61.0
41,34.0
42,50.0
43,61.0
44,53.0
45,58.0
46,36.0
47,44.0
48,42.0
49,64.0
50,67.0
51,52.0
52,39.0
53,42.0
54,40.0
55,33.0
56,200.0
57,199.0
58,149.0
59,185.0
60,134.0
61,174.0
62,162.0
63,200.0
64,93.0
65,72.0
66,69.0
67,51.0
68,62.0
69,98.0
70,73.0
71,73.0
72,200.0
73,200.0
74,200.0
75,200.0
76,200.0
77,200.0
78,200.0
79,133.0
80,200.0
81,200.0
82,200.0
83,200.0
84,200.0
85,200.0
86,200.0
87,200.0
88,114.0
89,151.0
90,129.0
91,156.0
92,112.0
93,172.0
94,171.0
95,141.0
96,200.0
97,200.0
98,200.0
99,200.0
100,200.0
101,200.0
102,200.0
103,200.0
104,188.0
105,199.0
106,138.0
107,200.0
108,200.0
109,181.0
110,145.0
111,200.0
112,135.0
113,119.0
114,112.0
115,122.0
116,118.0
117,119.0
118,131.0
119,119.0
120,109.0
121,96.0
122,105.0
123,29.0
124,110.0
125,113.0
126,18.0
127,90.0
128,145.0
129,152.0
130,151.0
131,109.0
132,141.0
133,109.0
134,136.0
135,143.0
136,200.0
137,200.0
138,200.0
139,200.0
140,200.0
141,200.0
142,200.0
143,200.0
144,192.0
145,173.0
146,180.0
147,182.0
148,186.0
149,175.0
150,176.0
151,191.0
152,200.0
153,200.0
154,200.0
155,200.0
156,200.0
157,200.0
158,200.0
159,200.0
160,200.0
161,200.0
162,200.0
163,200.0
164,200.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,200.0
197,200.0
198,200.0
199,200.0
1 episodes rewards
2 0 26.0
3 1 53.0
4 2 10.0
5 3 37.0
6 4 22.0
7 5 21.0
8 6 12.0
9 7 34.0
10 8 93.0
11 9 36.0
12 10 29.0
13 11 18.0
14 12 14.0
15 13 62.0
16 14 20.0
17 15 40.0
18 16 10.0
19 17 10.0
20 18 10.0
21 19 11.0
22 20 10.0
23 21 14.0
24 22 12.0
25 23 8.0
26 24 19.0
27 25 33.0
28 26 22.0
29 27 32.0
30 28 16.0
31 29 24.0
32 30 24.0
33 31 24.0
34 32 75.0
35 33 33.0
36 34 33.0
37 35 72.0
38 36 110.0
39 37 48.0
40 38 60.0
41 39 43.0
42 40 61.0
43 41 34.0
44 42 50.0
45 43 61.0
46 44 53.0
47 45 58.0
48 46 36.0
49 47 44.0
50 48 42.0
51 49 64.0
52 50 67.0
53 51 52.0
54 52 39.0
55 53 42.0
56 54 40.0
57 55 33.0
58 56 200.0
59 57 199.0
60 58 149.0
61 59 185.0
62 60 134.0
63 61 174.0
64 62 162.0
65 63 200.0
66 64 93.0
67 65 72.0
68 66 69.0
69 67 51.0
70 68 62.0
71 69 98.0
72 70 73.0
73 71 73.0
74 72 200.0
75 73 200.0
76 74 200.0
77 75 200.0
78 76 200.0
79 77 200.0
80 78 200.0
81 79 133.0
82 80 200.0
83 81 200.0
84 82 200.0
85 83 200.0
86 84 200.0
87 85 200.0
88 86 200.0
89 87 200.0
90 88 114.0
91 89 151.0
92 90 129.0
93 91 156.0
94 92 112.0
95 93 172.0
96 94 171.0
97 95 141.0
98 96 200.0
99 97 200.0
100 98 200.0
101 99 200.0
102 100 200.0
103 101 200.0
104 102 200.0
105 103 200.0
106 104 188.0
107 105 199.0
108 106 138.0
109 107 200.0
110 108 200.0
111 109 181.0
112 110 145.0
113 111 200.0
114 112 135.0
115 113 119.0
116 114 112.0
117 115 122.0
118 116 118.0
119 117 119.0
120 118 131.0
121 119 119.0
122 120 109.0
123 121 96.0
124 122 105.0
125 123 29.0
126 124 110.0
127 125 113.0
128 126 18.0
129 127 90.0
130 128 145.0
131 129 152.0
132 130 151.0
133 131 109.0
134 132 141.0
135 133 109.0
136 134 136.0
137 135 143.0
138 136 200.0
139 137 200.0
140 138 200.0
141 139 200.0
142 140 200.0
143 141 200.0
144 142 200.0
145 143 200.0
146 144 192.0
147 145 173.0
148 146 180.0
149 147 182.0
150 148 186.0
151 149 175.0
152 150 176.0
153 151 191.0
154 152 200.0
155 153 200.0
156 154 200.0
157 155 200.0
158 156 200.0
159 157 200.0
160 158 200.0
161 159 200.0
162 160 200.0
163 161 200.0
164 162 200.0
165 163 200.0
166 164 200.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 200.0
199 197 200.0
200 198 200.0
201 199 200.0

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2020-11-22 23:27:44
LastEditor: John
LastEditTime: 2022-08-25 20:58:59
LastEditTime: 2022-08-27 13:45:26
Discription:
Environment:
'''
@@ -31,8 +31,11 @@ class PolicyGradient:
state = torch.from_numpy(state).float()
state = Variable(state)
probs = self.policy_net(state)
print("probs")
print(probs)
m = Bernoulli(probs) # 伯努利分布
action = m.sample()
action = action.data.numpy().astype(int)[0] # 转为标量
return action
def predict_action(self,state):