hot update DQN

This commit is contained in:
johnjim0816
2022-08-24 12:49:16 +08:00
parent 07fb1d233e
commit 4f4658503e
24 changed files with 148 additions and 512 deletions

View File

@@ -71,7 +71,7 @@ class DQN:
return
else:
if not self.update_flag:
print("begin to update!")
print("Begin to update!")
self.update_flag = True
# sample a batch of transitions from replay buffer
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(

View File

@@ -27,7 +27,7 @@ def get_args():
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon")
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon, the higher value, the slower decay")
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
parser.add_argument('--memory_capacity',default=100000,type=int,help="memory capacity")
parser.add_argument('--batch_size',default=64,type=int)
@@ -64,8 +64,8 @@ def env_agent_config(cfg):
def train(cfg, env, agent):
''' 训练
'''
print("start training!")
print(f"Env: {cfg['env_name']}, Algo: {cfg['algo_name']}, Device: {cfg['device']}")
print("Start training!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg["train_eps"]):
@@ -89,17 +89,17 @@ def train(cfg, env, agent):
rewards.append(ep_reward)
if (i_ep + 1) % 10 == 0:
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
print("finish training!")
print("Finish training!")
env.close()
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg, env, agent):
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
print("Start testing!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.test_eps):
for i_ep in range(cfg['test_eps']):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
@@ -113,8 +113,8 @@ def test(cfg, env, agent):
break
steps.append(ep_step)
rewards.append(ep_reward)
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
print("finish testing!")
print(f"Episode: {i_ep+1}/{cfg['test_eps']}Reward: {ep_reward:.2f}")
print("Finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}

View File

@@ -0,0 +1 @@
{"algo_name": "DQN", "env_name": "Acrobot-v1", "train_eps": 100, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 1500, "lr": 0.002, "memory_capacity": 200000, "batch_size": 128, "target_update": 4, "hidden_dim": 256, "device": "cuda", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "C:\\Users\\jiangji\\Desktop\\rl-tutorials\\codes\\DQN/outputs/Acrobot-v1/20220824-124401/results", "model_path": "C:\\Users\\jiangji\\Desktop\\rl-tutorials\\codes\\DQN/outputs/Acrobot-v1/20220824-124401/models", "n_states": 6, "n_actions": 3}

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

View File

@@ -0,0 +1,21 @@
episodes,rewards
0,-79.0
1,-113.0
2,-81.0
3,-132.0
4,-110.0
5,-114.0
6,-80.0
7,-101.0
8,-78.0
9,-91.0
10,-107.0
11,-87.0
12,-105.0
13,-91.0
14,-128.0
15,-132.0
16,-119.0
17,-77.0
18,-89.0
19,-134.0
1 episodes rewards
2 0 -79.0
3 1 -113.0
4 2 -81.0
5 3 -132.0
6 4 -110.0
7 5 -114.0
8 6 -80.0
9 7 -101.0
10 8 -78.0
11 9 -91.0
12 10 -107.0
13 11 -87.0
14 12 -105.0
15 13 -91.0
16 14 -128.0
17 15 -132.0
18 16 -119.0
19 17 -77.0
20 18 -89.0
21 19 -134.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@@ -0,0 +1,101 @@
episodes,rewards
0,-500.0
1,-500.0
2,-500.0
3,-370.0
4,-449.0
5,-500.0
6,-312.0
7,-374.0
8,-180.0
9,-154.0
10,-137.0
11,-185.0
12,-135.0
13,-302.0
14,-146.0
15,-137.0
16,-119.0
17,-149.0
18,-217.0
19,-191.0
20,-157.0
21,-166.0
22,-138.0
23,-135.0
24,-182.0
25,-130.0
26,-175.0
27,-222.0
28,-133.0
29,-108.0
30,-250.0
31,-119.0
32,-135.0
33,-148.0
34,-194.0
35,-194.0
36,-186.0
37,-131.0
38,-185.0
39,-79.0
40,-129.0
41,-271.0
42,-117.0
43,-159.0
44,-156.0
45,-117.0
46,-158.0
47,-153.0
48,-119.0
49,-164.0
50,-134.0
51,-231.0
52,-117.0
53,-119.0
54,-136.0
55,-173.0
56,-202.0
57,-133.0
58,-142.0
59,-169.0
60,-137.0
61,-123.0
62,-205.0
63,-107.0
64,-194.0
65,-150.0
66,-143.0
67,-218.0
68,-145.0
69,-90.0
70,-107.0
71,-169.0
72,-125.0
73,-142.0
74,-145.0
75,-94.0
76,-150.0
77,-134.0
78,-159.0
79,-137.0
80,-146.0
81,-191.0
82,-242.0
83,-117.0
84,-92.0
85,-193.0
86,-239.0
87,-173.0
88,-140.0
89,-157.0
90,-133.0
91,-148.0
92,-87.0
93,-398.0
94,-98.0
95,-121.0
96,-102.0
97,-120.0
98,-195.0
99,-219.0
1 episodes rewards
2 0 -500.0
3 1 -500.0
4 2 -500.0
5 3 -370.0
6 4 -449.0
7 5 -500.0
8 6 -312.0
9 7 -374.0
10 8 -180.0
11 9 -154.0
12 10 -137.0
13 11 -185.0
14 12 -135.0
15 13 -302.0
16 14 -146.0
17 15 -137.0
18 16 -119.0
19 17 -149.0
20 18 -217.0
21 19 -191.0
22 20 -157.0
23 21 -166.0
24 22 -138.0
25 23 -135.0
26 24 -182.0
27 25 -130.0
28 26 -175.0
29 27 -222.0
30 28 -133.0
31 29 -108.0
32 30 -250.0
33 31 -119.0
34 32 -135.0
35 33 -148.0
36 34 -194.0
37 35 -194.0
38 36 -186.0
39 37 -131.0
40 38 -185.0
41 39 -79.0
42 40 -129.0
43 41 -271.0
44 42 -117.0
45 43 -159.0
46 44 -156.0
47 45 -117.0
48 46 -158.0
49 47 -153.0
50 48 -119.0
51 49 -164.0
52 50 -134.0
53 51 -231.0
54 52 -117.0
55 53 -119.0
56 54 -136.0
57 55 -173.0
58 56 -202.0
59 57 -133.0
60 58 -142.0
61 59 -169.0
62 60 -137.0
63 61 -123.0
64 62 -205.0
65 63 -107.0
66 64 -194.0
67 65 -150.0
68 66 -143.0
69 67 -218.0
70 68 -145.0
71 69 -90.0
72 70 -107.0
73 71 -169.0
74 72 -125.0
75 73 -142.0
76 74 -145.0
77 75 -94.0
78 76 -150.0
79 77 -134.0
80 78 -159.0
81 79 -137.0
82 80 -146.0
83 81 -191.0
84 82 -242.0
85 83 -117.0
86 84 -92.0
87 85 -193.0
88 86 -239.0
89 87 -173.0
90 88 -140.0
91 89 -157.0
92 90 -133.0
93 91 -148.0
94 92 -87.0
95 93 -398.0
96 94 -98.0
97 95 -121.0
98 96 -102.0
99 97 -120.0
100 98 -195.0
101 99 -219.0

View File

@@ -1 +0,0 @@
{"algo_name": "DQN", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 500, "lr": 0.0001, "memory_capacity": 100000, "batch_size": 64, "target_update": 4, "hidden_dim": 256, "device": "cpu", "seed": 10, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220818-143132/results", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220818-143132/models", "show_fig": false, "save_fig": true}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -1,21 +0,0 @@
episodes,rewards
0,200.0
1,200.0
2,200.0
3,200.0
4,200.0
5,200.0
6,200.0
7,200.0
8,200.0
9,200.0
10,200.0
11,200.0
12,200.0
13,200.0
14,200.0
15,200.0
16,200.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 200.0
5 3 200.0
6 4 200.0
7 5 200.0
8 6 200.0
9 7 200.0
10 8 200.0
11 9 200.0
12 10 200.0
13 11 200.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 200.0
18 16 200.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -1,201 +0,0 @@
episodes,rewards
0,38.0
1,16.0
2,37.0
3,15.0
4,22.0
5,34.0
6,20.0
7,12.0
8,16.0
9,14.0
10,13.0
11,21.0
12,14.0
13,12.0
14,17.0
15,12.0
16,10.0
17,14.0
18,10.0
19,10.0
20,16.0
21,9.0
22,14.0
23,13.0
24,10.0
25,9.0
26,12.0
27,12.0
28,14.0
29,11.0
30,9.0
31,8.0
32,9.0
33,11.0
34,12.0
35,10.0
36,11.0
37,10.0
38,10.0
39,18.0
40,13.0
41,15.0
42,10.0
43,9.0
44,14.0
45,14.0
46,23.0
47,17.0
48,15.0
49,15.0
50,20.0
51,28.0
52,36.0
53,36.0
54,23.0
55,27.0
56,53.0
57,19.0
58,35.0
59,62.0
60,57.0
61,38.0
62,61.0
63,65.0
64,58.0
65,43.0
66,67.0
67,56.0
68,91.0
69,128.0
70,71.0
71,126.0
72,100.0
73,200.0
74,200.0
75,200.0
76,200.0
77,200.0
78,200.0
79,200.0
80,200.0
81,200.0
82,200.0
83,200.0
84,200.0
85,200.0
86,200.0
87,200.0
88,200.0
89,200.0
90,200.0
91,200.0
92,200.0
93,200.0
94,200.0
95,200.0
96,200.0
97,200.0
98,200.0
99,200.0
100,200.0
101,200.0
102,200.0
103,200.0
104,200.0
105,200.0
106,200.0
107,200.0
108,200.0
109,200.0
110,200.0
111,200.0
112,200.0
113,200.0
114,200.0
115,200.0
116,200.0
117,200.0
118,200.0
119,200.0
120,200.0
121,200.0
122,200.0
123,200.0
124,200.0
125,200.0
126,200.0
127,200.0
128,200.0
129,200.0
130,200.0
131,200.0
132,200.0
133,200.0
134,200.0
135,200.0
136,200.0
137,200.0
138,200.0
139,200.0
140,200.0
141,200.0
142,200.0
143,200.0
144,200.0
145,200.0
146,200.0
147,200.0
148,200.0
149,200.0
150,200.0
151,200.0
152,200.0
153,200.0
154,200.0
155,200.0
156,200.0
157,200.0
158,200.0
159,200.0
160,200.0
161,200.0
162,200.0
163,200.0
164,200.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,200.0
197,200.0
198,200.0
199,200.0
1 episodes rewards
2 0 38.0
3 1 16.0
4 2 37.0
5 3 15.0
6 4 22.0
7 5 34.0
8 6 20.0
9 7 12.0
10 8 16.0
11 9 14.0
12 10 13.0
13 11 21.0
14 12 14.0
15 13 12.0
16 14 17.0
17 15 12.0
18 16 10.0
19 17 14.0
20 18 10.0
21 19 10.0
22 20 16.0
23 21 9.0
24 22 14.0
25 23 13.0
26 24 10.0
27 25 9.0
28 26 12.0
29 27 12.0
30 28 14.0
31 29 11.0
32 30 9.0
33 31 8.0
34 32 9.0
35 33 11.0
36 34 12.0
37 35 10.0
38 36 11.0
39 37 10.0
40 38 10.0
41 39 18.0
42 40 13.0
43 41 15.0
44 42 10.0
45 43 9.0
46 44 14.0
47 45 14.0
48 46 23.0
49 47 17.0
50 48 15.0
51 49 15.0
52 50 20.0
53 51 28.0
54 52 36.0
55 53 36.0
56 54 23.0
57 55 27.0
58 56 53.0
59 57 19.0
60 58 35.0
61 59 62.0
62 60 57.0
63 61 38.0
64 62 61.0
65 63 65.0
66 64 58.0
67 65 43.0
68 66 67.0
69 67 56.0
70 68 91.0
71 69 128.0
72 70 71.0
73 71 126.0
74 72 100.0
75 73 200.0
76 74 200.0
77 75 200.0
78 76 200.0
79 77 200.0
80 78 200.0
81 79 200.0
82 80 200.0
83 81 200.0
84 82 200.0
85 83 200.0
86 84 200.0
87 85 200.0
88 86 200.0
89 87 200.0
90 88 200.0
91 89 200.0
92 90 200.0
93 91 200.0
94 92 200.0
95 93 200.0
96 94 200.0
97 95 200.0
98 96 200.0
99 97 200.0
100 98 200.0
101 99 200.0
102 100 200.0
103 101 200.0
104 102 200.0
105 103 200.0
106 104 200.0
107 105 200.0
108 106 200.0
109 107 200.0
110 108 200.0
111 109 200.0
112 110 200.0
113 111 200.0
114 112 200.0
115 113 200.0
116 114 200.0
117 115 200.0
118 116 200.0
119 117 200.0
120 118 200.0
121 119 200.0
122 120 200.0
123 121 200.0
124 122 200.0
125 123 200.0
126 124 200.0
127 125 200.0
128 126 200.0
129 127 200.0
130 128 200.0
131 129 200.0
132 130 200.0
133 131 200.0
134 132 200.0
135 133 200.0
136 134 200.0
137 135 200.0
138 136 200.0
139 137 200.0
140 138 200.0
141 139 200.0
142 140 200.0
143 141 200.0
144 142 200.0
145 143 200.0
146 144 200.0
147 145 200.0
148 146 200.0
149 147 200.0
150 148 200.0
151 149 200.0
152 150 200.0
153 151 200.0
154 152 200.0
155 153 200.0
156 154 200.0
157 155 200.0
158 156 200.0
159 157 200.0
160 158 200.0
161 159 200.0
162 160 200.0
163 161 200.0
164 162 200.0
165 163 200.0
166 164 200.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 200.0
199 197 200.0
200 198 200.0
201 199 200.0

View File

@@ -1,133 +0,0 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
parent_path = os.path.dirname(curr_path) # parent path
sys.path.append(parent_path) # add path to system path
import gym
import torch
import datetime
import numpy as np
import argparse
from common.utils import save_results,all_seed
from common.utils import plot_rewards,save_args
from common.models import MLP
from common.memories import ReplayBuffer
from dqn import DQN
def get_args():
""" hyperparameters
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='DQN',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon")
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
parser.add_argument('--memory_capacity',default=100000,type=int,help="memory capacity")
parser.add_argument('--batch_size',default=64,type=int)
parser.add_argument('--target_update',default=4,type=int)
parser.add_argument('--hidden_dim',default=256,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=10,type=int,help="seed")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models' )
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
def env_agent_config(cfg):
''' create env and agent
'''
env = gym.make(cfg.env_name) # create env
if cfg.seed !=0: # set random seed
all_seed(env,seed=cfg.seed)
n_states = env.observation_space.shape[0] # state dimension
n_actions = env.action_space.n # action dimension
print(f"state dim: {n_states}, action dim: {n_actions}")
model = MLP(n_states,n_actions,hidden_dim=cfg.hidden_dim)
memory = ReplayBuffer(cfg.memory_capacity) # replay buffer
agent = DQN(n_actions,model,memory,cfg) # create agent
return env, agent
def train(cfg, env, agent):
''' 训练
'''
print("start training!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.train_eps):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
ep_step += 1
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action) # update env and return transitions
agent.memory.push(state, action, reward,
next_state, done) # save transitions
state = next_state # update next state for env
agent.update() # update agent
ep_reward += reward #
if done:
break
if (i_ep + 1) % cfg.target_update == 0: # target net update, target_update means "C" in pseucodes
agent.target_net.load_state_dict(agent.policy_net.state_dict())
steps.append(ep_step)
rewards.append(ep_reward)
if (i_ep + 1) % 10 == 0:
print(f'Episode: {i_ep+1}/{cfg.train_eps}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
print("finish training!")
env.close()
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg, env, agent):
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.test_eps):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
ep_step+=1
action = agent.predict_action(state) # predict action
next_state, reward, done, _ = env.step(action)
state = next_state
ep_reward += reward
if done:
break
steps.append(ep_step)
rewards.append(ep_reward)
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
print("finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}
if __name__ == "__main__":
cfg = get_args()
# training
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
save_args(cfg,path = cfg.result_path) # save parameters
agent.save_model(path = cfg.model_path) # save models
save_results(res_dic, tag = 'train', path = cfg.result_path) # save results
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train") # plot results
# testing
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
agent.load_model(path = cfg.model_path) # load model
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path = cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test")