hot update PG
This commit is contained in:
136
projects/codes/Sarsa/main.py
Normal file
136
projects/codes/Sarsa/main.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-11 17:59:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 14:26:36
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
import gym
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.register import register_env
|
||||
from envs.wrappers import CliffWalkingWapper
|
||||
from Sarsa.sarsa import Sarsa
|
||||
from common.utils import save_results,make_dir,plot_rewards,save_args,all_seed
|
||||
|
||||
def get_args():
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='Racetrack-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
|
||||
def env_agent_config(cfg):
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed= cfg['seed'])
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = CliffWalkingWapper(env)
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = Sarsa(cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
action = agent.sample_action(state)
|
||||
while True:
|
||||
# for _ in range(cfg.ep_max_steps):
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
next_action = agent.sample_action(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done) # update agent
|
||||
state = next_state # update state
|
||||
action = next_action
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
while True:
|
||||
# for _ in range(cfg.ep_max_steps):
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done = env.step(action)
|
||||
state = next_state
|
||||
ep_reward+=reward
|
||||
ep_step+=1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:58:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-04 22:22:16
|
||||
LastEditTime: 2022-08-25 00:23:22
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -14,45 +14,51 @@ from collections import defaultdict
|
||||
import torch
|
||||
import math
|
||||
class Sarsa(object):
|
||||
def __init__(self,
|
||||
n_actions,cfg):
|
||||
self.n_actions = n_actions
|
||||
self.lr = cfg.lr
|
||||
self.gamma = cfg.gamma
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg.epsilon_start
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.Q = defaultdict(lambda: np.zeros(n_actions)) # Q table
|
||||
def sample(self, state):
|
||||
def __init__(self,cfg):
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.lr = cfg['lr']
|
||||
self.gamma = cfg['gamma']
|
||||
self.epsilon = cfg['epsilon_start']
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg['epsilon_start']
|
||||
self.epsilon_end = cfg['epsilon_end']
|
||||
self.epsilon_decay = cfg['epsilon_decay']
|
||||
self.Q_table = defaultdict(lambda: np.zeros(self.n_actions)) # Q table
|
||||
def sample_action(self, state):
|
||||
''' another way to represent e-greedy policy
|
||||
'''
|
||||
self.sample_count += 1
|
||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
|
||||
best_action = np.argmax(self.Q[state])
|
||||
best_action = np.argmax(self.Q_table[state])
|
||||
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
|
||||
action_probs[best_action] += (1.0 - self.epsilon)
|
||||
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
|
||||
return action
|
||||
def predict(self,state):
|
||||
return np.argmax(self.Q[state])
|
||||
def update(self, state, action, reward, next_state, next_action,done):
|
||||
Q_predict = self.Q[state][action]
|
||||
if done:
|
||||
Q_target = reward # 终止状态
|
||||
else:
|
||||
Q_target = reward + self.gamma * self.Q[next_state][next_action] # 与Q learning不同,Sarsa是拿下一步动作对应的Q值去更新
|
||||
self.Q[state][action] += self.lr * (Q_target - Q_predict)
|
||||
def save(self,path):
|
||||
'''把 Q表格 的数据保存到文件中
|
||||
def predict_action(self,state):
|
||||
''' predict action while testing
|
||||
'''
|
||||
action = np.argmax(self.Q_table[state])
|
||||
return action
|
||||
def update(self, state, action, reward, next_state, next_action,done):
|
||||
Q_predict = self.Q_table[state][action]
|
||||
if done:
|
||||
Q_target = reward # terminal state
|
||||
else:
|
||||
Q_target = reward + self.gamma * self.Q_table[next_state][next_action] # the only difference from Q learning
|
||||
self.Q_table[state][action] += self.lr * (Q_target - Q_predict)
|
||||
def save_model(self,path):
|
||||
import dill
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
obj=self.Q,
|
||||
f=path+"sarsa_model.pkl",
|
||||
obj=self.Q_table_table,
|
||||
f=path+"checkpoint.pkl",
|
||||
pickle_module=dill
|
||||
)
|
||||
def load(self, path):
|
||||
'''从文件中读取数据到 Q表格
|
||||
'''
|
||||
print("Model saved!")
|
||||
def load_model(self, path):
|
||||
import dill
|
||||
self.Q =torch.load(f=path+'sarsa_model.pkl',pickle_module=dill)
|
||||
self.Q_table_table =torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
|
||||
print("Mode loaded!")
|
||||
@@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-11 17:59:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-04 22:28:51
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
|
||||
parent_path = os.path.dirname(curr_path) # 父路径
|
||||
sys.path.append(parent_path) # 添加路径到系统路径
|
||||
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.racetrack_env import RacetrackEnv
|
||||
from Sarsa.sarsa import Sarsa
|
||||
from common.utils import save_results,make_dir,plot_rewards,save_args
|
||||
|
||||
def get_args():
|
||||
""" 超参数
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training") # 训练的回合数
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing") # 测试的回合数
|
||||
parser.add_argument('--ep_max_steps',default=200,type=int) # 每回合最大的部署
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor") # 折扣因子
|
||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon") # e-greedy策略中初始epsilon
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon") # e-greedy策略中的终止epsilon
|
||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon") # e-greedy策略中epsilon的衰减率
|
||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' ) # path to save models
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = RacetrackEnv()
|
||||
n_actions = 9 # 动作数
|
||||
agent = Sarsa(n_actions,cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print('开始训练!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
rewards = [] # 记录奖励
|
||||
for i_ep in range(cfg.train_eps):
|
||||
state = env.reset()
|
||||
action = agent.sample(state)
|
||||
ep_reward = 0
|
||||
# while True:
|
||||
for _ in range(cfg.ep_max_steps):
|
||||
next_state, reward, done = env.step(action)
|
||||
ep_reward+=reward
|
||||
next_action = agent.sample(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done)
|
||||
state = next_state
|
||||
action = next_action
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
if (i_ep+1)%2==0:
|
||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
||||
print('完成训练!')
|
||||
return {"rewards":rewards}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print('开始测试!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
rewards = []
|
||||
for i_ep in range(cfg.test_eps):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
# while True:
|
||||
for _ in range(cfg.ep_max_steps):
|
||||
action = agent.predict(state)
|
||||
next_state, reward, done = env.step(action)
|
||||
ep_reward+=reward
|
||||
state = next_state
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
|
||||
print('完成测试!')
|
||||
return {"rewards":rewards}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user