hot update
This commit is contained in:
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"algo_name": "SoftQ", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "max_steps": 200, "gamma": 0.99, "alpha": 4, "lr": 0.0001, "memory_capacity": 50000, "batch_size": 128, "target_update": 2, "device": "cpu", "seed": 10, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/SoftQ/outputs/CartPole-v0/20220818-154333/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/SoftQ/outputs/CartPole-v0/20220818-154333/models/", "show_fig": false, "save_fig": true}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards
|
||||
0,200.0
|
||||
1,200.0
|
||||
2,200.0
|
||||
3,200.0
|
||||
4,200.0
|
||||
5,200.0
|
||||
6,200.0
|
||||
7,200.0
|
||||
8,199.0
|
||||
9,200.0
|
||||
10,200.0
|
||||
11,200.0
|
||||
12,200.0
|
||||
13,200.0
|
||||
14,200.0
|
||||
15,200.0
|
||||
16,200.0
|
||||
17,200.0
|
||||
18,200.0
|
||||
19,200.0
|
||||
|
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
@@ -0,0 +1,201 @@
|
||||
episodes,rewards
|
||||
0,21.0
|
||||
1,23.0
|
||||
2,24.0
|
||||
3,27.0
|
||||
4,33.0
|
||||
5,18.0
|
||||
6,47.0
|
||||
7,18.0
|
||||
8,18.0
|
||||
9,21.0
|
||||
10,26.0
|
||||
11,31.0
|
||||
12,11.0
|
||||
13,17.0
|
||||
14,22.0
|
||||
15,16.0
|
||||
16,17.0
|
||||
17,34.0
|
||||
18,20.0
|
||||
19,11.0
|
||||
20,50.0
|
||||
21,15.0
|
||||
22,11.0
|
||||
23,39.0
|
||||
24,11.0
|
||||
25,28.0
|
||||
26,37.0
|
||||
27,26.0
|
||||
28,63.0
|
||||
29,18.0
|
||||
30,17.0
|
||||
31,13.0
|
||||
32,9.0
|
||||
33,15.0
|
||||
34,13.0
|
||||
35,21.0
|
||||
36,17.0
|
||||
37,22.0
|
||||
38,20.0
|
||||
39,31.0
|
||||
40,9.0
|
||||
41,10.0
|
||||
42,11.0
|
||||
43,15.0
|
||||
44,18.0
|
||||
45,10.0
|
||||
46,30.0
|
||||
47,14.0
|
||||
48,36.0
|
||||
49,26.0
|
||||
50,21.0
|
||||
51,15.0
|
||||
52,9.0
|
||||
53,14.0
|
||||
54,10.0
|
||||
55,27.0
|
||||
56,14.0
|
||||
57,15.0
|
||||
58,22.0
|
||||
59,12.0
|
||||
60,20.0
|
||||
61,10.0
|
||||
62,12.0
|
||||
63,29.0
|
||||
64,11.0
|
||||
65,13.0
|
||||
66,27.0
|
||||
67,50.0
|
||||
68,29.0
|
||||
69,40.0
|
||||
70,29.0
|
||||
71,18.0
|
||||
72,27.0
|
||||
73,11.0
|
||||
74,15.0
|
||||
75,10.0
|
||||
76,13.0
|
||||
77,11.0
|
||||
78,17.0
|
||||
79,13.0
|
||||
80,18.0
|
||||
81,24.0
|
||||
82,15.0
|
||||
83,34.0
|
||||
84,11.0
|
||||
85,35.0
|
||||
86,26.0
|
||||
87,9.0
|
||||
88,19.0
|
||||
89,19.0
|
||||
90,16.0
|
||||
91,25.0
|
||||
92,18.0
|
||||
93,37.0
|
||||
94,46.0
|
||||
95,88.0
|
||||
96,26.0
|
||||
97,55.0
|
||||
98,43.0
|
||||
99,141.0
|
||||
100,89.0
|
||||
101,151.0
|
||||
102,47.0
|
||||
103,56.0
|
||||
104,64.0
|
||||
105,56.0
|
||||
106,49.0
|
||||
107,87.0
|
||||
108,58.0
|
||||
109,55.0
|
||||
110,57.0
|
||||
111,165.0
|
||||
112,31.0
|
||||
113,200.0
|
||||
114,57.0
|
||||
115,107.0
|
||||
116,46.0
|
||||
117,45.0
|
||||
118,64.0
|
||||
119,69.0
|
||||
120,67.0
|
||||
121,65.0
|
||||
122,47.0
|
||||
123,63.0
|
||||
124,134.0
|
||||
125,60.0
|
||||
126,89.0
|
||||
127,99.0
|
||||
128,51.0
|
||||
129,109.0
|
||||
130,131.0
|
||||
131,156.0
|
||||
132,118.0
|
||||
133,185.0
|
||||
134,86.0
|
||||
135,149.0
|
||||
136,138.0
|
||||
137,143.0
|
||||
138,114.0
|
||||
139,130.0
|
||||
140,139.0
|
||||
141,106.0
|
||||
142,135.0
|
||||
143,164.0
|
||||
144,156.0
|
||||
145,155.0
|
||||
146,200.0
|
||||
147,186.0
|
||||
148,64.0
|
||||
149,200.0
|
||||
150,135.0
|
||||
151,135.0
|
||||
152,168.0
|
||||
153,200.0
|
||||
154,200.0
|
||||
155,200.0
|
||||
156,167.0
|
||||
157,198.0
|
||||
158,188.0
|
||||
159,200.0
|
||||
160,200.0
|
||||
161,200.0
|
||||
162,200.0
|
||||
163,200.0
|
||||
164,200.0
|
||||
165,200.0
|
||||
166,200.0
|
||||
167,200.0
|
||||
168,189.0
|
||||
169,200.0
|
||||
170,146.0
|
||||
171,200.0
|
||||
172,200.0
|
||||
173,200.0
|
||||
174,115.0
|
||||
175,170.0
|
||||
176,200.0
|
||||
177,200.0
|
||||
178,178.0
|
||||
179,200.0
|
||||
180,200.0
|
||||
181,200.0
|
||||
182,200.0
|
||||
183,200.0
|
||||
184,200.0
|
||||
185,200.0
|
||||
186,120.0
|
||||
187,200.0
|
||||
188,200.0
|
||||
189,200.0
|
||||
190,200.0
|
||||
191,200.0
|
||||
192,200.0
|
||||
193,200.0
|
||||
194,200.0
|
||||
195,200.0
|
||||
196,200.0
|
||||
197,200.0
|
||||
198,200.0
|
||||
199,200.0
|
||||
|
71
projects/codes/SoftQ/softq.py
Normal file
71
projects/codes/SoftQ/softq.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import deque
|
||||
import random
|
||||
from torch.distributions import Categorical
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
class SoftQ:
|
||||
def __init__(self,n_actions,model,memory,cfg):
|
||||
self.memory = memory
|
||||
self.alpha = cfg.alpha
|
||||
self.gamma = cfg.gamma # discount factor
|
||||
self.batch_size = cfg.batch_size
|
||||
self.device = torch.device(cfg.device)
|
||||
self.policy_net = model.to(self.device)
|
||||
self.target_net = model.to(self.device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict()) # copy parameters
|
||||
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.losses = [] # save losses
|
||||
|
||||
def sample_action(self,state):
|
||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
with torch.no_grad():
|
||||
q = self.policy_net(state)
|
||||
v = self.alpha * torch.log(torch.sum(torch.exp(q/self.alpha), dim=1, keepdim=True)).squeeze()
|
||||
dist = torch.exp((q-v)/self.alpha)
|
||||
dist = dist / torch.sum(dist)
|
||||
c = Categorical(dist)
|
||||
a = c.sample()
|
||||
return a.item()
|
||||
def predict_action(self,state):
|
||||
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
q = self.policy_net(state)
|
||||
v = self.alpha * torch.log(torch.sum(torch.exp(q/self.alpha), dim=1, keepdim=True)).squeeze()
|
||||
dist = torch.exp((q-v)/self.alpha)
|
||||
dist = dist / torch.sum(dist)
|
||||
c = Categorical(dist)
|
||||
a = c.sample()
|
||||
return a.item()
|
||||
def update(self):
|
||||
if len(self.memory) < self.batch_size: # when the memory capacity does not meet a batch, the network will not update
|
||||
return
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.batch_size)
|
||||
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
|
||||
action_batch = torch.tensor(np.array(action_batch), device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
|
||||
reward_batch = torch.tensor(np.array(reward_batch), device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
|
||||
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
|
||||
done_batch = torch.tensor(np.array(done_batch), device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
|
||||
# print(state_batch.shape,action_batch.shape,reward_batch.shape,next_state_batch.shape,done_batch.shape)
|
||||
with torch.no_grad():
|
||||
next_q = self.target_net(next_state_batch)
|
||||
next_v = self.alpha * torch.log(torch.sum(torch.exp(next_q/self.alpha), dim=1, keepdim=True))
|
||||
y = reward_batch + (1 - done_batch ) * self.gamma * next_v
|
||||
loss = F.mse_loss(self.policy_net(state_batch).gather(1, action_batch.long()), y)
|
||||
self.losses.append(loss)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
def save_model(self, path):
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.target_net.state_dict(), path+'checkpoint.pth')
|
||||
|
||||
def load_model(self, path):
|
||||
self.target_net.load_state_dict(torch.load(path+'checkpoint.pth'))
|
||||
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
|
||||
param.data.copy_(target_param.data)
|
||||
142
projects/codes/SoftQ/task0.py
Normal file
142
projects/codes/SoftQ/task0.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import sys,os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import gym
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from common.memories import ReplayBufferQue
|
||||
from common.models import MLP
|
||||
from common.utils import save_results,all_seed,plot_rewards,save_args
|
||||
from softq import SoftQ
|
||||
|
||||
def get_args():
|
||||
""" hyperparameters
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='SoftQ',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--max_steps',default=200,type=int,help="maximum steps per episode")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--alpha',default=4,type=float,help="alpha")
|
||||
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
|
||||
parser.add_argument('--memory_capacity',default=50000,type=int,help="memory capacity")
|
||||
parser.add_argument('--batch_size',default=128,type=int)
|
||||
parser.add_argument('--target_update',default=2,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' )
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
class SoftQNetwork(nn.Module):
|
||||
'''Actually almost same to common.models.MLP
|
||||
'''
|
||||
def __init__(self,input_dim,output_dim):
|
||||
super(SoftQNetwork,self).__init__()
|
||||
self.fc1 = nn.Linear(input_dim, 64)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(64, 256)
|
||||
self.fc3 = nn.Linear(256, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
def env_agent_config(cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
env = gym.make(cfg.env_name) # create env
|
||||
if cfg.seed !=0: # set random seed
|
||||
all_seed(env,seed=cfg.seed)
|
||||
n_states = env.observation_space.shape[0] # state dimension
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"state dim: {n_states}, action dim: {n_actions}")
|
||||
# model = MLP(n_states,n_actions)
|
||||
model = SoftQNetwork(n_states,n_actions)
|
||||
memory = ReplayBufferQue(cfg.memory_capacity) # replay buffer
|
||||
agent = SoftQ(n_actions,model,memory,cfg) # create agent
|
||||
return env, agent
|
||||
|
||||
def train(cfg, env, agent):
|
||||
''' training
|
||||
'''
|
||||
print("start training!")
|
||||
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes, sometimes need
|
||||
for i_ep in range(cfg.train_eps):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
# for _ in range(cfg.max_steps):
|
||||
ep_step += 1
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
agent.memory.push((state, action, reward, next_state, done)) # save transitions
|
||||
state = next_state # update next state for env
|
||||
agent.update() # update agent
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
if (i_ep + 1) % cfg.target_update == 0: # target net update, target_update means "C" in pseucodes
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
steps.append(ep_step)
|
||||
rewards.append(ep_reward)
|
||||
if (i_ep + 1) % 10 == 0:
|
||||
print(f'Episode: {i_ep+1}/{cfg.train_eps}, Reward: {ep_reward:.2f}')
|
||||
print("finish training!")
|
||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
return res_dic
|
||||
def test(cfg, env, agent):
|
||||
print("start testing!")
|
||||
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
for i_ep in range(cfg.test_eps):
|
||||
ep_reward = 0 # reward per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f'Episode: {i_ep+1}/{cfg.test_eps},Reward: {ep_reward:.2f}')
|
||||
print("finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
save_args(cfg,path = cfg.result_path) # 保存参数到模型路径上
|
||||
agent.save_model(path = cfg.model_path) # 保存模型
|
||||
save_results(res_dic, tag = 'train', path = cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg) # 也可以不加,加这一行的是为了避免训练之后环境可能会出现问题,因此新建一个环境用于测试
|
||||
agent.load_model(path = cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path = cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test") # 画出结果
|
||||
Reference in New Issue
Block a user