hot update DQN

This commit is contained in:
johnjim0816
2022-08-24 12:49:16 +08:00
parent 07fb1d233e
commit 4f4658503e
24 changed files with 148 additions and 512 deletions

View File

@@ -71,7 +71,7 @@ class DQN:
return
else:
if not self.update_flag:
print("begin to update!")
print("Begin to update!")
self.update_flag = True
# sample a batch of transitions from replay buffer
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(

View File

@@ -27,7 +27,7 @@ def get_args():
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon")
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon, the higher value, the slower decay")
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
parser.add_argument('--memory_capacity',default=100000,type=int,help="memory capacity")
parser.add_argument('--batch_size',default=64,type=int)
@@ -64,8 +64,8 @@ def env_agent_config(cfg):
def train(cfg, env, agent):
''' 训练
'''
print("start training!")
print(f"Env: {cfg['env_name']}, Algo: {cfg['algo_name']}, Device: {cfg['device']}")
print("Start training!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg["train_eps"]):
@@ -89,17 +89,17 @@ def train(cfg, env, agent):
rewards.append(ep_reward)
if (i_ep + 1) % 10 == 0:
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
print("finish training!")
print("Finish training!")
env.close()
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg, env, agent):
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
print("Start testing!")
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.test_eps):
for i_ep in range(cfg['test_eps']):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
@@ -113,8 +113,8 @@ def test(cfg, env, agent):
break
steps.append(ep_step)
rewards.append(ep_reward)
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
print("finish testing!")
print(f"Episode: {i_ep+1}/{cfg['test_eps']}Reward: {ep_reward:.2f}")
print("Finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}

View File

@@ -0,0 +1 @@
{"algo_name": "DQN", "env_name": "Acrobot-v1", "train_eps": 100, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 1500, "lr": 0.002, "memory_capacity": 200000, "batch_size": 128, "target_update": 4, "hidden_dim": 256, "device": "cuda", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "C:\\Users\\jiangji\\Desktop\\rl-tutorials\\codes\\DQN/outputs/Acrobot-v1/20220824-124401/results", "model_path": "C:\\Users\\jiangji\\Desktop\\rl-tutorials\\codes\\DQN/outputs/Acrobot-v1/20220824-124401/models", "n_states": 6, "n_actions": 3}

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

View File

@@ -0,0 +1,21 @@
episodes,rewards
0,-79.0
1,-113.0
2,-81.0
3,-132.0
4,-110.0
5,-114.0
6,-80.0
7,-101.0
8,-78.0
9,-91.0
10,-107.0
11,-87.0
12,-105.0
13,-91.0
14,-128.0
15,-132.0
16,-119.0
17,-77.0
18,-89.0
19,-134.0
1 episodes rewards
2 0 -79.0
3 1 -113.0
4 2 -81.0
5 3 -132.0
6 4 -110.0
7 5 -114.0
8 6 -80.0
9 7 -101.0
10 8 -78.0
11 9 -91.0
12 10 -107.0
13 11 -87.0
14 12 -105.0
15 13 -91.0
16 14 -128.0
17 15 -132.0
18 16 -119.0
19 17 -77.0
20 18 -89.0
21 19 -134.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@@ -0,0 +1,101 @@
episodes,rewards
0,-500.0
1,-500.0
2,-500.0
3,-370.0
4,-449.0
5,-500.0
6,-312.0
7,-374.0
8,-180.0
9,-154.0
10,-137.0
11,-185.0
12,-135.0
13,-302.0
14,-146.0
15,-137.0
16,-119.0
17,-149.0
18,-217.0
19,-191.0
20,-157.0
21,-166.0
22,-138.0
23,-135.0
24,-182.0
25,-130.0
26,-175.0
27,-222.0
28,-133.0
29,-108.0
30,-250.0
31,-119.0
32,-135.0
33,-148.0
34,-194.0
35,-194.0
36,-186.0
37,-131.0
38,-185.0
39,-79.0
40,-129.0
41,-271.0
42,-117.0
43,-159.0
44,-156.0
45,-117.0
46,-158.0
47,-153.0
48,-119.0
49,-164.0
50,-134.0
51,-231.0
52,-117.0
53,-119.0
54,-136.0
55,-173.0
56,-202.0
57,-133.0
58,-142.0
59,-169.0
60,-137.0
61,-123.0
62,-205.0
63,-107.0
64,-194.0
65,-150.0
66,-143.0
67,-218.0
68,-145.0
69,-90.0
70,-107.0
71,-169.0
72,-125.0
73,-142.0
74,-145.0
75,-94.0
76,-150.0
77,-134.0
78,-159.0
79,-137.0
80,-146.0
81,-191.0
82,-242.0
83,-117.0
84,-92.0
85,-193.0
86,-239.0
87,-173.0
88,-140.0
89,-157.0
90,-133.0
91,-148.0
92,-87.0
93,-398.0
94,-98.0
95,-121.0
96,-102.0
97,-120.0
98,-195.0
99,-219.0
1 episodes rewards
2 0 -500.0
3 1 -500.0
4 2 -500.0
5 3 -370.0
6 4 -449.0
7 5 -500.0
8 6 -312.0
9 7 -374.0
10 8 -180.0
11 9 -154.0
12 10 -137.0
13 11 -185.0
14 12 -135.0
15 13 -302.0
16 14 -146.0
17 15 -137.0
18 16 -119.0
19 17 -149.0
20 18 -217.0
21 19 -191.0
22 20 -157.0
23 21 -166.0
24 22 -138.0
25 23 -135.0
26 24 -182.0
27 25 -130.0
28 26 -175.0
29 27 -222.0
30 28 -133.0
31 29 -108.0
32 30 -250.0
33 31 -119.0
34 32 -135.0
35 33 -148.0
36 34 -194.0
37 35 -194.0
38 36 -186.0
39 37 -131.0
40 38 -185.0
41 39 -79.0
42 40 -129.0
43 41 -271.0
44 42 -117.0
45 43 -159.0
46 44 -156.0
47 45 -117.0
48 46 -158.0
49 47 -153.0
50 48 -119.0
51 49 -164.0
52 50 -134.0
53 51 -231.0
54 52 -117.0
55 53 -119.0
56 54 -136.0
57 55 -173.0
58 56 -202.0
59 57 -133.0
60 58 -142.0
61 59 -169.0
62 60 -137.0
63 61 -123.0
64 62 -205.0
65 63 -107.0
66 64 -194.0
67 65 -150.0
68 66 -143.0
69 67 -218.0
70 68 -145.0
71 69 -90.0
72 70 -107.0
73 71 -169.0
74 72 -125.0
75 73 -142.0
76 74 -145.0
77 75 -94.0
78 76 -150.0
79 77 -134.0
80 78 -159.0
81 79 -137.0
82 80 -146.0
83 81 -191.0
84 82 -242.0
85 83 -117.0
86 84 -92.0
87 85 -193.0
88 86 -239.0
89 87 -173.0
90 88 -140.0
91 89 -157.0
92 90 -133.0
93 91 -148.0
94 92 -87.0
95 93 -398.0
96 94 -98.0
97 95 -121.0
98 96 -102.0
99 97 -120.0
100 98 -195.0
101 99 -219.0

View File

@@ -1 +0,0 @@
{"algo_name": "DQN", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 500, "lr": 0.0001, "memory_capacity": 100000, "batch_size": 64, "target_update": 4, "hidden_dim": 256, "device": "cpu", "seed": 10, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220818-143132/results", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220818-143132/models", "show_fig": false, "save_fig": true}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -1,21 +0,0 @@
episodes,rewards
0,200.0
1,200.0
2,200.0
3,200.0
4,200.0
5,200.0
6,200.0
7,200.0
8,200.0
9,200.0
10,200.0
11,200.0
12,200.0
13,200.0
14,200.0
15,200.0
16,200.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 200.0
5 3 200.0
6 4 200.0
7 5 200.0
8 6 200.0
9 7 200.0
10 8 200.0
11 9 200.0
12 10 200.0
13 11 200.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 200.0
18 16 200.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -1,201 +0,0 @@
episodes,rewards
0,38.0
1,16.0
2,37.0
3,15.0
4,22.0
5,34.0
6,20.0
7,12.0
8,16.0
9,14.0
10,13.0
11,21.0
12,14.0
13,12.0
14,17.0
15,12.0
16,10.0
17,14.0
18,10.0
19,10.0
20,16.0
21,9.0
22,14.0
23,13.0
24,10.0
25,9.0
26,12.0
27,12.0
28,14.0
29,11.0
30,9.0
31,8.0
32,9.0
33,11.0
34,12.0
35,10.0
36,11.0
37,10.0
38,10.0
39,18.0
40,13.0
41,15.0
42,10.0
43,9.0
44,14.0
45,14.0
46,23.0
47,17.0
48,15.0
49,15.0
50,20.0
51,28.0
52,36.0
53,36.0
54,23.0
55,27.0
56,53.0
57,19.0
58,35.0
59,62.0
60,57.0
61,38.0
62,61.0
63,65.0
64,58.0
65,43.0
66,67.0
67,56.0
68,91.0
69,128.0
70,71.0
71,126.0
72,100.0
73,200.0
74,200.0
75,200.0
76,200.0
77,200.0
78,200.0
79,200.0
80,200.0
81,200.0
82,200.0
83,200.0
84,200.0
85,200.0
86,200.0
87,200.0
88,200.0
89,200.0
90,200.0
91,200.0
92,200.0
93,200.0
94,200.0
95,200.0
96,200.0
97,200.0
98,200.0
99,200.0
100,200.0
101,200.0
102,200.0
103,200.0
104,200.0
105,200.0
106,200.0
107,200.0
108,200.0
109,200.0
110,200.0
111,200.0
112,200.0
113,200.0
114,200.0
115,200.0
116,200.0
117,200.0
118,200.0
119,200.0
120,200.0
121,200.0
122,200.0
123,200.0
124,200.0
125,200.0
126,200.0
127,200.0
128,200.0
129,200.0
130,200.0
131,200.0
132,200.0
133,200.0
134,200.0
135,200.0
136,200.0
137,200.0
138,200.0
139,200.0
140,200.0
141,200.0
142,200.0
143,200.0
144,200.0
145,200.0
146,200.0
147,200.0
148,200.0
149,200.0
150,200.0
151,200.0
152,200.0
153,200.0
154,200.0
155,200.0
156,200.0
157,200.0
158,200.0
159,200.0
160,200.0
161,200.0
162,200.0
163,200.0
164,200.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,200.0
197,200.0
198,200.0
199,200.0
1 episodes rewards
2 0 38.0
3 1 16.0
4 2 37.0
5 3 15.0
6 4 22.0
7 5 34.0
8 6 20.0
9 7 12.0
10 8 16.0
11 9 14.0
12 10 13.0
13 11 21.0
14 12 14.0
15 13 12.0
16 14 17.0
17 15 12.0
18 16 10.0
19 17 14.0
20 18 10.0
21 19 10.0
22 20 16.0
23 21 9.0
24 22 14.0
25 23 13.0
26 24 10.0
27 25 9.0
28 26 12.0
29 27 12.0
30 28 14.0
31 29 11.0
32 30 9.0
33 31 8.0
34 32 9.0
35 33 11.0
36 34 12.0
37 35 10.0
38 36 11.0
39 37 10.0
40 38 10.0
41 39 18.0
42 40 13.0
43 41 15.0
44 42 10.0
45 43 9.0
46 44 14.0
47 45 14.0
48 46 23.0
49 47 17.0
50 48 15.0
51 49 15.0
52 50 20.0
53 51 28.0
54 52 36.0
55 53 36.0
56 54 23.0
57 55 27.0
58 56 53.0
59 57 19.0
60 58 35.0
61 59 62.0
62 60 57.0
63 61 38.0
64 62 61.0
65 63 65.0
66 64 58.0
67 65 43.0
68 66 67.0
69 67 56.0
70 68 91.0
71 69 128.0
72 70 71.0
73 71 126.0
74 72 100.0
75 73 200.0
76 74 200.0
77 75 200.0
78 76 200.0
79 77 200.0
80 78 200.0
81 79 200.0
82 80 200.0
83 81 200.0
84 82 200.0
85 83 200.0
86 84 200.0
87 85 200.0
88 86 200.0
89 87 200.0
90 88 200.0
91 89 200.0
92 90 200.0
93 91 200.0
94 92 200.0
95 93 200.0
96 94 200.0
97 95 200.0
98 96 200.0
99 97 200.0
100 98 200.0
101 99 200.0
102 100 200.0
103 101 200.0
104 102 200.0
105 103 200.0
106 104 200.0
107 105 200.0
108 106 200.0
109 107 200.0
110 108 200.0
111 109 200.0
112 110 200.0
113 111 200.0
114 112 200.0
115 113 200.0
116 114 200.0
117 115 200.0
118 116 200.0
119 117 200.0
120 118 200.0
121 119 200.0
122 120 200.0
123 121 200.0
124 122 200.0
125 123 200.0
126 124 200.0
127 125 200.0
128 126 200.0
129 127 200.0
130 128 200.0
131 129 200.0
132 130 200.0
133 131 200.0
134 132 200.0
135 133 200.0
136 134 200.0
137 135 200.0
138 136 200.0
139 137 200.0
140 138 200.0
141 139 200.0
142 140 200.0
143 141 200.0
144 142 200.0
145 143 200.0
146 144 200.0
147 145 200.0
148 146 200.0
149 147 200.0
150 148 200.0
151 149 200.0
152 150 200.0
153 151 200.0
154 152 200.0
155 153 200.0
156 154 200.0
157 155 200.0
158 156 200.0
159 157 200.0
160 158 200.0
161 159 200.0
162 160 200.0
163 161 200.0
164 162 200.0
165 163 200.0
166 164 200.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 200.0
199 197 200.0
200 198 200.0
201 199 200.0

View File

@@ -1,133 +0,0 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
parent_path = os.path.dirname(curr_path) # parent path
sys.path.append(parent_path) # add path to system path
import gym
import torch
import datetime
import numpy as np
import argparse
from common.utils import save_results,all_seed
from common.utils import plot_rewards,save_args
from common.models import MLP
from common.memories import ReplayBuffer
from dqn import DQN
def get_args():
""" hyperparameters
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='DQN',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon")
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
parser.add_argument('--memory_capacity',default=100000,type=int,help="memory capacity")
parser.add_argument('--batch_size',default=64,type=int)
parser.add_argument('--target_update',default=4,type=int)
parser.add_argument('--hidden_dim',default=256,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=10,type=int,help="seed")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models' )
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
def env_agent_config(cfg):
''' create env and agent
'''
env = gym.make(cfg.env_name) # create env
if cfg.seed !=0: # set random seed
all_seed(env,seed=cfg.seed)
n_states = env.observation_space.shape[0] # state dimension
n_actions = env.action_space.n # action dimension
print(f"state dim: {n_states}, action dim: {n_actions}")
model = MLP(n_states,n_actions,hidden_dim=cfg.hidden_dim)
memory = ReplayBuffer(cfg.memory_capacity) # replay buffer
agent = DQN(n_actions,model,memory,cfg) # create agent
return env, agent
def train(cfg, env, agent):
''' 训练
'''
print("start training!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.train_eps):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
ep_step += 1
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action) # update env and return transitions
agent.memory.push(state, action, reward,
next_state, done) # save transitions
state = next_state # update next state for env
agent.update() # update agent
ep_reward += reward #
if done:
break
if (i_ep + 1) % cfg.target_update == 0: # target net update, target_update means "C" in pseucodes
agent.target_net.load_state_dict(agent.policy_net.state_dict())
steps.append(ep_step)
rewards.append(ep_reward)
if (i_ep + 1) % 10 == 0:
print(f'Episode: {i_ep+1}/{cfg.train_eps}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
print("finish training!")
env.close()
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg, env, agent):
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.test_eps):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
ep_step+=1
action = agent.predict_action(state) # predict action
next_state, reward, done, _ = env.step(action)
state = next_state
ep_reward += reward
if done:
break
steps.append(ep_step)
rewards.append(ep_reward)
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
print("finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}
if __name__ == "__main__":
cfg = get_args()
# training
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
save_args(cfg,path = cfg.result_path) # save parameters
agent.save_model(path = cfg.model_path) # save models
save_results(res_dic, tag = 'train', path = cfg.result_path) # save results
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train") # plot results
# testing
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
agent.load_model(path = cfg.model_path) # load model
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path = cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test")

View File

@@ -1,15 +0,0 @@
{
"algo_name": "Q-learning",
"env_name": "CliffWalking-v0",
"train_eps": 400,
"test_eps": 20,
"gamma": 0.9,
"epsilon_start": 0.95,
"epsilon_end": 0.01,
"epsilon_decay": 300,
"lr": 0.1,
"device": "cpu",
"result_path": "/root/Desktop/rl-tutorials/codes/QLearning/outputs/CliffWalking-v0/20220802-163256/results/",
"model_path": "/root/Desktop/rl-tutorials/codes/QLearning/outputs/CliffWalking-v0/20220802-163256/models/",
"save_fig": true
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

View File

@@ -1,127 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2020-09-11 23:03:00
LastEditor: John
LastEditTime: 2022-08-10 11:25:56
Discription:
Environment:
'''
import sys
import os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
import gym
import torch
import datetime
import argparse
from envs.gridworld_env import CliffWalkingWapper
from qlearning import QLearning
from common.utils import plot_rewards,save_args
from common.utils import save_results,make_dir
def get_args():
"""
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='Q-learning',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training") # 训练的回合数
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing") # 测试的回合数
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor") # 折扣因子
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon") # e-greedy策略中初始epsilon
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon") # e-greedy策略中的终止epsilon
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon") # e-greedy策略中epsilon的衰减率
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/',type=str )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/',type=str,help="path to save models")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录奖励
for i_ep in range(cfg.train_eps):
ep_reward = 0 # 记录每个回合的奖励
state = env.reset() # 重置环境,即开始新的回合
while True:
action = agent.sample(state) # 根据算法采样一个动作
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
agent.update(state, action, reward, next_state, done) # Q学习算法更新
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f}Epsilon{agent.epsilon}")
print('完成训练!')
return {"rewards":rewards}
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
for i_ep in range(cfg.test_eps):
ep_reward = 0 # 记录每个episode的reward
state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合)
while True:
action = agent.predict(state) # 根据算法选择一个动作
next_state, reward, done, _ = env.step(action) # 与环境进行一个交互
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
print('完成测试!')
return {"rewards":rewards}
def env_agent_config(cfg,seed=1):
'''创建环境和智能体
Args:
cfg ([type]): [description]
seed (int, optional): 随机种子. Defaults to 1.
Returns:
env [type]: 环境
agent : 智能体
'''
env = gym.make(cfg.env_name)
env = CliffWalkingWapper(env)
env.seed(seed) # 设置随机种子
n_states = env.observation_space.n # 状态维度
n_actions = env.action_space.n # 动作维度
print(f"状态数:{n_states},动作数:{n_actions}")
agent = QLearning(n_actions,cfg)
return env,agent
if __name__ == "__main__":
cfg = get_args()
# 训练
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path)
save_args(cfg) # save parameters
agent.save(path=cfg.model_path) # save model
save_results(res_dic, tag='train',
path=cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, tag="train")
# 测试
env, agent = env_agent_config(cfg)
agent.load(path=cfg.model_path) # 导入模型
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path=cfg.result_path) # 保存结果
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果

View File

@@ -1,4 +0,0 @@
class SAC:
def __init__(self,n_actions,model,memory,cfg):
pass

View File

@@ -0,0 +1,15 @@
# run DQN on Acrobot-v1, not the best tuned parameters
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
source ~/opt/anaconda3/etc/profile.d/conda.sh
else
echo 'please manually config the conda source path'
fi
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
python $codes_dir/DQN/main.py --env_name Acrobot-v1 --train_eps 100 --epsilon_decay 1500 --lr 0.002 --memory_capacity 200000 --batch_size 128 --device cuda