hot update A2C
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:21:53
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 20:59:23
|
||||
LastEditTime: 2022-08-27 00:04:08
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -34,7 +34,7 @@ class PGNet(MLP):
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = F.sigmoid(self.fc3(x))
|
||||
x = torch.sigmoid(self.fc3(x))
|
||||
return x
|
||||
|
||||
class Main(Launcher):
|
||||
@@ -47,8 +47,9 @@ class Main(Launcher):
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
|
||||
parser.add_argument('--lr',default=0.01,type=float,help="learning rate")
|
||||
parser.add_argument('--update_fre',default=8,type=int)
|
||||
parser.add_argument('--hidden_dim',default=36,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
@@ -81,7 +82,7 @@ class Main(Launcher):
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
@@ -90,8 +91,9 @@ class Main(Launcher):
|
||||
agent.memory.push((state,float(action),reward))
|
||||
state = next_state
|
||||
if done:
|
||||
print(f"Episode:{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
|
||||
break
|
||||
if (i_ep+1) % 10 == 0:
|
||||
print(f"Episode:{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
|
||||
if (i_ep+1) % cfg['update_fre'] == 0:
|
||||
agent.update()
|
||||
rewards.append(ep_reward)
|
||||
@@ -107,7 +109,7 @@ class Main(Launcher):
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
@@ -115,9 +117,9 @@ class Main(Launcher):
|
||||
reward = 0
|
||||
state = next_state
|
||||
if done:
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||
rewards.append(ep_reward)
|
||||
print("Finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.99, "lr": 0.005, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/models/", "n_states": 4, "n_actions": 2}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 35 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 66 KiB |
@@ -1,201 +0,0 @@
|
||||
episodes,rewards
|
||||
0,26.0
|
||||
1,53.0
|
||||
2,10.0
|
||||
3,37.0
|
||||
4,22.0
|
||||
5,21.0
|
||||
6,12.0
|
||||
7,34.0
|
||||
8,38.0
|
||||
9,40.0
|
||||
10,23.0
|
||||
11,14.0
|
||||
12,16.0
|
||||
13,25.0
|
||||
14,15.0
|
||||
15,23.0
|
||||
16,11.0
|
||||
17,28.0
|
||||
18,21.0
|
||||
19,62.0
|
||||
20,33.0
|
||||
21,27.0
|
||||
22,15.0
|
||||
23,17.0
|
||||
24,26.0
|
||||
25,35.0
|
||||
26,26.0
|
||||
27,14.0
|
||||
28,42.0
|
||||
29,45.0
|
||||
30,34.0
|
||||
31,39.0
|
||||
32,31.0
|
||||
33,17.0
|
||||
34,42.0
|
||||
35,41.0
|
||||
36,31.0
|
||||
37,39.0
|
||||
38,28.0
|
||||
39,12.0
|
||||
40,36.0
|
||||
41,33.0
|
||||
42,47.0
|
||||
43,40.0
|
||||
44,63.0
|
||||
45,36.0
|
||||
46,64.0
|
||||
47,79.0
|
||||
48,49.0
|
||||
49,40.0
|
||||
50,65.0
|
||||
51,47.0
|
||||
52,51.0
|
||||
53,30.0
|
||||
54,26.0
|
||||
55,41.0
|
||||
56,86.0
|
||||
57,61.0
|
||||
58,38.0
|
||||
59,200.0
|
||||
60,49.0
|
||||
61,70.0
|
||||
62,61.0
|
||||
63,101.0
|
||||
64,200.0
|
||||
65,152.0
|
||||
66,108.0
|
||||
67,46.0
|
||||
68,72.0
|
||||
69,87.0
|
||||
70,27.0
|
||||
71,126.0
|
||||
72,46.0
|
||||
73,25.0
|
||||
74,14.0
|
||||
75,42.0
|
||||
76,38.0
|
||||
77,55.0
|
||||
78,42.0
|
||||
79,51.0
|
||||
80,67.0
|
||||
81,83.0
|
||||
82,178.0
|
||||
83,115.0
|
||||
84,140.0
|
||||
85,97.0
|
||||
86,85.0
|
||||
87,61.0
|
||||
88,153.0
|
||||
89,200.0
|
||||
90,200.0
|
||||
91,200.0
|
||||
92,200.0
|
||||
93,64.0
|
||||
94,200.0
|
||||
95,200.0
|
||||
96,157.0
|
||||
97,128.0
|
||||
98,160.0
|
||||
99,35.0
|
||||
100,140.0
|
||||
101,113.0
|
||||
102,200.0
|
||||
103,154.0
|
||||
104,200.0
|
||||
105,200.0
|
||||
106,200.0
|
||||
107,198.0
|
||||
108,137.0
|
||||
109,200.0
|
||||
110,200.0
|
||||
111,102.0
|
||||
112,200.0
|
||||
113,200.0
|
||||
114,200.0
|
||||
115,200.0
|
||||
116,148.0
|
||||
117,200.0
|
||||
118,200.0
|
||||
119,200.0
|
||||
120,200.0
|
||||
121,200.0
|
||||
122,194.0
|
||||
123,200.0
|
||||
124,200.0
|
||||
125,200.0
|
||||
126,183.0
|
||||
127,200.0
|
||||
128,200.0
|
||||
129,200.0
|
||||
130,200.0
|
||||
131,200.0
|
||||
132,200.0
|
||||
133,200.0
|
||||
134,200.0
|
||||
135,200.0
|
||||
136,93.0
|
||||
137,96.0
|
||||
138,84.0
|
||||
139,103.0
|
||||
140,79.0
|
||||
141,104.0
|
||||
142,82.0
|
||||
143,105.0
|
||||
144,200.0
|
||||
145,200.0
|
||||
146,171.0
|
||||
147,200.0
|
||||
148,200.0
|
||||
149,200.0
|
||||
150,200.0
|
||||
151,197.0
|
||||
152,133.0
|
||||
153,142.0
|
||||
154,147.0
|
||||
155,156.0
|
||||
156,131.0
|
||||
157,181.0
|
||||
158,163.0
|
||||
159,146.0
|
||||
160,200.0
|
||||
161,176.0
|
||||
162,200.0
|
||||
163,173.0
|
||||
164,177.0
|
||||
165,200.0
|
||||
166,200.0
|
||||
167,200.0
|
||||
168,200.0
|
||||
169,200.0
|
||||
170,200.0
|
||||
171,200.0
|
||||
172,200.0
|
||||
173,200.0
|
||||
174,200.0
|
||||
175,200.0
|
||||
176,200.0
|
||||
177,200.0
|
||||
178,200.0
|
||||
179,200.0
|
||||
180,200.0
|
||||
181,200.0
|
||||
182,200.0
|
||||
183,200.0
|
||||
184,200.0
|
||||
185,200.0
|
||||
186,200.0
|
||||
187,200.0
|
||||
188,200.0
|
||||
189,200.0
|
||||
190,200.0
|
||||
191,200.0
|
||||
192,200.0
|
||||
193,200.0
|
||||
194,200.0
|
||||
195,200.0
|
||||
196,190.0
|
||||
197,200.0
|
||||
198,189.0
|
||||
199,200.0
|
||||
|
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "ep_max_steps": 100000, "gamma": 0.99, "lr": 0.01, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PolicyGradient/outputs/CartPole-v0/20220827-000433/results/", "model_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PolicyGradient/outputs/CartPole-v0/20220827-000433/models/", "n_states": 4, "n_actions": 2}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
@@ -1,7 +1,7 @@
|
||||
episodes,rewards
|
||||
0,200.0
|
||||
1,200.0
|
||||
2,165.0
|
||||
2,200.0
|
||||
3,200.0
|
||||
4,200.0
|
||||
5,200.0
|
||||
@@ -10,12 +10,12 @@ episodes,rewards
|
||||
8,200.0
|
||||
9,200.0
|
||||
10,200.0
|
||||
11,168.0
|
||||
11,200.0
|
||||
12,200.0
|
||||
13,200.0
|
||||
14,200.0
|
||||
15,115.0
|
||||
16,198.0
|
||||
15,200.0
|
||||
16,200.0
|
||||
17,200.0
|
||||
18,200.0
|
||||
19,200.0
|
||||
|
Binary file not shown.
|
After Width: | Height: | Size: 60 KiB |
@@ -0,0 +1,201 @@
|
||||
episodes,rewards
|
||||
0,26.0
|
||||
1,53.0
|
||||
2,10.0
|
||||
3,37.0
|
||||
4,22.0
|
||||
5,21.0
|
||||
6,12.0
|
||||
7,34.0
|
||||
8,93.0
|
||||
9,36.0
|
||||
10,29.0
|
||||
11,18.0
|
||||
12,14.0
|
||||
13,62.0
|
||||
14,20.0
|
||||
15,40.0
|
||||
16,10.0
|
||||
17,10.0
|
||||
18,10.0
|
||||
19,11.0
|
||||
20,10.0
|
||||
21,14.0
|
||||
22,12.0
|
||||
23,8.0
|
||||
24,19.0
|
||||
25,33.0
|
||||
26,22.0
|
||||
27,32.0
|
||||
28,16.0
|
||||
29,24.0
|
||||
30,24.0
|
||||
31,24.0
|
||||
32,75.0
|
||||
33,33.0
|
||||
34,33.0
|
||||
35,72.0
|
||||
36,110.0
|
||||
37,48.0
|
||||
38,60.0
|
||||
39,43.0
|
||||
40,61.0
|
||||
41,34.0
|
||||
42,50.0
|
||||
43,61.0
|
||||
44,53.0
|
||||
45,58.0
|
||||
46,36.0
|
||||
47,44.0
|
||||
48,42.0
|
||||
49,64.0
|
||||
50,67.0
|
||||
51,52.0
|
||||
52,39.0
|
||||
53,42.0
|
||||
54,40.0
|
||||
55,33.0
|
||||
56,200.0
|
||||
57,199.0
|
||||
58,149.0
|
||||
59,185.0
|
||||
60,134.0
|
||||
61,174.0
|
||||
62,162.0
|
||||
63,200.0
|
||||
64,93.0
|
||||
65,72.0
|
||||
66,69.0
|
||||
67,51.0
|
||||
68,62.0
|
||||
69,98.0
|
||||
70,73.0
|
||||
71,73.0
|
||||
72,200.0
|
||||
73,200.0
|
||||
74,200.0
|
||||
75,200.0
|
||||
76,200.0
|
||||
77,200.0
|
||||
78,200.0
|
||||
79,133.0
|
||||
80,200.0
|
||||
81,200.0
|
||||
82,200.0
|
||||
83,200.0
|
||||
84,200.0
|
||||
85,200.0
|
||||
86,200.0
|
||||
87,200.0
|
||||
88,114.0
|
||||
89,151.0
|
||||
90,129.0
|
||||
91,156.0
|
||||
92,112.0
|
||||
93,172.0
|
||||
94,171.0
|
||||
95,141.0
|
||||
96,200.0
|
||||
97,200.0
|
||||
98,200.0
|
||||
99,200.0
|
||||
100,200.0
|
||||
101,200.0
|
||||
102,200.0
|
||||
103,200.0
|
||||
104,188.0
|
||||
105,199.0
|
||||
106,138.0
|
||||
107,200.0
|
||||
108,200.0
|
||||
109,181.0
|
||||
110,145.0
|
||||
111,200.0
|
||||
112,135.0
|
||||
113,119.0
|
||||
114,112.0
|
||||
115,122.0
|
||||
116,118.0
|
||||
117,119.0
|
||||
118,131.0
|
||||
119,119.0
|
||||
120,109.0
|
||||
121,96.0
|
||||
122,105.0
|
||||
123,29.0
|
||||
124,110.0
|
||||
125,113.0
|
||||
126,18.0
|
||||
127,90.0
|
||||
128,145.0
|
||||
129,152.0
|
||||
130,151.0
|
||||
131,109.0
|
||||
132,141.0
|
||||
133,109.0
|
||||
134,136.0
|
||||
135,143.0
|
||||
136,200.0
|
||||
137,200.0
|
||||
138,200.0
|
||||
139,200.0
|
||||
140,200.0
|
||||
141,200.0
|
||||
142,200.0
|
||||
143,200.0
|
||||
144,192.0
|
||||
145,173.0
|
||||
146,180.0
|
||||
147,182.0
|
||||
148,186.0
|
||||
149,175.0
|
||||
150,176.0
|
||||
151,191.0
|
||||
152,200.0
|
||||
153,200.0
|
||||
154,200.0
|
||||
155,200.0
|
||||
156,200.0
|
||||
157,200.0
|
||||
158,200.0
|
||||
159,200.0
|
||||
160,200.0
|
||||
161,200.0
|
||||
162,200.0
|
||||
163,200.0
|
||||
164,200.0
|
||||
165,200.0
|
||||
166,200.0
|
||||
167,200.0
|
||||
168,200.0
|
||||
169,200.0
|
||||
170,200.0
|
||||
171,200.0
|
||||
172,200.0
|
||||
173,200.0
|
||||
174,200.0
|
||||
175,200.0
|
||||
176,200.0
|
||||
177,200.0
|
||||
178,200.0
|
||||
179,200.0
|
||||
180,200.0
|
||||
181,200.0
|
||||
182,200.0
|
||||
183,200.0
|
||||
184,200.0
|
||||
185,200.0
|
||||
186,200.0
|
||||
187,200.0
|
||||
188,200.0
|
||||
189,200.0
|
||||
190,200.0
|
||||
191,200.0
|
||||
192,200.0
|
||||
193,200.0
|
||||
194,200.0
|
||||
195,200.0
|
||||
196,200.0
|
||||
197,200.0
|
||||
198,200.0
|
||||
199,200.0
|
||||
|
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:27:44
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 20:58:59
|
||||
LastEditTime: 2022-08-27 13:45:26
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -31,8 +31,11 @@ class PolicyGradient:
|
||||
state = torch.from_numpy(state).float()
|
||||
state = Variable(state)
|
||||
probs = self.policy_net(state)
|
||||
print("probs")
|
||||
print(probs)
|
||||
m = Bernoulli(probs) # 伯努利分布
|
||||
action = m.sample()
|
||||
|
||||
action = action.data.numpy().astype(int)[0] # 转为标量
|
||||
return action
|
||||
def predict_action(self,state):
|
||||
|
||||
Reference in New Issue
Block a user