hot update A2C
@@ -1,56 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: JiangJi
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-05-03 22:16:08
|
||||
LastEditor: JiangJi
|
||||
LastEditTime: 2022-07-20 23:54:40
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
''' A2C网络模型,包含一个Actor和Critic
|
||||
'''
|
||||
def __init__(self, input_dim, output_dim, hidden_dim):
|
||||
super(ActorCritic, self).__init__()
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
nn.Softmax(dim=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
value = self.critic(x)
|
||||
probs = self.actor(x)
|
||||
dist = Categorical(probs)
|
||||
return dist, value
|
||||
class A2C:
|
||||
''' A2C算法
|
||||
'''
|
||||
def __init__(self,n_states,n_actions,cfg) -> None:
|
||||
self.gamma = cfg.gamma
|
||||
self.device = torch.device(cfg.device)
|
||||
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
def __init__(self,models,memories,cfg):
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.gamma = cfg['gamma']
|
||||
self.device = torch.device(cfg['device'])
|
||||
self.memory = memories['ACMemory']
|
||||
self.actor = models['Actor'].to(self.device)
|
||||
self.critic = models['Critic'].to(self.device)
|
||||
self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=cfg['actor_lr'])
|
||||
self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=cfg['critic_lr'])
|
||||
def sample_action(self,state):
|
||||
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state) # note that 'dist' need require_grad=True
|
||||
value = value.detach().numpy().squeeze(0)[0]
|
||||
action = np.random.choice(self.n_actions, p=dist.detach().numpy().squeeze(0)) # shape(p=(n_actions,1)
|
||||
return action,value,dist
|
||||
def predict_action(self,state):
|
||||
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
dist = self.actor(state)
|
||||
value = self.critic(state) # note that 'dist' need require_grad=True
|
||||
value = value.detach().numpy().squeeze(0)[0]
|
||||
action = np.random.choice(self.n_actions, p=dist.detach().numpy().squeeze(0)) # shape(p=(n_actions,1)
|
||||
return action,value,dist
|
||||
def update(self,next_state,entropy):
|
||||
value_pool,log_prob_pool,reward_pool = self.memory.sample()
|
||||
next_state = torch.tensor(next_state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
next_value = self.critic(next_state)
|
||||
returns = np.zeros_like(reward_pool)
|
||||
for t in reversed(range(len(reward_pool))):
|
||||
next_value = reward_pool[t] + self.gamma * next_value # G(s_{t},a{t}) = r_{t+1} + gamma * V(s_{t+1})
|
||||
returns[t] = next_value
|
||||
returns = torch.tensor(returns, device=self.device)
|
||||
value_pool = torch.tensor(value_pool, device=self.device)
|
||||
advantages = returns - value_pool
|
||||
log_prob_pool = torch.stack(log_prob_pool)
|
||||
actor_loss = (-log_prob_pool * advantages).mean()
|
||||
critic_loss = 0.5 * advantages.pow(2).mean()
|
||||
tot_loss = actor_loss + critic_loss + 0.001 * entropy
|
||||
self.actor_optim.zero_grad()
|
||||
self.critic_optim.zero_grad()
|
||||
tot_loss.backward()
|
||||
self.actor_optim.step()
|
||||
self.critic_optim.step()
|
||||
self.memory.clear()
|
||||
def save_model(self, path):
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.actor.state_dict(), f"{path}/actor_checkpoint.pt")
|
||||
torch.save(self.critic.state_dict(), f"{path}/critic_checkpoint.pt")
|
||||
|
||||
def compute_returns(self,next_value, rewards, masks):
|
||||
R = next_value
|
||||
returns = []
|
||||
for step in reversed(range(len(rewards))):
|
||||
R = rewards[step] + self.gamma * R * masks[step]
|
||||
returns.insert(0, R)
|
||||
return returns
|
||||
def load_model(self, path):
|
||||
self.actor.load_state_dict(torch.load(f"{path}/actor_checkpoint.pt"))
|
||||
self.critic.load_state_dict(torch.load(f"{path}/critic_checkpoint.pt"))
|
||||
55
projects/codes/A2C/a2c_2.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class A2C_2:
|
||||
def __init__(self,models,memories,cfg):
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.gamma = cfg['gamma']
|
||||
self.device = torch.device(cfg['device'])
|
||||
self.memory = memories['ACMemory']
|
||||
self.ac_net = models['ActorCritic'].to(self.device)
|
||||
self.ac_optimizer = torch.optim.Adam(self.ac_net.parameters(), lr=cfg['lr'])
|
||||
def sample_action(self,state):
|
||||
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
value, dist = self.ac_net(state) # note that 'dist' need require_grad=True
|
||||
value = value.detach().numpy().squeeze(0)[0]
|
||||
action = np.random.choice(self.n_actions, p=dist.detach().numpy().squeeze(0)) # shape(p=(n_actions,1)
|
||||
return action,value,dist
|
||||
def predict_action(self,state):
|
||||
''' predict can be all wrapped with no_grad(), then donot need detach(), or you can just copy contents of 'sample_action'
|
||||
'''
|
||||
with torch.no_grad():
|
||||
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
value, dist = self.ac_net(state)
|
||||
value = value.numpy().squeeze(0)[0] # shape(value) = (1,)
|
||||
action = np.random.choice(self.n_actions, p=dist.numpy().squeeze(0)) # shape(p=(n_actions,1)
|
||||
return action,value,dist
|
||||
def update(self,next_state,entropy):
|
||||
value_pool,log_prob_pool,reward_pool = self.memory.sample()
|
||||
next_state = torch.tensor(next_state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
next_value,_ = self.ac_net(next_state)
|
||||
returns = np.zeros_like(reward_pool)
|
||||
for t in reversed(range(len(reward_pool))):
|
||||
next_value = reward_pool[t] + self.gamma * next_value # G(s_{t},a{t}) = r_{t+1} + gamma * V(s_{t+1})
|
||||
returns[t] = next_value
|
||||
returns = torch.tensor(returns, device=self.device)
|
||||
value_pool = torch.tensor(value_pool, device=self.device)
|
||||
advantages = returns - value_pool
|
||||
log_prob_pool = torch.stack(log_prob_pool)
|
||||
actor_loss = (-log_prob_pool * advantages).mean()
|
||||
critic_loss = 0.5 * advantages.pow(2).mean()
|
||||
ac_loss = actor_loss + critic_loss + 0.001 * entropy
|
||||
self.ac_optimizer.zero_grad()
|
||||
ac_loss.backward()
|
||||
self.ac_optimizer.step()
|
||||
self.memory.clear()
|
||||
def save_model(self, path):
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.ac_net.state_dict(), f"{path}/a2c_checkpoint.pt")
|
||||
|
||||
def load_model(self, path):
|
||||
self.ac_net.load_state_dict(torch.load(f"{path}/a2c_checkpoint.pt"))
|
||||
|
||||
|
||||
121
projects/codes/A2C/main.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import sys,os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
|
||||
import datetime
|
||||
import argparse
|
||||
import gym
|
||||
import torch
|
||||
import numpy as np
|
||||
from common.utils import all_seed
|
||||
from common.launcher import Launcher
|
||||
from common.memories import PGReplay
|
||||
from common.models import ActorSoftmax,Critic
|
||||
from envs.register import register_env
|
||||
from a2c import A2C
|
||||
|
||||
class Main(Launcher):
|
||||
def get_args(self):
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='A2C',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=1600,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--actor_lr',default=3e-4,type=float,help="learning rate of actor")
|
||||
parser.add_argument('--critic_lr',default=1e-3,type=float,help="learning rate of critic")
|
||||
parser.add_argument('--actor_hidden_dim',default=256,type=int,help="hidden of actor net")
|
||||
parser.add_argument('--critic_hidden_dim',default=256,type=int,help="hidden of critic net")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
def env_agent_config(self,cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
models = {'Actor':ActorSoftmax(cfg['n_states'],cfg['n_actions'], hidden_dim = cfg['actor_hidden_dim']),'Critic':Critic(cfg['n_states'],1,hidden_dim=cfg['critic_hidden_dim'])}
|
||||
memories = {'ACMemory':PGReplay()}
|
||||
agent = A2C(models,memories,cfg)
|
||||
return env,agent
|
||||
def train(self,cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
ep_entropy = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action, value, dist = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
log_prob = torch.log(dist.squeeze(0)[action])
|
||||
entropy = -np.sum(np.mean(dist.detach().numpy()) * np.log(dist.detach().numpy()))
|
||||
agent.memory.push((value,log_prob,reward)) # save transitions
|
||||
state = next_state # update state
|
||||
ep_reward += reward
|
||||
ep_entropy += entropy
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
agent.update(next_state,ep_entropy) # update agent
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
def test(self,cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action,_,_ = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main = Main()
|
||||
main.run()
|
||||
|
||||
|
||||
|
||||
|
||||
120
projects/codes/A2C/main2.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import sys,os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
|
||||
import datetime
|
||||
import argparse
|
||||
import gym
|
||||
import torch
|
||||
import numpy as np
|
||||
from common.utils import all_seed
|
||||
from common.launcher import Launcher
|
||||
from common.memories import PGReplay
|
||||
from common.models import ActorCriticSoftmax
|
||||
from envs.register import register_env
|
||||
from a2c_2 import A2C_2
|
||||
|
||||
class Main(Launcher):
|
||||
def get_args(self):
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='A2C',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=2000,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--lr',default=3e-4,type=float,help="learning rate")
|
||||
parser.add_argument('--actor_hidden_dim',default=256,type=int)
|
||||
parser.add_argument('--critic_hidden_dim',default=256,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
def env_agent_config(self,cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
models = {'ActorCritic':ActorCriticSoftmax(cfg['n_states'],cfg['n_actions'], actor_hidden_dim = cfg['actor_hidden_dim'],critic_hidden_dim=cfg['critic_hidden_dim'])}
|
||||
memories = {'ACMemory':PGReplay()}
|
||||
agent = A2C_2(models,memories,cfg)
|
||||
return env,agent
|
||||
def train(self,cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
ep_entropy = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action, value, dist = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
log_prob = torch.log(dist.squeeze(0)[action])
|
||||
entropy = -np.sum(np.mean(dist.detach().numpy()) * np.log(dist.detach().numpy()))
|
||||
agent.memory.push((value,log_prob,reward)) # save transitions
|
||||
state = next_state # update state
|
||||
ep_reward += reward
|
||||
ep_entropy += entropy
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
agent.update(next_state,ep_entropy) # update agent
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
def test(self,cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action,_,_ = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main = Main()
|
||||
main.run()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"algo_name": "A2C",
|
||||
"env_name": "CartPole-v0",
|
||||
"train_eps": 2000,
|
||||
"test_eps": 20,
|
||||
"ep_max_steps": 100000,
|
||||
"gamma": 0.99,
|
||||
"lr": 0.0003,
|
||||
"actor_hidden_dim": 256,
|
||||
"critic_hidden_dim": 256,
|
||||
"device": "cpu",
|
||||
"seed": 10,
|
||||
"show_fig": false,
|
||||
"save_fig": true,
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/A2C/outputs/CartPole-v0/20220829-135818/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/A2C/outputs/CartPole-v0/20220829-135818/models/",
|
||||
"n_states": 4,
|
||||
"n_actions": 2
|
||||
}
|
||||
|
After Width: | Height: | Size: 44 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,200.0,200
|
||||
1,200.0,200
|
||||
2,93.0,93
|
||||
3,155.0,155
|
||||
4,116.0,116
|
||||
5,200.0,200
|
||||
6,190.0,190
|
||||
7,176.0,176
|
||||
8,200.0,200
|
||||
9,200.0,200
|
||||
10,200.0,200
|
||||
11,179.0,179
|
||||
12,200.0,200
|
||||
13,185.0,185
|
||||
14,191.0,191
|
||||
15,200.0,200
|
||||
16,200.0,200
|
||||
17,124.0,124
|
||||
18,200.0,200
|
||||
19,172.0,172
|
||||
|
|
After Width: | Height: | Size: 63 KiB |
@@ -0,0 +1 @@
|
||||
{"algo_name": "A2C", "env_name": "CartPole-v0", "train_eps": 1600, "test_eps": 20, "ep_max_steps": 100000, "gamma": 0.99, "actor_lr": 0.0003, "critic_lr": 0.001, "actor_hidden_dim": 256, "critic_hidden_dim": 256, "device": "cpu", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/A2C/outputs/CartPole-v0/20220829-143327/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/A2C/outputs/CartPole-v0/20220829-143327/models/", "n_states": 4, "n_actions": 2}
|
||||
|
After Width: | Height: | Size: 41 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,177.0,177
|
||||
1,180.0,180
|
||||
2,200.0,200
|
||||
3,200.0,200
|
||||
4,167.0,167
|
||||
5,124.0,124
|
||||
6,128.0,128
|
||||
7,200.0,200
|
||||
8,200.0,200
|
||||
9,200.0,200
|
||||
10,186.0,186
|
||||
11,187.0,187
|
||||
12,200.0,200
|
||||
13,176.0,176
|
||||
14,200.0,200
|
||||
15,200.0,200
|
||||
16,200.0,200
|
||||
17,200.0,200
|
||||
18,185.0,185
|
||||
19,180.0,180
|
||||
|
|
After Width: | Height: | Size: 66 KiB |
56
projects/codes/A3C/a3c.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: JiangJi
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-05-03 22:16:08
|
||||
LastEditor: JiangJi
|
||||
LastEditTime: 2022-07-20 23:54:40
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
''' A2C网络模型,包含一个Actor和Critic
|
||||
'''
|
||||
def __init__(self, input_dim, output_dim, hidden_dim):
|
||||
super(ActorCritic, self).__init__()
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
nn.Softmax(dim=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
value = self.critic(x)
|
||||
probs = self.actor(x)
|
||||
dist = Categorical(probs)
|
||||
return dist, value
|
||||
class A2C:
|
||||
''' A2C算法
|
||||
'''
|
||||
def __init__(self,n_states,n_actions,cfg) -> None:
|
||||
self.gamma = cfg.gamma
|
||||
self.device = torch.device(cfg.device)
|
||||
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
|
||||
def compute_returns(self,next_value, rewards, masks):
|
||||
R = next_value
|
||||
returns = []
|
||||
for step in reversed(range(len(rewards))):
|
||||
R = rewards[step] + self.gamma * R * masks[step]
|
||||
returns.insert(0, R)
|
||||
return returns
|
||||
|
Before Width: | Height: | Size: 64 KiB After Width: | Height: | Size: 64 KiB |
@@ -10,7 +10,7 @@ import torch.optim as optim
|
||||
import datetime
|
||||
import argparse
|
||||
from common.multiprocessing_env import SubprocVecEnv
|
||||
from a2c import ActorCritic
|
||||
from a3c import ActorCritic
|
||||
from common.utils import save_results, make_dir
|
||||
from common.utils import plot_rewards, save_args
|
||||
|
||||
@@ -24,6 +24,7 @@ def get_args():
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
|
||||
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
@@ -72,7 +73,7 @@ def train(cfg, env, agent):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
ep_step += 1
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
@@ -91,7 +92,7 @@ def train(cfg, env, agent):
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
env.close()
|
||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
return res_dic
|
||||
|
||||
def test(cfg, env, agent):
|
||||
@@ -103,7 +104,7 @@ def test(cfg, env, agent):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
ep_step+=1
|
||||
action = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
@@ -116,7 +117,7 @@ def test(cfg, env, agent):
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{"algo_name": "DQN", "env_name": "CartPole-v1", "train_eps": 2000, "test_eps": 20, "ep_max_steps": 100000, "gamma": 0.99, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 6000, "lr": 1e-05, "memory_capacity": 200000, "batch_size": 64, "target_update": 4, "hidden_dim": 256, "device": "cuda", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\DQN/outputs/CartPole-v1/20220828-214702/results", "model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\DQN/outputs/CartPole-v1/20220828-214702/models", "n_states": 4, "n_actions": 2}
|
||||
|
After Width: | Height: | Size: 50 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,371.0,371
|
||||
1,446.0,446
|
||||
2,300.0,300
|
||||
3,500.0,500
|
||||
4,313.0,313
|
||||
5,500.0,500
|
||||
6,341.0,341
|
||||
7,489.0,489
|
||||
8,304.0,304
|
||||
9,358.0,358
|
||||
10,278.0,278
|
||||
11,500.0,500
|
||||
12,500.0,500
|
||||
13,500.0,500
|
||||
14,500.0,500
|
||||
15,476.0,476
|
||||
16,308.0,308
|
||||
17,394.0,394
|
||||
18,500.0,500
|
||||
19,500.0,500
|
||||
|
|
After Width: | Height: | Size: 50 KiB |
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:21:53
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 20:59:23
|
||||
LastEditTime: 2022-08-27 00:04:08
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -34,7 +34,7 @@ class PGNet(MLP):
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = F.sigmoid(self.fc3(x))
|
||||
x = torch.sigmoid(self.fc3(x))
|
||||
return x
|
||||
|
||||
class Main(Launcher):
|
||||
@@ -47,8 +47,9 @@ class Main(Launcher):
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
|
||||
parser.add_argument('--lr',default=0.01,type=float,help="learning rate")
|
||||
parser.add_argument('--update_fre',default=8,type=int)
|
||||
parser.add_argument('--hidden_dim',default=36,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
@@ -81,7 +82,7 @@ class Main(Launcher):
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
@@ -90,8 +91,9 @@ class Main(Launcher):
|
||||
agent.memory.push((state,float(action),reward))
|
||||
state = next_state
|
||||
if done:
|
||||
print(f"Episode:{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
|
||||
break
|
||||
if (i_ep+1) % 10 == 0:
|
||||
print(f"Episode:{i_ep+1}/{cfg['train_eps']}, Reward:{ep_reward:.2f}")
|
||||
if (i_ep+1) % cfg['update_fre'] == 0:
|
||||
agent.update()
|
||||
rewards.append(ep_reward)
|
||||
@@ -107,7 +109,7 @@ class Main(Launcher):
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
state = env.reset()
|
||||
ep_reward = 0
|
||||
for _ in count():
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
ep_reward += reward
|
||||
@@ -115,9 +117,9 @@ class Main(Launcher):
|
||||
reward = 0
|
||||
state = next_state
|
||||
if done:
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']},Reward: {ep_reward:.2f}")
|
||||
rewards.append(ep_reward)
|
||||
print("Finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.99, "lr": 0.005, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220825-205930/models/", "n_states": 4, "n_actions": 2}
|
||||
|
Before Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 66 KiB |
@@ -1,201 +0,0 @@
|
||||
episodes,rewards
|
||||
0,26.0
|
||||
1,53.0
|
||||
2,10.0
|
||||
3,37.0
|
||||
4,22.0
|
||||
5,21.0
|
||||
6,12.0
|
||||
7,34.0
|
||||
8,38.0
|
||||
9,40.0
|
||||
10,23.0
|
||||
11,14.0
|
||||
12,16.0
|
||||
13,25.0
|
||||
14,15.0
|
||||
15,23.0
|
||||
16,11.0
|
||||
17,28.0
|
||||
18,21.0
|
||||
19,62.0
|
||||
20,33.0
|
||||
21,27.0
|
||||
22,15.0
|
||||
23,17.0
|
||||
24,26.0
|
||||
25,35.0
|
||||
26,26.0
|
||||
27,14.0
|
||||
28,42.0
|
||||
29,45.0
|
||||
30,34.0
|
||||
31,39.0
|
||||
32,31.0
|
||||
33,17.0
|
||||
34,42.0
|
||||
35,41.0
|
||||
36,31.0
|
||||
37,39.0
|
||||
38,28.0
|
||||
39,12.0
|
||||
40,36.0
|
||||
41,33.0
|
||||
42,47.0
|
||||
43,40.0
|
||||
44,63.0
|
||||
45,36.0
|
||||
46,64.0
|
||||
47,79.0
|
||||
48,49.0
|
||||
49,40.0
|
||||
50,65.0
|
||||
51,47.0
|
||||
52,51.0
|
||||
53,30.0
|
||||
54,26.0
|
||||
55,41.0
|
||||
56,86.0
|
||||
57,61.0
|
||||
58,38.0
|
||||
59,200.0
|
||||
60,49.0
|
||||
61,70.0
|
||||
62,61.0
|
||||
63,101.0
|
||||
64,200.0
|
||||
65,152.0
|
||||
66,108.0
|
||||
67,46.0
|
||||
68,72.0
|
||||
69,87.0
|
||||
70,27.0
|
||||
71,126.0
|
||||
72,46.0
|
||||
73,25.0
|
||||
74,14.0
|
||||
75,42.0
|
||||
76,38.0
|
||||
77,55.0
|
||||
78,42.0
|
||||
79,51.0
|
||||
80,67.0
|
||||
81,83.0
|
||||
82,178.0
|
||||
83,115.0
|
||||
84,140.0
|
||||
85,97.0
|
||||
86,85.0
|
||||
87,61.0
|
||||
88,153.0
|
||||
89,200.0
|
||||
90,200.0
|
||||
91,200.0
|
||||
92,200.0
|
||||
93,64.0
|
||||
94,200.0
|
||||
95,200.0
|
||||
96,157.0
|
||||
97,128.0
|
||||
98,160.0
|
||||
99,35.0
|
||||
100,140.0
|
||||
101,113.0
|
||||
102,200.0
|
||||
103,154.0
|
||||
104,200.0
|
||||
105,200.0
|
||||
106,200.0
|
||||
107,198.0
|
||||
108,137.0
|
||||
109,200.0
|
||||
110,200.0
|
||||
111,102.0
|
||||
112,200.0
|
||||
113,200.0
|
||||
114,200.0
|
||||
115,200.0
|
||||
116,148.0
|
||||
117,200.0
|
||||
118,200.0
|
||||
119,200.0
|
||||
120,200.0
|
||||
121,200.0
|
||||
122,194.0
|
||||
123,200.0
|
||||
124,200.0
|
||||
125,200.0
|
||||
126,183.0
|
||||
127,200.0
|
||||
128,200.0
|
||||
129,200.0
|
||||
130,200.0
|
||||
131,200.0
|
||||
132,200.0
|
||||
133,200.0
|
||||
134,200.0
|
||||
135,200.0
|
||||
136,93.0
|
||||
137,96.0
|
||||
138,84.0
|
||||
139,103.0
|
||||
140,79.0
|
||||
141,104.0
|
||||
142,82.0
|
||||
143,105.0
|
||||
144,200.0
|
||||
145,200.0
|
||||
146,171.0
|
||||
147,200.0
|
||||
148,200.0
|
||||
149,200.0
|
||||
150,200.0
|
||||
151,197.0
|
||||
152,133.0
|
||||
153,142.0
|
||||
154,147.0
|
||||
155,156.0
|
||||
156,131.0
|
||||
157,181.0
|
||||
158,163.0
|
||||
159,146.0
|
||||
160,200.0
|
||||
161,176.0
|
||||
162,200.0
|
||||
163,173.0
|
||||
164,177.0
|
||||
165,200.0
|
||||
166,200.0
|
||||
167,200.0
|
||||
168,200.0
|
||||
169,200.0
|
||||
170,200.0
|
||||
171,200.0
|
||||
172,200.0
|
||||
173,200.0
|
||||
174,200.0
|
||||
175,200.0
|
||||
176,200.0
|
||||
177,200.0
|
||||
178,200.0
|
||||
179,200.0
|
||||
180,200.0
|
||||
181,200.0
|
||||
182,200.0
|
||||
183,200.0
|
||||
184,200.0
|
||||
185,200.0
|
||||
186,200.0
|
||||
187,200.0
|
||||
188,200.0
|
||||
189,200.0
|
||||
190,200.0
|
||||
191,200.0
|
||||
192,200.0
|
||||
193,200.0
|
||||
194,200.0
|
||||
195,200.0
|
||||
196,190.0
|
||||
197,200.0
|
||||
198,189.0
|
||||
199,200.0
|
||||
|
@@ -0,0 +1 @@
|
||||
{"algo_name": "PolicyGradient", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "ep_max_steps": 100000, "gamma": 0.99, "lr": 0.01, "update_fre": 8, "hidden_dim": 36, "device": "cpu", "seed": 1, "save_fig": true, "show_fig": false, "result_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PolicyGradient/outputs/CartPole-v0/20220827-000433/results/", "model_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\PolicyGradient/outputs/CartPole-v0/20220827-000433/models/", "n_states": 4, "n_actions": 2}
|
||||
|
After Width: | Height: | Size: 28 KiB |
@@ -1,7 +1,7 @@
|
||||
episodes,rewards
|
||||
0,200.0
|
||||
1,200.0
|
||||
2,165.0
|
||||
2,200.0
|
||||
3,200.0
|
||||
4,200.0
|
||||
5,200.0
|
||||
@@ -10,12 +10,12 @@ episodes,rewards
|
||||
8,200.0
|
||||
9,200.0
|
||||
10,200.0
|
||||
11,168.0
|
||||
11,200.0
|
||||
12,200.0
|
||||
13,200.0
|
||||
14,200.0
|
||||
15,115.0
|
||||
16,198.0
|
||||
15,200.0
|
||||
16,200.0
|
||||
17,200.0
|
||||
18,200.0
|
||||
19,200.0
|
||||
|
|
After Width: | Height: | Size: 60 KiB |
@@ -0,0 +1,201 @@
|
||||
episodes,rewards
|
||||
0,26.0
|
||||
1,53.0
|
||||
2,10.0
|
||||
3,37.0
|
||||
4,22.0
|
||||
5,21.0
|
||||
6,12.0
|
||||
7,34.0
|
||||
8,93.0
|
||||
9,36.0
|
||||
10,29.0
|
||||
11,18.0
|
||||
12,14.0
|
||||
13,62.0
|
||||
14,20.0
|
||||
15,40.0
|
||||
16,10.0
|
||||
17,10.0
|
||||
18,10.0
|
||||
19,11.0
|
||||
20,10.0
|
||||
21,14.0
|
||||
22,12.0
|
||||
23,8.0
|
||||
24,19.0
|
||||
25,33.0
|
||||
26,22.0
|
||||
27,32.0
|
||||
28,16.0
|
||||
29,24.0
|
||||
30,24.0
|
||||
31,24.0
|
||||
32,75.0
|
||||
33,33.0
|
||||
34,33.0
|
||||
35,72.0
|
||||
36,110.0
|
||||
37,48.0
|
||||
38,60.0
|
||||
39,43.0
|
||||
40,61.0
|
||||
41,34.0
|
||||
42,50.0
|
||||
43,61.0
|
||||
44,53.0
|
||||
45,58.0
|
||||
46,36.0
|
||||
47,44.0
|
||||
48,42.0
|
||||
49,64.0
|
||||
50,67.0
|
||||
51,52.0
|
||||
52,39.0
|
||||
53,42.0
|
||||
54,40.0
|
||||
55,33.0
|
||||
56,200.0
|
||||
57,199.0
|
||||
58,149.0
|
||||
59,185.0
|
||||
60,134.0
|
||||
61,174.0
|
||||
62,162.0
|
||||
63,200.0
|
||||
64,93.0
|
||||
65,72.0
|
||||
66,69.0
|
||||
67,51.0
|
||||
68,62.0
|
||||
69,98.0
|
||||
70,73.0
|
||||
71,73.0
|
||||
72,200.0
|
||||
73,200.0
|
||||
74,200.0
|
||||
75,200.0
|
||||
76,200.0
|
||||
77,200.0
|
||||
78,200.0
|
||||
79,133.0
|
||||
80,200.0
|
||||
81,200.0
|
||||
82,200.0
|
||||
83,200.0
|
||||
84,200.0
|
||||
85,200.0
|
||||
86,200.0
|
||||
87,200.0
|
||||
88,114.0
|
||||
89,151.0
|
||||
90,129.0
|
||||
91,156.0
|
||||
92,112.0
|
||||
93,172.0
|
||||
94,171.0
|
||||
95,141.0
|
||||
96,200.0
|
||||
97,200.0
|
||||
98,200.0
|
||||
99,200.0
|
||||
100,200.0
|
||||
101,200.0
|
||||
102,200.0
|
||||
103,200.0
|
||||
104,188.0
|
||||
105,199.0
|
||||
106,138.0
|
||||
107,200.0
|
||||
108,200.0
|
||||
109,181.0
|
||||
110,145.0
|
||||
111,200.0
|
||||
112,135.0
|
||||
113,119.0
|
||||
114,112.0
|
||||
115,122.0
|
||||
116,118.0
|
||||
117,119.0
|
||||
118,131.0
|
||||
119,119.0
|
||||
120,109.0
|
||||
121,96.0
|
||||
122,105.0
|
||||
123,29.0
|
||||
124,110.0
|
||||
125,113.0
|
||||
126,18.0
|
||||
127,90.0
|
||||
128,145.0
|
||||
129,152.0
|
||||
130,151.0
|
||||
131,109.0
|
||||
132,141.0
|
||||
133,109.0
|
||||
134,136.0
|
||||
135,143.0
|
||||
136,200.0
|
||||
137,200.0
|
||||
138,200.0
|
||||
139,200.0
|
||||
140,200.0
|
||||
141,200.0
|
||||
142,200.0
|
||||
143,200.0
|
||||
144,192.0
|
||||
145,173.0
|
||||
146,180.0
|
||||
147,182.0
|
||||
148,186.0
|
||||
149,175.0
|
||||
150,176.0
|
||||
151,191.0
|
||||
152,200.0
|
||||
153,200.0
|
||||
154,200.0
|
||||
155,200.0
|
||||
156,200.0
|
||||
157,200.0
|
||||
158,200.0
|
||||
159,200.0
|
||||
160,200.0
|
||||
161,200.0
|
||||
162,200.0
|
||||
163,200.0
|
||||
164,200.0
|
||||
165,200.0
|
||||
166,200.0
|
||||
167,200.0
|
||||
168,200.0
|
||||
169,200.0
|
||||
170,200.0
|
||||
171,200.0
|
||||
172,200.0
|
||||
173,200.0
|
||||
174,200.0
|
||||
175,200.0
|
||||
176,200.0
|
||||
177,200.0
|
||||
178,200.0
|
||||
179,200.0
|
||||
180,200.0
|
||||
181,200.0
|
||||
182,200.0
|
||||
183,200.0
|
||||
184,200.0
|
||||
185,200.0
|
||||
186,200.0
|
||||
187,200.0
|
||||
188,200.0
|
||||
189,200.0
|
||||
190,200.0
|
||||
191,200.0
|
||||
192,200.0
|
||||
193,200.0
|
||||
194,200.0
|
||||
195,200.0
|
||||
196,200.0
|
||||
197,200.0
|
||||
198,200.0
|
||||
199,200.0
|
||||
|
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-11-22 23:27:44
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 20:58:59
|
||||
LastEditTime: 2022-08-27 13:45:26
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -31,8 +31,11 @@ class PolicyGradient:
|
||||
state = torch.from_numpy(state).float()
|
||||
state = Variable(state)
|
||||
probs = self.policy_net(state)
|
||||
print("probs")
|
||||
print(probs)
|
||||
m = Bernoulli(probs) # 伯努利分布
|
||||
action = m.sample()
|
||||
|
||||
action = action.data.numpy().astype(int)[0] # 转为标量
|
||||
return action
|
||||
def predict_action(self,state):
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 14:59:15
|
||||
LastEditTime: 2022-08-26 22:46:21
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -57,7 +57,10 @@ class Main(Launcher):
|
||||
env = CliffWalkingWapper(env)
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
n_states = env.observation_space.n # state dimension
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
episodes,rewards
|
||||
0,-13
|
||||
1,-13
|
||||
2,-13
|
||||
3,-13
|
||||
4,-13
|
||||
5,-13
|
||||
6,-13
|
||||
7,-13
|
||||
8,-13
|
||||
9,-13
|
||||
10,-13
|
||||
11,-13
|
||||
12,-13
|
||||
13,-13
|
||||
14,-13
|
||||
15,-13
|
||||
16,-13
|
||||
17,-13
|
||||
18,-13
|
||||
19,-13
|
||||
|
@@ -1,401 +0,0 @@
|
||||
episodes,rewards
|
||||
0,-2131
|
||||
1,-1086
|
||||
2,-586
|
||||
3,-220
|
||||
4,-154
|
||||
5,-122
|
||||
6,-150
|
||||
7,-159
|
||||
8,-164
|
||||
9,-88
|
||||
10,-195
|
||||
11,-114
|
||||
12,-60
|
||||
13,-179
|
||||
14,-101
|
||||
15,-304
|
||||
16,-96
|
||||
17,-119
|
||||
18,-113
|
||||
19,-98
|
||||
20,-106
|
||||
21,-105
|
||||
22,-77
|
||||
23,-51
|
||||
24,-105
|
||||
25,-136
|
||||
26,-100
|
||||
27,-29
|
||||
28,-79
|
||||
29,-114
|
||||
30,-82
|
||||
31,-70
|
||||
32,-75
|
||||
33,-51
|
||||
34,-94
|
||||
35,-52
|
||||
36,-93
|
||||
37,-71
|
||||
38,-73
|
||||
39,-48
|
||||
40,-52
|
||||
41,-96
|
||||
42,-46
|
||||
43,-65
|
||||
44,-57
|
||||
45,-41
|
||||
46,-104
|
||||
47,-51
|
||||
48,-181
|
||||
49,-229
|
||||
50,-39
|
||||
51,-69
|
||||
52,-53
|
||||
53,-59
|
||||
54,-26
|
||||
55,-75
|
||||
56,-31
|
||||
57,-60
|
||||
58,-63
|
||||
59,-40
|
||||
60,-35
|
||||
61,-79
|
||||
62,-42
|
||||
63,-22
|
||||
64,-73
|
||||
65,-71
|
||||
66,-18
|
||||
67,-55
|
||||
68,-29
|
||||
69,-43
|
||||
70,-70
|
||||
71,-49
|
||||
72,-42
|
||||
73,-29
|
||||
74,-81
|
||||
75,-36
|
||||
76,-38
|
||||
77,-36
|
||||
78,-52
|
||||
79,-28
|
||||
80,-42
|
||||
81,-52
|
||||
82,-66
|
||||
83,-31
|
||||
84,-27
|
||||
85,-49
|
||||
86,-28
|
||||
87,-54
|
||||
88,-34
|
||||
89,-35
|
||||
90,-50
|
||||
91,-36
|
||||
92,-36
|
||||
93,-46
|
||||
94,-34
|
||||
95,-135
|
||||
96,-39
|
||||
97,-36
|
||||
98,-26
|
||||
99,-56
|
||||
100,-40
|
||||
101,-40
|
||||
102,-26
|
||||
103,-28
|
||||
104,-31
|
||||
105,-35
|
||||
106,-26
|
||||
107,-57
|
||||
108,-44
|
||||
109,-41
|
||||
110,-31
|
||||
111,-26
|
||||
112,-25
|
||||
113,-41
|
||||
114,-32
|
||||
115,-44
|
||||
116,-30
|
||||
117,-32
|
||||
118,-30
|
||||
119,-25
|
||||
120,-23
|
||||
121,-47
|
||||
122,-24
|
||||
123,-45
|
||||
124,-39
|
||||
125,-21
|
||||
126,-43
|
||||
127,-143
|
||||
128,-26
|
||||
129,-20
|
||||
130,-32
|
||||
131,-16
|
||||
132,-24
|
||||
133,-42
|
||||
134,-25
|
||||
135,-36
|
||||
136,-19
|
||||
137,-29
|
||||
138,-43
|
||||
139,-17
|
||||
140,-150
|
||||
141,-32
|
||||
142,-34
|
||||
143,-19
|
||||
144,-26
|
||||
145,-30
|
||||
146,-31
|
||||
147,-49
|
||||
148,-33
|
||||
149,-21
|
||||
150,-17
|
||||
151,-48
|
||||
152,-34
|
||||
153,-20
|
||||
154,-20
|
||||
155,-26
|
||||
156,-21
|
||||
157,-13
|
||||
158,-40
|
||||
159,-22
|
||||
160,-26
|
||||
161,-30
|
||||
162,-29
|
||||
163,-25
|
||||
164,-26
|
||||
165,-27
|
||||
166,-21
|
||||
167,-29
|
||||
168,-24
|
||||
169,-17
|
||||
170,-22
|
||||
171,-35
|
||||
172,-35
|
||||
173,-18
|
||||
174,-135
|
||||
175,-15
|
||||
176,-23
|
||||
177,-28
|
||||
178,-25
|
||||
179,-24
|
||||
180,-29
|
||||
181,-31
|
||||
182,-24
|
||||
183,-129
|
||||
184,-45
|
||||
185,-24
|
||||
186,-17
|
||||
187,-20
|
||||
188,-21
|
||||
189,-23
|
||||
190,-15
|
||||
191,-32
|
||||
192,-22
|
||||
193,-19
|
||||
194,-17
|
||||
195,-45
|
||||
196,-15
|
||||
197,-14
|
||||
198,-14
|
||||
199,-37
|
||||
200,-23
|
||||
201,-17
|
||||
202,-19
|
||||
203,-21
|
||||
204,-23
|
||||
205,-27
|
||||
206,-14
|
||||
207,-18
|
||||
208,-23
|
||||
209,-34
|
||||
210,-23
|
||||
211,-13
|
||||
212,-25
|
||||
213,-17
|
||||
214,-13
|
||||
215,-21
|
||||
216,-29
|
||||
217,-18
|
||||
218,-24
|
||||
219,-15
|
||||
220,-27
|
||||
221,-25
|
||||
222,-21
|
||||
223,-19
|
||||
224,-17
|
||||
225,-18
|
||||
226,-13
|
||||
227,-22
|
||||
228,-14
|
||||
229,-13
|
||||
230,-29
|
||||
231,-23
|
||||
232,-15
|
||||
233,-15
|
||||
234,-14
|
||||
235,-28
|
||||
236,-25
|
||||
237,-17
|
||||
238,-23
|
||||
239,-29
|
||||
240,-15
|
||||
241,-14
|
||||
242,-15
|
||||
243,-23
|
||||
244,-15
|
||||
245,-16
|
||||
246,-19
|
||||
247,-13
|
||||
248,-16
|
||||
249,-17
|
||||
250,-25
|
||||
251,-30
|
||||
252,-13
|
||||
253,-14
|
||||
254,-15
|
||||
255,-22
|
||||
256,-14
|
||||
257,-17
|
||||
258,-126
|
||||
259,-15
|
||||
260,-21
|
||||
261,-16
|
||||
262,-23
|
||||
263,-14
|
||||
264,-13
|
||||
265,-13
|
||||
266,-19
|
||||
267,-13
|
||||
268,-19
|
||||
269,-17
|
||||
270,-17
|
||||
271,-13
|
||||
272,-19
|
||||
273,-13
|
||||
274,-13
|
||||
275,-16
|
||||
276,-22
|
||||
277,-14
|
||||
278,-15
|
||||
279,-19
|
||||
280,-34
|
||||
281,-13
|
||||
282,-15
|
||||
283,-32
|
||||
284,-13
|
||||
285,-13
|
||||
286,-13
|
||||
287,-14
|
||||
288,-16
|
||||
289,-13
|
||||
290,-13
|
||||
291,-17
|
||||
292,-13
|
||||
293,-13
|
||||
294,-22
|
||||
295,-14
|
||||
296,-15
|
||||
297,-13
|
||||
298,-13
|
||||
299,-13
|
||||
300,-16
|
||||
301,-13
|
||||
302,-14
|
||||
303,-13
|
||||
304,-13
|
||||
305,-13
|
||||
306,-24
|
||||
307,-13
|
||||
308,-13
|
||||
309,-15
|
||||
310,-13
|
||||
311,-13
|
||||
312,-13
|
||||
313,-15
|
||||
314,-13
|
||||
315,-19
|
||||
316,-15
|
||||
317,-17
|
||||
318,-13
|
||||
319,-13
|
||||
320,-13
|
||||
321,-13
|
||||
322,-13
|
||||
323,-15
|
||||
324,-13
|
||||
325,-13
|
||||
326,-13
|
||||
327,-123
|
||||
328,-13
|
||||
329,-13
|
||||
330,-13
|
||||
331,-13
|
||||
332,-13
|
||||
333,-13
|
||||
334,-13
|
||||
335,-13
|
||||
336,-16
|
||||
337,-13
|
||||
338,-23
|
||||
339,-13
|
||||
340,-13
|
||||
341,-13
|
||||
342,-13
|
||||
343,-13
|
||||
344,-13
|
||||
345,-13
|
||||
346,-13
|
||||
347,-13
|
||||
348,-13
|
||||
349,-13
|
||||
350,-134
|
||||
351,-13
|
||||
352,-13
|
||||
353,-13
|
||||
354,-13
|
||||
355,-13
|
||||
356,-13
|
||||
357,-13
|
||||
358,-13
|
||||
359,-13
|
||||
360,-15
|
||||
361,-13
|
||||
362,-13
|
||||
363,-13
|
||||
364,-13
|
||||
365,-13
|
||||
366,-13
|
||||
367,-13
|
||||
368,-13
|
||||
369,-14
|
||||
370,-13
|
||||
371,-13
|
||||
372,-13
|
||||
373,-13
|
||||
374,-13
|
||||
375,-13
|
||||
376,-13
|
||||
377,-124
|
||||
378,-13
|
||||
379,-13
|
||||
380,-13
|
||||
381,-13
|
||||
382,-13
|
||||
383,-13
|
||||
384,-13
|
||||
385,-13
|
||||
386,-13
|
||||
387,-13
|
||||
388,-13
|
||||
389,-121
|
||||
390,-13
|
||||
391,-13
|
||||
392,-13
|
||||
393,-13
|
||||
394,-13
|
||||
395,-13
|
||||
396,-13
|
||||
397,-13
|
||||
398,-17
|
||||
399,-13
|
||||
|
@@ -0,0 +1 @@
|
||||
{"algo_name": "Q-learning", "env_name": "CliffWalking-v0", "train_eps": 400, "test_eps": 20, "gamma": 0.9, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 300, "lr": 0.1, "device": "cpu", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\QLearning/outputs/CliffWalking-v0/20220826-224730/results/", "model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\QLearning/outputs/CliffWalking-v0/20220826-224730/models/", "n_states": 48, "n_actions": 4}
|
||||
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,-13,13
|
||||
1,-13,13
|
||||
2,-13,13
|
||||
3,-13,13
|
||||
4,-13,13
|
||||
5,-13,13
|
||||
6,-13,13
|
||||
7,-13,13
|
||||
8,-13,13
|
||||
9,-13,13
|
||||
10,-13,13
|
||||
11,-13,13
|
||||
12,-13,13
|
||||
13,-13,13
|
||||
14,-13,13
|
||||
15,-13,13
|
||||
16,-13,13
|
||||
17,-13,13
|
||||
18,-13,13
|
||||
19,-13,13
|
||||
|
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 35 KiB |
@@ -0,0 +1,401 @@
|
||||
episodes,rewards,steps
|
||||
0,-2131,448
|
||||
1,-1086,492
|
||||
2,-586,388
|
||||
3,-220,220
|
||||
4,-154,154
|
||||
5,-122,122
|
||||
6,-150,150
|
||||
7,-159,159
|
||||
8,-164,164
|
||||
9,-88,88
|
||||
10,-195,195
|
||||
11,-114,114
|
||||
12,-60,60
|
||||
13,-179,179
|
||||
14,-101,101
|
||||
15,-304,205
|
||||
16,-96,96
|
||||
17,-119,119
|
||||
18,-113,113
|
||||
19,-98,98
|
||||
20,-106,106
|
||||
21,-105,105
|
||||
22,-77,77
|
||||
23,-51,51
|
||||
24,-105,105
|
||||
25,-136,136
|
||||
26,-100,100
|
||||
27,-29,29
|
||||
28,-79,79
|
||||
29,-114,114
|
||||
30,-82,82
|
||||
31,-70,70
|
||||
32,-75,75
|
||||
33,-51,51
|
||||
34,-94,94
|
||||
35,-52,52
|
||||
36,-93,93
|
||||
37,-71,71
|
||||
38,-73,73
|
||||
39,-48,48
|
||||
40,-52,52
|
||||
41,-96,96
|
||||
42,-46,46
|
||||
43,-65,65
|
||||
44,-57,57
|
||||
45,-41,41
|
||||
46,-104,104
|
||||
47,-51,51
|
||||
48,-181,82
|
||||
49,-229,130
|
||||
50,-39,39
|
||||
51,-69,69
|
||||
52,-53,53
|
||||
53,-59,59
|
||||
54,-26,26
|
||||
55,-75,75
|
||||
56,-31,31
|
||||
57,-60,60
|
||||
58,-63,63
|
||||
59,-40,40
|
||||
60,-35,35
|
||||
61,-79,79
|
||||
62,-42,42
|
||||
63,-22,22
|
||||
64,-73,73
|
||||
65,-71,71
|
||||
66,-18,18
|
||||
67,-55,55
|
||||
68,-29,29
|
||||
69,-43,43
|
||||
70,-70,70
|
||||
71,-49,49
|
||||
72,-42,42
|
||||
73,-29,29
|
||||
74,-81,81
|
||||
75,-36,36
|
||||
76,-38,38
|
||||
77,-36,36
|
||||
78,-52,52
|
||||
79,-28,28
|
||||
80,-42,42
|
||||
81,-52,52
|
||||
82,-66,66
|
||||
83,-31,31
|
||||
84,-27,27
|
||||
85,-49,49
|
||||
86,-28,28
|
||||
87,-54,54
|
||||
88,-34,34
|
||||
89,-35,35
|
||||
90,-50,50
|
||||
91,-36,36
|
||||
92,-36,36
|
||||
93,-46,46
|
||||
94,-34,34
|
||||
95,-135,36
|
||||
96,-39,39
|
||||
97,-36,36
|
||||
98,-26,26
|
||||
99,-56,56
|
||||
100,-40,40
|
||||
101,-40,40
|
||||
102,-26,26
|
||||
103,-28,28
|
||||
104,-31,31
|
||||
105,-35,35
|
||||
106,-26,26
|
||||
107,-57,57
|
||||
108,-44,44
|
||||
109,-41,41
|
||||
110,-31,31
|
||||
111,-26,26
|
||||
112,-25,25
|
||||
113,-41,41
|
||||
114,-32,32
|
||||
115,-44,44
|
||||
116,-30,30
|
||||
117,-32,32
|
||||
118,-30,30
|
||||
119,-25,25
|
||||
120,-23,23
|
||||
121,-47,47
|
||||
122,-24,24
|
||||
123,-45,45
|
||||
124,-39,39
|
||||
125,-21,21
|
||||
126,-43,43
|
||||
127,-143,44
|
||||
128,-26,26
|
||||
129,-20,20
|
||||
130,-32,32
|
||||
131,-16,16
|
||||
132,-24,24
|
||||
133,-42,42
|
||||
134,-25,25
|
||||
135,-36,36
|
||||
136,-19,19
|
||||
137,-29,29
|
||||
138,-43,43
|
||||
139,-17,17
|
||||
140,-150,51
|
||||
141,-32,32
|
||||
142,-34,34
|
||||
143,-19,19
|
||||
144,-26,26
|
||||
145,-30,30
|
||||
146,-31,31
|
||||
147,-49,49
|
||||
148,-33,33
|
||||
149,-21,21
|
||||
150,-17,17
|
||||
151,-48,48
|
||||
152,-34,34
|
||||
153,-20,20
|
||||
154,-20,20
|
||||
155,-26,26
|
||||
156,-21,21
|
||||
157,-13,13
|
||||
158,-40,40
|
||||
159,-22,22
|
||||
160,-26,26
|
||||
161,-30,30
|
||||
162,-29,29
|
||||
163,-25,25
|
||||
164,-26,26
|
||||
165,-27,27
|
||||
166,-21,21
|
||||
167,-29,29
|
||||
168,-24,24
|
||||
169,-17,17
|
||||
170,-22,22
|
||||
171,-35,35
|
||||
172,-35,35
|
||||
173,-18,18
|
||||
174,-135,36
|
||||
175,-15,15
|
||||
176,-23,23
|
||||
177,-28,28
|
||||
178,-25,25
|
||||
179,-24,24
|
||||
180,-29,29
|
||||
181,-31,31
|
||||
182,-24,24
|
||||
183,-129,30
|
||||
184,-45,45
|
||||
185,-24,24
|
||||
186,-17,17
|
||||
187,-20,20
|
||||
188,-21,21
|
||||
189,-23,23
|
||||
190,-15,15
|
||||
191,-32,32
|
||||
192,-22,22
|
||||
193,-19,19
|
||||
194,-17,17
|
||||
195,-45,45
|
||||
196,-15,15
|
||||
197,-14,14
|
||||
198,-14,14
|
||||
199,-37,37
|
||||
200,-23,23
|
||||
201,-17,17
|
||||
202,-19,19
|
||||
203,-21,21
|
||||
204,-23,23
|
||||
205,-27,27
|
||||
206,-14,14
|
||||
207,-18,18
|
||||
208,-23,23
|
||||
209,-34,34
|
||||
210,-23,23
|
||||
211,-13,13
|
||||
212,-25,25
|
||||
213,-17,17
|
||||
214,-13,13
|
||||
215,-21,21
|
||||
216,-29,29
|
||||
217,-18,18
|
||||
218,-24,24
|
||||
219,-15,15
|
||||
220,-27,27
|
||||
221,-25,25
|
||||
222,-21,21
|
||||
223,-19,19
|
||||
224,-17,17
|
||||
225,-18,18
|
||||
226,-13,13
|
||||
227,-22,22
|
||||
228,-14,14
|
||||
229,-13,13
|
||||
230,-29,29
|
||||
231,-23,23
|
||||
232,-15,15
|
||||
233,-15,15
|
||||
234,-14,14
|
||||
235,-28,28
|
||||
236,-25,25
|
||||
237,-17,17
|
||||
238,-23,23
|
||||
239,-29,29
|
||||
240,-15,15
|
||||
241,-14,14
|
||||
242,-15,15
|
||||
243,-23,23
|
||||
244,-15,15
|
||||
245,-16,16
|
||||
246,-19,19
|
||||
247,-13,13
|
||||
248,-16,16
|
||||
249,-17,17
|
||||
250,-25,25
|
||||
251,-30,30
|
||||
252,-13,13
|
||||
253,-14,14
|
||||
254,-15,15
|
||||
255,-22,22
|
||||
256,-14,14
|
||||
257,-17,17
|
||||
258,-126,27
|
||||
259,-15,15
|
||||
260,-21,21
|
||||
261,-16,16
|
||||
262,-23,23
|
||||
263,-14,14
|
||||
264,-13,13
|
||||
265,-13,13
|
||||
266,-19,19
|
||||
267,-13,13
|
||||
268,-19,19
|
||||
269,-17,17
|
||||
270,-17,17
|
||||
271,-13,13
|
||||
272,-19,19
|
||||
273,-13,13
|
||||
274,-13,13
|
||||
275,-16,16
|
||||
276,-22,22
|
||||
277,-14,14
|
||||
278,-15,15
|
||||
279,-19,19
|
||||
280,-34,34
|
||||
281,-13,13
|
||||
282,-15,15
|
||||
283,-32,32
|
||||
284,-13,13
|
||||
285,-13,13
|
||||
286,-13,13
|
||||
287,-14,14
|
||||
288,-16,16
|
||||
289,-13,13
|
||||
290,-13,13
|
||||
291,-17,17
|
||||
292,-13,13
|
||||
293,-13,13
|
||||
294,-22,22
|
||||
295,-14,14
|
||||
296,-15,15
|
||||
297,-13,13
|
||||
298,-13,13
|
||||
299,-13,13
|
||||
300,-16,16
|
||||
301,-13,13
|
||||
302,-14,14
|
||||
303,-13,13
|
||||
304,-13,13
|
||||
305,-13,13
|
||||
306,-24,24
|
||||
307,-13,13
|
||||
308,-13,13
|
||||
309,-15,15
|
||||
310,-13,13
|
||||
311,-13,13
|
||||
312,-13,13
|
||||
313,-15,15
|
||||
314,-13,13
|
||||
315,-19,19
|
||||
316,-15,15
|
||||
317,-17,17
|
||||
318,-13,13
|
||||
319,-13,13
|
||||
320,-13,13
|
||||
321,-13,13
|
||||
322,-13,13
|
||||
323,-15,15
|
||||
324,-13,13
|
||||
325,-13,13
|
||||
326,-13,13
|
||||
327,-123,24
|
||||
328,-13,13
|
||||
329,-13,13
|
||||
330,-13,13
|
||||
331,-13,13
|
||||
332,-13,13
|
||||
333,-13,13
|
||||
334,-13,13
|
||||
335,-13,13
|
||||
336,-16,16
|
||||
337,-13,13
|
||||
338,-23,23
|
||||
339,-13,13
|
||||
340,-13,13
|
||||
341,-13,13
|
||||
342,-13,13
|
||||
343,-13,13
|
||||
344,-13,13
|
||||
345,-13,13
|
||||
346,-13,13
|
||||
347,-13,13
|
||||
348,-13,13
|
||||
349,-13,13
|
||||
350,-134,35
|
||||
351,-13,13
|
||||
352,-13,13
|
||||
353,-13,13
|
||||
354,-13,13
|
||||
355,-13,13
|
||||
356,-13,13
|
||||
357,-13,13
|
||||
358,-13,13
|
||||
359,-13,13
|
||||
360,-15,15
|
||||
361,-13,13
|
||||
362,-13,13
|
||||
363,-13,13
|
||||
364,-13,13
|
||||
365,-13,13
|
||||
366,-13,13
|
||||
367,-13,13
|
||||
368,-13,13
|
||||
369,-14,14
|
||||
370,-13,13
|
||||
371,-13,13
|
||||
372,-13,13
|
||||
373,-13,13
|
||||
374,-13,13
|
||||
375,-13,13
|
||||
376,-13,13
|
||||
377,-124,25
|
||||
378,-13,13
|
||||
379,-13,13
|
||||
380,-13,13
|
||||
381,-13,13
|
||||
382,-13,13
|
||||
383,-13,13
|
||||
384,-13,13
|
||||
385,-13,13
|
||||
386,-13,13
|
||||
387,-13,13
|
||||
388,-13,13
|
||||
389,-121,22
|
||||
390,-13,13
|
||||
391,-13,13
|
||||
392,-13,13
|
||||
393,-13,13
|
||||
394,-13,13
|
||||
395,-13,13
|
||||
396,-13,13
|
||||
397,-13,13
|
||||
398,-17,17
|
||||
399,-13,13
|
||||
|
@@ -0,0 +1 @@
|
||||
{"algo_name": "Q-learning", "env_name": "Racetrack-v0", "train_eps": 400, "test_eps": 20, "gamma": 0.9, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 300, "lr": 0.1, "device": "cpu", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\QLearning/outputs/Racetrack-v0/20220826-224626/results/", "model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\QLearning/outputs/Racetrack-v0/20220826-224626/models/", "n_states": 4, "n_actions": 9}
|
||||
|
After Width: | Height: | Size: 39 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,-1000,1000
|
||||
1,2,8
|
||||
2,4,6
|
||||
3,3,7
|
||||
4,2,8
|
||||
5,3,7
|
||||
6,4,6
|
||||
7,-1000,1000
|
||||
8,3,7
|
||||
9,-11,11
|
||||
10,-19,19
|
||||
11,-18,18
|
||||
12,1,9
|
||||
13,1,9
|
||||
14,4,6
|
||||
15,-16,16
|
||||
16,-17,17
|
||||
17,4,6
|
||||
18,-16,16
|
||||
19,4,6
|
||||
|
|
After Width: | Height: | Size: 40 KiB |
@@ -0,0 +1,401 @@
|
||||
episodes,rewards,steps
|
||||
0,-3580,1000
|
||||
1,-2960,1000
|
||||
2,-2670,1000
|
||||
3,-2720,1000
|
||||
4,-2670,1000
|
||||
5,-2570,1000
|
||||
6,-2407,977
|
||||
7,-2012,852
|
||||
8,-2500,1000
|
||||
9,-2530,1000
|
||||
10,-2550,1000
|
||||
11,-437,187
|
||||
12,-80,40
|
||||
13,-2450,1000
|
||||
14,-338,148
|
||||
15,-1175,525
|
||||
16,-755,325
|
||||
17,-411,181
|
||||
18,-1068,448
|
||||
19,-785,325
|
||||
20,-149,79
|
||||
21,-628,268
|
||||
22,-423,183
|
||||
23,-282,122
|
||||
24,-2198,938
|
||||
25,-13,13
|
||||
26,-253,113
|
||||
27,-48,28
|
||||
28,-72,42
|
||||
29,-123,63
|
||||
30,-305,145
|
||||
31,-72,32
|
||||
32,-142,72
|
||||
33,-13,13
|
||||
34,4,6
|
||||
35,-1285,545
|
||||
36,-174,94
|
||||
37,-436,196
|
||||
38,-759,339
|
||||
39,-11,11
|
||||
40,-17,17
|
||||
41,-283,123
|
||||
42,-181,81
|
||||
43,-44,24
|
||||
44,-55,35
|
||||
45,-135,65
|
||||
46,-577,277
|
||||
47,-234,114
|
||||
48,-54,34
|
||||
49,4,6
|
||||
50,-29,19
|
||||
51,-100,50
|
||||
52,-32,22
|
||||
53,-23,23
|
||||
54,4,6
|
||||
55,-17,17
|
||||
56,-18,18
|
||||
57,-48,28
|
||||
58,-34,24
|
||||
59,-45,25
|
||||
60,-29,19
|
||||
61,1,9
|
||||
62,-77,37
|
||||
63,3,7
|
||||
64,-25,15
|
||||
65,-3,13
|
||||
66,-78,48
|
||||
67,-69,39
|
||||
68,-105,45
|
||||
69,-48,28
|
||||
70,3,7
|
||||
71,4,6
|
||||
72,-100,50
|
||||
73,-130,60
|
||||
74,-20,20
|
||||
75,4,6
|
||||
76,4,6
|
||||
77,4,6
|
||||
78,4,6
|
||||
79,-47,27
|
||||
80,4,6
|
||||
81,4,6
|
||||
82,-174,94
|
||||
83,-12,12
|
||||
84,-26,16
|
||||
85,3,7
|
||||
86,3,7
|
||||
87,-42,32
|
||||
88,-48,28
|
||||
89,-97,57
|
||||
90,-11,11
|
||||
91,-16,16
|
||||
92,-15,15
|
||||
93,4,6
|
||||
94,-147,67
|
||||
95,-52,32
|
||||
96,-97,47
|
||||
97,3,7
|
||||
98,-17,17
|
||||
99,3,7
|
||||
100,4,6
|
||||
101,3,7
|
||||
102,3,7
|
||||
103,3,7
|
||||
104,1,9
|
||||
105,4,6
|
||||
106,4,6
|
||||
107,3,7
|
||||
108,4,6
|
||||
109,-68,38
|
||||
110,3,7
|
||||
111,4,6
|
||||
112,-14,14
|
||||
113,4,6
|
||||
114,-57,37
|
||||
115,3,7
|
||||
116,4,6
|
||||
117,-12,12
|
||||
118,3,7
|
||||
119,3,7
|
||||
120,-64,34
|
||||
121,-13,13
|
||||
122,3,7
|
||||
123,-13,13
|
||||
124,4,6
|
||||
125,3,7
|
||||
126,-32,22
|
||||
127,-41,31
|
||||
128,3,7
|
||||
129,3,7
|
||||
130,3,7
|
||||
131,4,6
|
||||
132,4,6
|
||||
133,3,7
|
||||
134,-12,12
|
||||
135,-31,21
|
||||
136,4,6
|
||||
137,3,7
|
||||
138,-51,31
|
||||
139,-48,28
|
||||
140,4,6
|
||||
141,-85,45
|
||||
142,-14,14
|
||||
143,4,6
|
||||
144,3,7
|
||||
145,-6,16
|
||||
146,4,6
|
||||
147,4,6
|
||||
148,-15,15
|
||||
149,4,6
|
||||
150,-24,24
|
||||
151,3,7
|
||||
152,-14,14
|
||||
153,-18,18
|
||||
154,3,7
|
||||
155,4,6
|
||||
156,-85,45
|
||||
157,-51,31
|
||||
158,3,7
|
||||
159,2,8
|
||||
160,3,7
|
||||
161,-79,39
|
||||
162,-14,14
|
||||
163,-13,13
|
||||
164,4,6
|
||||
165,3,7
|
||||
166,4,6
|
||||
167,3,7
|
||||
168,-74,34
|
||||
169,-15,15
|
||||
170,4,6
|
||||
171,-14,14
|
||||
172,4,6
|
||||
173,-31,21
|
||||
174,-8,18
|
||||
175,4,6
|
||||
176,4,6
|
||||
177,4,6
|
||||
178,4,6
|
||||
179,-29,19
|
||||
180,4,6
|
||||
181,3,7
|
||||
182,4,6
|
||||
183,-82,42
|
||||
184,3,7
|
||||
185,4,6
|
||||
186,4,6
|
||||
187,-11,11
|
||||
188,-23,23
|
||||
189,-33,23
|
||||
190,3,7
|
||||
191,-12,12
|
||||
192,-44,24
|
||||
193,-62,42
|
||||
194,-16,16
|
||||
195,4,6
|
||||
196,-12,12
|
||||
197,3,7
|
||||
198,-13,13
|
||||
199,3,7
|
||||
200,3,7
|
||||
201,4,6
|
||||
202,4,6
|
||||
203,4,6
|
||||
204,-28,18
|
||||
205,-16,16
|
||||
206,3,7
|
||||
207,4,6
|
||||
208,-12,12
|
||||
209,-13,13
|
||||
210,-66,36
|
||||
211,-14,14
|
||||
212,4,6
|
||||
213,4,6
|
||||
214,-15,15
|
||||
215,-60,30
|
||||
216,4,6
|
||||
217,3,7
|
||||
218,4,6
|
||||
219,-33,23
|
||||
220,-12,12
|
||||
221,-14,14
|
||||
222,4,6
|
||||
223,3,7
|
||||
224,-97,47
|
||||
225,4,6
|
||||
226,2,8
|
||||
227,4,6
|
||||
228,4,6
|
||||
229,3,7
|
||||
230,-11,11
|
||||
231,4,6
|
||||
232,3,7
|
||||
233,3,7
|
||||
234,4,6
|
||||
235,3,7
|
||||
236,3,7
|
||||
237,-32,22
|
||||
238,-13,13
|
||||
239,3,7
|
||||
240,-22,22
|
||||
241,4,6
|
||||
242,2,8
|
||||
243,-31,21
|
||||
244,4,6
|
||||
245,-4,14
|
||||
246,-30,20
|
||||
247,4,6
|
||||
248,3,7
|
||||
249,-26,16
|
||||
250,4,6
|
||||
251,-12,12
|
||||
252,2,8
|
||||
253,1,9
|
||||
254,4,6
|
||||
255,2,8
|
||||
256,2,8
|
||||
257,-12,12
|
||||
258,3,7
|
||||
259,-48,28
|
||||
260,4,6
|
||||
261,4,6
|
||||
262,-51,31
|
||||
263,-12,12
|
||||
264,4,6
|
||||
265,2,8
|
||||
266,2,8
|
||||
267,2,8
|
||||
268,3,7
|
||||
269,4,6
|
||||
270,4,6
|
||||
271,-17,17
|
||||
272,4,6
|
||||
273,-13,13
|
||||
274,-16,16
|
||||
275,-97,57
|
||||
276,3,7
|
||||
277,-1,11
|
||||
278,-32,22
|
||||
279,3,7
|
||||
280,4,6
|
||||
281,3,7
|
||||
282,3,7
|
||||
283,3,7
|
||||
284,3,7
|
||||
285,2,8
|
||||
286,3,7
|
||||
287,-15,15
|
||||
288,2,8
|
||||
289,-18,18
|
||||
290,4,6
|
||||
291,-36,26
|
||||
292,4,6
|
||||
293,4,6
|
||||
294,4,6
|
||||
295,4,6
|
||||
296,-77,47
|
||||
297,-14,14
|
||||
298,3,7
|
||||
299,3,7
|
||||
300,3,7
|
||||
301,4,6
|
||||
302,3,7
|
||||
303,4,6
|
||||
304,-12,12
|
||||
305,-45,35
|
||||
306,-63,43
|
||||
307,2,8
|
||||
308,4,6
|
||||
309,4,6
|
||||
310,-13,13
|
||||
311,4,6
|
||||
312,-13,13
|
||||
313,4,6
|
||||
314,3,7
|
||||
315,-30,20
|
||||
316,-13,13
|
||||
317,3,7
|
||||
318,4,6
|
||||
319,4,6
|
||||
320,-12,12
|
||||
321,-13,13
|
||||
322,3,7
|
||||
323,3,7
|
||||
324,3,7
|
||||
325,3,7
|
||||
326,-36,26
|
||||
327,4,6
|
||||
328,3,7
|
||||
329,3,7
|
||||
330,3,7
|
||||
331,3,7
|
||||
332,-14,14
|
||||
333,-16,16
|
||||
334,3,7
|
||||
335,3,7
|
||||
336,-14,14
|
||||
337,1,9
|
||||
338,2,8
|
||||
339,3,7
|
||||
340,4,6
|
||||
341,-36,26
|
||||
342,-14,14
|
||||
343,-78,48
|
||||
344,2,8
|
||||
345,-37,27
|
||||
346,3,7
|
||||
347,3,7
|
||||
348,-37,27
|
||||
349,-16,16
|
||||
350,4,6
|
||||
351,-15,15
|
||||
352,4,6
|
||||
353,2,8
|
||||
354,-44,24
|
||||
355,-13,13
|
||||
356,-14,14
|
||||
357,-17,17
|
||||
358,-13,13
|
||||
359,3,7
|
||||
360,2,8
|
||||
361,4,6
|
||||
362,3,7
|
||||
363,-5,15
|
||||
364,-14,14
|
||||
365,2,8
|
||||
366,-12,12
|
||||
367,3,7
|
||||
368,4,6
|
||||
369,2,8
|
||||
370,2,8
|
||||
371,1,9
|
||||
372,-16,16
|
||||
373,1,9
|
||||
374,4,6
|
||||
375,-16,16
|
||||
376,3,7
|
||||
377,2,8
|
||||
378,-13,13
|
||||
379,-44,34
|
||||
380,-16,16
|
||||
381,-30,20
|
||||
382,4,6
|
||||
383,4,6
|
||||
384,2,8
|
||||
385,-15,15
|
||||
386,4,6
|
||||
387,3,7
|
||||
388,2,8
|
||||
389,4,6
|
||||
390,2,8
|
||||
391,3,7
|
||||
392,3,7
|
||||
393,-14,14
|
||||
394,-15,15
|
||||
395,3,7
|
||||
396,-13,13
|
||||
397,3,7
|
||||
398,4,6
|
||||
399,3,7
|
||||
|
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-11 17:59:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 14:26:36
|
||||
LastEditTime: 2022-08-26 23:03:39
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -20,117 +20,105 @@ import argparse
|
||||
from envs.register import register_env
|
||||
from envs.wrappers import CliffWalkingWapper
|
||||
from Sarsa.sarsa import Sarsa
|
||||
from common.utils import save_results,make_dir,plot_rewards,save_args,all_seed
|
||||
from common.utils import all_seed
|
||||
from common.launcher import Launcher
|
||||
|
||||
def get_args():
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='Racetrack-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
class Main(Launcher):
|
||||
def get_args(self):
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default = 'Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default = 'Racetrack-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default = 300,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default = 20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
|
||||
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.90,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=200,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.2,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
|
||||
def env_agent_config(cfg):
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed= cfg['seed'])
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = CliffWalkingWapper(env)
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = Sarsa(cfg)
|
||||
return env,agent
|
||||
def env_agent_config(self,cfg):
|
||||
register_env(cfg['env_name'])
|
||||
env = gym.make(cfg['env_name'])
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed= cfg['seed'])
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = CliffWalkingWapper(env)
|
||||
try: # state dimension
|
||||
n_states = env.observation_space.n # print(hasattr(env.observation_space, 'n'))
|
||||
except AttributeError:
|
||||
n_states = env.observation_space.shape[0] # print(hasattr(env.observation_space, 'shape'))
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = Sarsa(cfg)
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
action = agent.sample_action(state)
|
||||
while True:
|
||||
# for _ in range(cfg.ep_max_steps):
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
next_action = agent.sample_action(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done) # update agent
|
||||
state = next_state # update state
|
||||
action = next_action
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
def train(self,cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
action = agent.sample_action(state)
|
||||
# while True:
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
next_action = agent.sample_action(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done) # update agent
|
||||
state = next_state # update state
|
||||
action = next_action
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps: {ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
while True:
|
||||
# for _ in range(cfg.ep_max_steps):
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done = env.step(action)
|
||||
state = next_state
|
||||
ep_reward+=reward
|
||||
ep_step+=1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
def test(self,cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
for _ in range(cfg['ep_max_steps']):
|
||||
action = agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward+=reward
|
||||
ep_step+=1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps: {ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||
main = Main()
|
||||
main.run()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"algo_name": "Sarsa", "env_name": "CliffWalking-v0", "train_eps": 300, "test_eps": 20, "ep_max_steps": 200, "gamma": 0.99, "epsilon_start": 0.9, "epsilon_end": 0.01, "epsilon_decay": 200, "lr": 0.2, "device": "cpu", "result_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220803-142740/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220803-142740/models/", "save_fig": true}
|
||||
|
Before Width: | Height: | Size: 34 KiB |
|
Before Width: | Height: | Size: 54 KiB |
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"algo_name": "Sarsa",
|
||||
"env_name": "CliffWalking-v0",
|
||||
"train_eps": 400,
|
||||
"test_eps": 20,
|
||||
"gamma": 0.9,
|
||||
"epsilon_start": 0.95,
|
||||
"epsilon_end": 0.01,
|
||||
"epsilon_decay": 300,
|
||||
"lr": 0.1,
|
||||
"device": "cpu",
|
||||
"result_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\Sarsa/outputs/CliffWalking-v0/20220804-223029/results/",
|
||||
"model_path": "c:\\Users\\24438\\Desktop\\rl-tutorials\\codes\\Sarsa/outputs/CliffWalking-v0/20220804-223029/models/",
|
||||
"save_fig": true
|
||||
}
|
||||
|
Before Width: | Height: | Size: 34 KiB |
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"algo_name": "Q-learning",
|
||||
"algo_name": "Sarsa",
|
||||
"env_name": "CliffWalking-v0",
|
||||
"train_eps": 400,
|
||||
"test_eps": 20,
|
||||
@@ -12,8 +12,8 @@
|
||||
"seed": 10,
|
||||
"show_fig": false,
|
||||
"save_fig": true,
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/CliffWalking-v0/20220824-103255/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/CliffWalking-v0/20220824-103255/models/",
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220825-213316/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/CliffWalking-v0/20220825-213316/models/",
|
||||
"n_states": 48,
|
||||
"n_actions": 4
|
||||
}
|
||||
|
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 25 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,-15,15
|
||||
1,-15,15
|
||||
2,-15,15
|
||||
3,-15,15
|
||||
4,-15,15
|
||||
5,-15,15
|
||||
6,-15,15
|
||||
7,-15,15
|
||||
8,-15,15
|
||||
9,-15,15
|
||||
10,-15,15
|
||||
11,-15,15
|
||||
12,-15,15
|
||||
13,-15,15
|
||||
14,-15,15
|
||||
15,-15,15
|
||||
16,-15,15
|
||||
17,-15,15
|
||||
18,-15,15
|
||||
19,-15,15
|
||||
|
|
After Width: | Height: | Size: 33 KiB |
@@ -0,0 +1,401 @@
|
||||
episodes,rewards,steps
|
||||
0,-649,154
|
||||
1,-2822,842
|
||||
2,-176,176
|
||||
3,-139,139
|
||||
4,-221,221
|
||||
5,-51,51
|
||||
6,-219,219
|
||||
7,-247,148
|
||||
8,-90,90
|
||||
9,-145,145
|
||||
10,-104,104
|
||||
11,-162,162
|
||||
12,-49,49
|
||||
13,-129,129
|
||||
14,-140,140
|
||||
15,-19,19
|
||||
16,-131,131
|
||||
17,-115,115
|
||||
18,-43,43
|
||||
19,-133,133
|
||||
20,-73,73
|
||||
21,-89,89
|
||||
22,-131,131
|
||||
23,-61,61
|
||||
24,-113,113
|
||||
25,-119,119
|
||||
26,-119,119
|
||||
27,-71,71
|
||||
28,-132,132
|
||||
29,-47,47
|
||||
30,-79,79
|
||||
31,-57,57
|
||||
32,-125,125
|
||||
33,-77,77
|
||||
34,-87,87
|
||||
35,-49,49
|
||||
36,-57,57
|
||||
37,-81,81
|
||||
38,-81,81
|
||||
39,-97,97
|
||||
40,-61,61
|
||||
41,-85,85
|
||||
42,-217,118
|
||||
43,-39,39
|
||||
44,-117,117
|
||||
45,-41,41
|
||||
46,-71,71
|
||||
47,-105,105
|
||||
48,-73,73
|
||||
49,-68,68
|
||||
50,-95,95
|
||||
51,-41,41
|
||||
52,-41,41
|
||||
53,-67,67
|
||||
54,-71,71
|
||||
55,-65,65
|
||||
56,-41,41
|
||||
57,-61,61
|
||||
58,-81,81
|
||||
59,-21,21
|
||||
60,-76,76
|
||||
61,-80,80
|
||||
62,-23,23
|
||||
63,-53,53
|
||||
64,-67,67
|
||||
65,-33,33
|
||||
66,-41,41
|
||||
67,-59,59
|
||||
68,-33,33
|
||||
69,-64,64
|
||||
70,-188,89
|
||||
71,-47,47
|
||||
72,-57,57
|
||||
73,-45,45
|
||||
74,-33,33
|
||||
75,-79,79
|
||||
76,-45,45
|
||||
77,-23,23
|
||||
78,-47,47
|
||||
79,-57,57
|
||||
80,-47,47
|
||||
81,-45,45
|
||||
82,-53,53
|
||||
83,-29,29
|
||||
84,-33,33
|
||||
85,-69,69
|
||||
86,-61,61
|
||||
87,-35,35
|
||||
88,-59,59
|
||||
89,-43,43
|
||||
90,-17,17
|
||||
91,-39,39
|
||||
92,-59,59
|
||||
93,-29,29
|
||||
94,-31,31
|
||||
95,-55,55
|
||||
96,-35,35
|
||||
97,-45,45
|
||||
98,-29,29
|
||||
99,-59,59
|
||||
100,-25,25
|
||||
101,-29,29
|
||||
102,-33,33
|
||||
103,-39,39
|
||||
104,-19,19
|
||||
105,-47,47
|
||||
106,-57,57
|
||||
107,-19,19
|
||||
108,-47,47
|
||||
109,-25,25
|
||||
110,-23,23
|
||||
111,-53,53
|
||||
112,-39,39
|
||||
113,-34,34
|
||||
114,-27,27
|
||||
115,-27,27
|
||||
116,-63,63
|
||||
117,-33,33
|
||||
118,-17,17
|
||||
119,-21,21
|
||||
120,-19,19
|
||||
121,-49,49
|
||||
122,-25,25
|
||||
123,-39,39
|
||||
124,-25,25
|
||||
125,-167,68
|
||||
126,-35,35
|
||||
127,-29,29
|
||||
128,-31,31
|
||||
129,-44,44
|
||||
130,-33,33
|
||||
131,-23,23
|
||||
132,-37,37
|
||||
133,-134,35
|
||||
134,-31,31
|
||||
135,-19,19
|
||||
136,-29,29
|
||||
137,-37,37
|
||||
138,-25,25
|
||||
139,-39,39
|
||||
140,-47,47
|
||||
141,-29,29
|
||||
142,-27,27
|
||||
143,-21,21
|
||||
144,-41,41
|
||||
145,-29,29
|
||||
146,-25,25
|
||||
147,-25,25
|
||||
148,-21,21
|
||||
149,-29,29
|
||||
150,-39,39
|
||||
151,-35,35
|
||||
152,-35,35
|
||||
153,-32,32
|
||||
154,-31,31
|
||||
155,-19,19
|
||||
156,-21,21
|
||||
157,-35,35
|
||||
158,-33,33
|
||||
159,-37,37
|
||||
160,-25,25
|
||||
161,-41,41
|
||||
162,-25,25
|
||||
163,-23,23
|
||||
164,-27,27
|
||||
165,-25,25
|
||||
166,-39,39
|
||||
167,-28,28
|
||||
168,-24,24
|
||||
169,-23,23
|
||||
170,-41,41
|
||||
171,-17,17
|
||||
172,-35,35
|
||||
173,-23,23
|
||||
174,-29,29
|
||||
175,-17,17
|
||||
176,-39,39
|
||||
177,-33,33
|
||||
178,-29,29
|
||||
179,-24,24
|
||||
180,-23,23
|
||||
181,-19,19
|
||||
182,-15,15
|
||||
183,-23,23
|
||||
184,-39,39
|
||||
185,-25,25
|
||||
186,-35,35
|
||||
187,-33,33
|
||||
188,-19,19
|
||||
189,-35,35
|
||||
190,-21,21
|
||||
191,-131,32
|
||||
192,-15,15
|
||||
193,-23,23
|
||||
194,-21,21
|
||||
195,-17,17
|
||||
196,-23,23
|
||||
197,-31,31
|
||||
198,-21,21
|
||||
199,-31,31
|
||||
200,-35,35
|
||||
201,-27,27
|
||||
202,-19,19
|
||||
203,-21,21
|
||||
204,-23,23
|
||||
205,-23,23
|
||||
206,-21,21
|
||||
207,-31,31
|
||||
208,-25,25
|
||||
209,-23,23
|
||||
210,-17,17
|
||||
211,-19,19
|
||||
212,-25,25
|
||||
213,-23,23
|
||||
214,-19,19
|
||||
215,-19,19
|
||||
216,-25,25
|
||||
217,-25,25
|
||||
218,-25,25
|
||||
219,-25,25
|
||||
220,-23,23
|
||||
221,-19,19
|
||||
222,-19,19
|
||||
223,-149,50
|
||||
224,-41,41
|
||||
225,-19,19
|
||||
226,-29,29
|
||||
227,-37,37
|
||||
228,-17,17
|
||||
229,-17,17
|
||||
230,-19,19
|
||||
231,-27,27
|
||||
232,-19,19
|
||||
233,-33,33
|
||||
234,-23,23
|
||||
235,-23,23
|
||||
236,-34,34
|
||||
237,-15,15
|
||||
238,-33,33
|
||||
239,-29,29
|
||||
240,-17,17
|
||||
241,-23,23
|
||||
242,-17,17
|
||||
243,-19,19
|
||||
244,-21,21
|
||||
245,-23,23
|
||||
246,-17,17
|
||||
247,-15,15
|
||||
248,-39,39
|
||||
249,-21,21
|
||||
250,-23,23
|
||||
251,-29,29
|
||||
252,-15,15
|
||||
253,-17,17
|
||||
254,-29,29
|
||||
255,-15,15
|
||||
256,-21,21
|
||||
257,-19,19
|
||||
258,-19,19
|
||||
259,-21,21
|
||||
260,-17,17
|
||||
261,-21,21
|
||||
262,-27,27
|
||||
263,-27,27
|
||||
264,-21,21
|
||||
265,-19,19
|
||||
266,-17,17
|
||||
267,-23,23
|
||||
268,-19,19
|
||||
269,-17,17
|
||||
270,-19,19
|
||||
271,-19,19
|
||||
272,-17,17
|
||||
273,-23,23
|
||||
274,-17,17
|
||||
275,-22,22
|
||||
276,-31,31
|
||||
277,-19,19
|
||||
278,-17,17
|
||||
279,-33,33
|
||||
280,-19,19
|
||||
281,-17,17
|
||||
282,-31,31
|
||||
283,-15,15
|
||||
284,-15,15
|
||||
285,-15,15
|
||||
286,-29,29
|
||||
287,-19,19
|
||||
288,-17,17
|
||||
289,-26,26
|
||||
290,-17,17
|
||||
291,-19,19
|
||||
292,-15,15
|
||||
293,-21,21
|
||||
294,-21,21
|
||||
295,-15,15
|
||||
296,-19,19
|
||||
297,-15,15
|
||||
298,-17,17
|
||||
299,-19,19
|
||||
300,-17,17
|
||||
301,-21,21
|
||||
302,-17,17
|
||||
303,-27,27
|
||||
304,-17,17
|
||||
305,-19,19
|
||||
306,-15,15
|
||||
307,-19,19
|
||||
308,-33,33
|
||||
309,-17,17
|
||||
310,-20,20
|
||||
311,-19,19
|
||||
312,-17,17
|
||||
313,-15,15
|
||||
314,-23,23
|
||||
315,-15,15
|
||||
316,-15,15
|
||||
317,-17,17
|
||||
318,-25,25
|
||||
319,-15,15
|
||||
320,-17,17
|
||||
321,-19,19
|
||||
322,-17,17
|
||||
323,-15,15
|
||||
324,-23,23
|
||||
325,-19,19
|
||||
326,-17,17
|
||||
327,-23,23
|
||||
328,-15,15
|
||||
329,-19,19
|
||||
330,-15,15
|
||||
331,-17,17
|
||||
332,-19,19
|
||||
333,-15,15
|
||||
334,-17,17
|
||||
335,-17,17
|
||||
336,-19,19
|
||||
337,-15,15
|
||||
338,-19,19
|
||||
339,-19,19
|
||||
340,-17,17
|
||||
341,-15,15
|
||||
342,-21,21
|
||||
343,-19,19
|
||||
344,-17,17
|
||||
345,-17,17
|
||||
346,-15,15
|
||||
347,-21,21
|
||||
348,-20,20
|
||||
349,-15,15
|
||||
350,-15,15
|
||||
351,-15,15
|
||||
352,-19,19
|
||||
353,-17,17
|
||||
354,-15,15
|
||||
355,-27,27
|
||||
356,-15,15
|
||||
357,-15,15
|
||||
358,-23,23
|
||||
359,-125,26
|
||||
360,-132,33
|
||||
361,-17,17
|
||||
362,-15,15
|
||||
363,-17,17
|
||||
364,-23,23
|
||||
365,-17,17
|
||||
366,-15,15
|
||||
367,-15,15
|
||||
368,-17,17
|
||||
369,-15,15
|
||||
370,-17,17
|
||||
371,-15,15
|
||||
372,-15,15
|
||||
373,-15,15
|
||||
374,-15,15
|
||||
375,-15,15
|
||||
376,-15,15
|
||||
377,-15,15
|
||||
378,-15,15
|
||||
379,-15,15
|
||||
380,-17,17
|
||||
381,-15,15
|
||||
382,-15,15
|
||||
383,-19,19
|
||||
384,-15,15
|
||||
385,-17,17
|
||||
386,-27,27
|
||||
387,-15,15
|
||||
388,-21,21
|
||||
389,-125,26
|
||||
390,-15,15
|
||||
391,-15,15
|
||||
392,-15,15
|
||||
393,-27,27
|
||||
394,-15,15
|
||||
395,-15,15
|
||||
396,-17,17
|
||||
397,-15,15
|
||||
398,-15,15
|
||||
399,-15,15
|
||||
|
@@ -0,0 +1 @@
|
||||
{"algo_name": "Sarsa", "env_name": "Racetrack-v0", "train_eps": 300, "test_eps": 20, "gamma": 0.99, "epsilon_start": 0.9, "epsilon_end": 0.01, "epsilon_decay": 200, "lr": 0.2, "device": "cpu", "seed": 10, "show_fig": false, "save_fig": true, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/Racetrack-v0/20220825-212738/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/Sarsa/outputs/Racetrack-v0/20220825-212738/models/", "n_states": 4, "n_actions": 9}
|
||||
|
After Width: | Height: | Size: 39 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,4,6
|
||||
1,4,6
|
||||
2,-1010,1000
|
||||
3,-14,14
|
||||
4,4,6
|
||||
5,4,6
|
||||
6,4,6
|
||||
7,-1060,1000
|
||||
8,2,8
|
||||
9,-12,12
|
||||
10,3,7
|
||||
11,-15,15
|
||||
12,3,7
|
||||
13,4,6
|
||||
14,-14,14
|
||||
15,3,7
|
||||
16,-18,18
|
||||
17,4,6
|
||||
18,4,6
|
||||
19,-1020,1000
|
||||
|
|
After Width: | Height: | Size: 41 KiB |
@@ -0,0 +1,301 @@
|
||||
episodes,rewards,steps
|
||||
0,-3460,1000
|
||||
1,-2800,1000
|
||||
2,-2910,1000
|
||||
3,-2620,1000
|
||||
4,-2620,1000
|
||||
5,-2590,1000
|
||||
6,-2390,1000
|
||||
7,-2510,1000
|
||||
8,-2470,1000
|
||||
9,-611,251
|
||||
10,-891,371
|
||||
11,-265,125
|
||||
12,-2281,911
|
||||
13,-1203,523
|
||||
14,-616,266
|
||||
15,-213,113
|
||||
16,-633,273
|
||||
17,-1112,482
|
||||
18,-350,160
|
||||
19,-852,342
|
||||
20,-87,47
|
||||
21,-11,11
|
||||
22,-27,17
|
||||
23,-117,57
|
||||
24,-15,15
|
||||
25,4,6
|
||||
26,-27,17
|
||||
27,-94,44
|
||||
28,-184,84
|
||||
29,-44,24
|
||||
30,-150,80
|
||||
31,-14,14
|
||||
32,-219,89
|
||||
33,-50,30
|
||||
34,-111,61
|
||||
35,-10,10
|
||||
36,-28,18
|
||||
37,-34,24
|
||||
38,-12,12
|
||||
39,-19,19
|
||||
40,-136,66
|
||||
41,-171,71
|
||||
42,-51,31
|
||||
43,4,6
|
||||
44,-117,57
|
||||
45,4,6
|
||||
46,4,6
|
||||
47,-127,67
|
||||
48,-78,48
|
||||
49,-311,131
|
||||
50,-25,15
|
||||
51,4,6
|
||||
52,-49,29
|
||||
53,-25,15
|
||||
54,-78,48
|
||||
55,-238,108
|
||||
56,4,6
|
||||
57,-17,17
|
||||
58,-29,19
|
||||
59,-218,98
|
||||
60,4,6
|
||||
61,-129,59
|
||||
62,-344,144
|
||||
63,-25,15
|
||||
64,-15,15
|
||||
65,-77,37
|
||||
66,2,8
|
||||
67,0,10
|
||||
68,4,6
|
||||
69,4,6
|
||||
70,-242,102
|
||||
71,3,7
|
||||
72,4,6
|
||||
73,-53,33
|
||||
74,-14,14
|
||||
75,4,6
|
||||
76,4,6
|
||||
77,-30,20
|
||||
78,-12,12
|
||||
79,2,8
|
||||
80,-12,12
|
||||
81,-150,70
|
||||
82,-48,28
|
||||
83,-102,52
|
||||
84,4,6
|
||||
85,-97,47
|
||||
86,-10,10
|
||||
87,-125,55
|
||||
88,-28,18
|
||||
89,-26,16
|
||||
90,-107,57
|
||||
91,4,6
|
||||
92,-16,16
|
||||
93,-84,44
|
||||
94,-13,13
|
||||
95,-43,23
|
||||
96,-14,14
|
||||
97,-12,12
|
||||
98,-13,13
|
||||
99,-2,12
|
||||
100,-14,14
|
||||
101,-47,27
|
||||
102,4,6
|
||||
103,4,6
|
||||
104,-91,51
|
||||
105,-65,35
|
||||
106,4,6
|
||||
107,-12,12
|
||||
108,-14,14
|
||||
109,-13,13
|
||||
110,4,6
|
||||
111,-41,31
|
||||
112,-13,13
|
||||
113,4,6
|
||||
114,-4,14
|
||||
115,-74,34
|
||||
116,4,6
|
||||
117,-60,30
|
||||
118,4,6
|
||||
119,-15,15
|
||||
120,3,7
|
||||
121,4,6
|
||||
122,4,6
|
||||
123,-19,19
|
||||
124,4,6
|
||||
125,-49,29
|
||||
126,-13,13
|
||||
127,-30,20
|
||||
128,2,8
|
||||
129,-21,21
|
||||
130,-45,25
|
||||
131,-32,22
|
||||
132,-67,37
|
||||
133,-46,26
|
||||
134,0,10
|
||||
135,-12,12
|
||||
136,-9,9
|
||||
137,-10,10
|
||||
138,-14,14
|
||||
139,4,6
|
||||
140,-11,11
|
||||
141,-12,12
|
||||
142,2,8
|
||||
143,-35,25
|
||||
144,4,6
|
||||
145,-73,43
|
||||
146,4,6
|
||||
147,-20,20
|
||||
148,4,6
|
||||
149,2,8
|
||||
150,-29,19
|
||||
151,-20,20
|
||||
152,4,6
|
||||
153,-28,18
|
||||
154,4,6
|
||||
155,4,6
|
||||
156,4,6
|
||||
157,4,6
|
||||
158,-34,24
|
||||
159,4,6
|
||||
160,4,6
|
||||
161,4,6
|
||||
162,-25,15
|
||||
163,4,6
|
||||
164,3,7
|
||||
165,-48,28
|
||||
166,4,6
|
||||
167,-58,38
|
||||
168,-20,20
|
||||
169,-9,9
|
||||
170,3,7
|
||||
171,4,6
|
||||
172,3,7
|
||||
173,-33,23
|
||||
174,-50,30
|
||||
175,-16,16
|
||||
176,-32,22
|
||||
177,-65,35
|
||||
178,4,6
|
||||
179,-13,13
|
||||
180,-11,11
|
||||
181,3,7
|
||||
182,4,6
|
||||
183,-16,16
|
||||
184,-12,12
|
||||
185,4,6
|
||||
186,-48,28
|
||||
187,-13,13
|
||||
188,2,8
|
||||
189,3,7
|
||||
190,-27,17
|
||||
191,3,7
|
||||
192,4,6
|
||||
193,4,6
|
||||
194,4,6
|
||||
195,4,6
|
||||
196,4,6
|
||||
197,-13,13
|
||||
198,-14,14
|
||||
199,4,6
|
||||
200,4,6
|
||||
201,-13,13
|
||||
202,-33,23
|
||||
203,4,6
|
||||
204,-32,22
|
||||
205,4,6
|
||||
206,-48,28
|
||||
207,4,6
|
||||
208,4,6
|
||||
209,3,7
|
||||
210,4,6
|
||||
211,-34,24
|
||||
212,3,7
|
||||
213,4,6
|
||||
214,4,6
|
||||
215,4,6
|
||||
216,3,7
|
||||
217,-12,12
|
||||
218,3,7
|
||||
219,-8,8
|
||||
220,3,7
|
||||
221,4,6
|
||||
222,-46,26
|
||||
223,-33,23
|
||||
224,4,6
|
||||
225,1,9
|
||||
226,3,7
|
||||
227,2,8
|
||||
228,-34,24
|
||||
229,4,6
|
||||
230,4,6
|
||||
231,4,6
|
||||
232,4,6
|
||||
233,-55,35
|
||||
234,-37,27
|
||||
235,4,6
|
||||
236,-14,14
|
||||
237,-65,35
|
||||
238,4,6
|
||||
239,-13,13
|
||||
240,4,6
|
||||
241,4,6
|
||||
242,-13,13
|
||||
243,-30,20
|
||||
244,3,7
|
||||
245,-13,13
|
||||
246,4,6
|
||||
247,4,6
|
||||
248,-13,13
|
||||
249,-32,22
|
||||
250,4,6
|
||||
251,-55,35
|
||||
252,-12,12
|
||||
253,3,7
|
||||
254,3,7
|
||||
255,3,7
|
||||
256,4,6
|
||||
257,2,8
|
||||
258,-12,12
|
||||
259,3,7
|
||||
260,-10,10
|
||||
261,-12,12
|
||||
262,4,6
|
||||
263,3,7
|
||||
264,3,7
|
||||
265,-16,16
|
||||
266,3,7
|
||||
267,-47,27
|
||||
268,-13,13
|
||||
269,4,6
|
||||
270,3,7
|
||||
271,-13,13
|
||||
272,4,6
|
||||
273,4,6
|
||||
274,-17,17
|
||||
275,4,6
|
||||
276,3,7
|
||||
277,3,7
|
||||
278,4,6
|
||||
279,-41,31
|
||||
280,3,7
|
||||
281,-47,27
|
||||
282,-32,22
|
||||
283,4,6
|
||||
284,3,7
|
||||
285,-17,17
|
||||
286,3,7
|
||||
287,3,7
|
||||
288,3,7
|
||||
289,-12,12
|
||||
290,4,6
|
||||
291,3,7
|
||||
292,3,7
|
||||
293,-24,14
|
||||
294,3,7
|
||||
295,4,6
|
||||
296,3,7
|
||||
297,3,7
|
||||
298,3,7
|
||||
299,-13,13
|
||||
|
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:58:16
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-25 00:23:22
|
||||
LastEditTime: 2022-08-25 21:26:08
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -30,7 +30,7 @@ class Sarsa(object):
|
||||
self.sample_count += 1
|
||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
|
||||
best_action = np.argmax(self.Q_table[state])
|
||||
best_action = np.argmax(self.Q_table[str(state)]) # array cannot be hashtable, thus convert to str
|
||||
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
|
||||
action_probs[best_action] += (1.0 - self.epsilon)
|
||||
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
|
||||
@@ -38,27 +38,27 @@ class Sarsa(object):
|
||||
def predict_action(self,state):
|
||||
''' predict action while testing
|
||||
'''
|
||||
action = np.argmax(self.Q_table[state])
|
||||
action = np.argmax(self.Q_table[str(state)])
|
||||
return action
|
||||
def update(self, state, action, reward, next_state, next_action,done):
|
||||
Q_predict = self.Q_table[state][action]
|
||||
Q_predict = self.Q_table[str(state)][action]
|
||||
if done:
|
||||
Q_target = reward # terminal state
|
||||
else:
|
||||
Q_target = reward + self.gamma * self.Q_table[next_state][next_action] # the only difference from Q learning
|
||||
self.Q_table[state][action] += self.lr * (Q_target - Q_predict)
|
||||
Q_target = reward + self.gamma * self.Q_table[str(next_state)][next_action] # the only difference from Q learning
|
||||
self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)
|
||||
def save_model(self,path):
|
||||
import dill
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
obj=self.Q_table_table,
|
||||
obj=self.Q_table,
|
||||
f=path+"checkpoint.pkl",
|
||||
pickle_module=dill
|
||||
)
|
||||
print("Model saved!")
|
||||
def load_model(self, path):
|
||||
import dill
|
||||
self.Q_table_table =torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
|
||||
self.Q_table=torch.load(f=path+'checkpoint.pkl',pickle_module=dill)
|
||||
print("Mode loaded!")
|
||||
@@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-04 22:44:00
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys
|
||||
import os
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
|
||||
parent_path = os.path.dirname(curr_path) # 父路径
|
||||
sys.path.append(parent_path) # 添加路径到系统路径
|
||||
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.gridworld_env import CliffWalkingWapper
|
||||
from Sarsa.sarsa import Sarsa
|
||||
from common.utils import plot_rewards,save_args
|
||||
from common.utils import save_results,make_dir
|
||||
|
||||
|
||||
def get_args():
|
||||
"""
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Sarsa',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training") # 训练的回合数
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing") # 测试的回合数
|
||||
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor") # 折扣因子
|
||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon") # e-greedy策略中初始epsilon
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon") # e-greedy策略中的终止epsilon
|
||||
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon") # e-greedy策略中epsilon的衰减率
|
||||
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results/' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models/' ) # path to save models
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args([])
|
||||
return args
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print('开始训练!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
rewards = [] # 记录奖励
|
||||
for i_ep in range(cfg.train_eps):
|
||||
ep_reward = 0 # 记录每个回合的奖励
|
||||
state = env.reset() # 重置环境,即开始新的回合
|
||||
action = agent.sample(state)
|
||||
while True:
|
||||
action = agent.sample(state) # 根据算法采样一个动作
|
||||
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
|
||||
next_action = agent.sample(next_state)
|
||||
agent.update(state, action, reward, next_state, next_action,done) # 算法更新
|
||||
state = next_state # 更新状态
|
||||
action = next_action
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
||||
print('完成训练!')
|
||||
return {"rewards":rewards}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print('开始测试!')
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
|
||||
rewards = [] # 记录所有回合的奖励
|
||||
for i_ep in range(cfg.test_eps):
|
||||
ep_reward = 0 # 记录每个episode的reward
|
||||
state = env.reset() # 重置环境, 重新开一局(即开始新的一个回合)
|
||||
while True:
|
||||
action = agent.predict(state) # 根据算法选择一个动作
|
||||
next_state, reward, done, _ = env.step(action) # 与环境进行一个交互
|
||||
state = next_state # 更新状态
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"回合数:{i_ep+1}/{cfg.test_eps}, 奖励:{ep_reward:.1f}")
|
||||
print('完成测试!')
|
||||
return {"rewards":rewards}
|
||||
|
||||
def env_agent_config(cfg,seed=1):
|
||||
'''创建环境和智能体
|
||||
Args:
|
||||
cfg ([type]): [description]
|
||||
seed (int, optional): 随机种子. Defaults to 1.
|
||||
Returns:
|
||||
env [type]: 环境
|
||||
agent : 智能体
|
||||
'''
|
||||
env = gym.make(cfg.env_name)
|
||||
env = CliffWalkingWapper(env)
|
||||
env.seed(seed) # 设置随机种子
|
||||
n_states = env.observation_space.n # 状态维度
|
||||
n_actions = env.action_space.n # 动作维度
|
||||
print(f"状态数:{n_states},动作数:{n_actions}")
|
||||
agent = Sarsa(n_actions,cfg)
|
||||
return env,agent
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# 训练
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
make_dir(cfg.result_path, cfg.model_path)
|
||||
save_args(cfg) # save parameters
|
||||
agent.save(path=cfg.model_path) # save model
|
||||
save_results(res_dic, tag='train',
|
||||
path=cfg.result_path)
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="train")
|
||||
# 测试
|
||||
env, agent = env_agent_config(cfg)
|
||||
agent.load(path=cfg.model_path) # 导入模型
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path=cfg.result_path) # 保存结果
|
||||
plot_rewards(res_dic['rewards'], cfg, tag="test") # 画出结果
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class Launcher:
|
||||
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
||||
# testing
|
||||
env, agent = self.env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
# env, agent = self.env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
agent.load_model(path = cfg['model_path']) # load model
|
||||
res_dic = self.test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-10 15:27:16
|
||||
@LastEditor: John
|
||||
LastEditTime: 2022-08-22 17:23:21
|
||||
LastEditTime: 2022-08-28 23:44:06
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -39,12 +39,12 @@ class ReplayBufferQue:
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.capacity = capacity
|
||||
self.buffer = deque(maxlen=self.capacity)
|
||||
def push(self,trainsitions):
|
||||
def push(self,transitions):
|
||||
'''_summary_
|
||||
Args:
|
||||
trainsitions (tuple): _description_
|
||||
'''
|
||||
self.buffer.append(trainsitions)
|
||||
self.buffer.append(transitions)
|
||||
def sample(self, batch_size: int, sequential: bool = False):
|
||||
if batch_size > len(self.buffer):
|
||||
batch_size = len(self.buffer)
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 21:14:12
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-09-15 13:21:03
|
||||
LastEditTime: 2022-08-29 14:24:44
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -31,40 +31,45 @@ class MLP(nn.Module):
|
||||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
class ActorSoftmax(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, hidden_dim=256):
|
||||
super(ActorSoftmax, self).__init__()
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
||||
def forward(self,state):
|
||||
dist = F.relu(self.fc1(state))
|
||||
dist = F.softmax(self.fc2(dist),dim=1)
|
||||
return dist
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, n_obs, n_actions, hidden_size, init_w=3e-3):
|
||||
super(Critic, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(n_obs + n_actions, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear3 = nn.Linear(hidden_size, 1)
|
||||
# 随机初始化为较小的值
|
||||
self.linear3.weight.data.uniform_(-init_w, init_w)
|
||||
self.linear3.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, state, action):
|
||||
# 按维数1拼接
|
||||
x = torch.cat([state, action], 1)
|
||||
x = F.relu(self.linear1(x))
|
||||
x = F.relu(self.linear2(x))
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
def __init__(self,input_dim,output_dim,hidden_dim=256):
|
||||
super(Critic,self).__init__()
|
||||
assert output_dim == 1 # critic must output a single value
|
||||
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
||||
def forward(self,state):
|
||||
value = F.relu(self.fc1(state))
|
||||
value = self.fc2(value)
|
||||
return value
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, n_obs, n_actions, hidden_size, init_w=3e-3):
|
||||
super(Actor, self).__init__()
|
||||
self.linear1 = nn.Linear(n_obs, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear3 = nn.Linear(hidden_size, n_actions)
|
||||
class ActorCriticSoftmax(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, actor_hidden_dim=256,critic_hidden_dim=256):
|
||||
super(ActorCriticSoftmax, self).__init__()
|
||||
|
||||
self.critic_fc1 = nn.Linear(input_dim, critic_hidden_dim)
|
||||
self.critic_fc2 = nn.Linear(critic_hidden_dim, 1)
|
||||
|
||||
self.actor_fc1 = nn.Linear(input_dim, actor_hidden_dim)
|
||||
self.actor_fc2 = nn.Linear(actor_hidden_dim, output_dim)
|
||||
|
||||
def forward(self, state):
|
||||
# state = Variable(torch.from_numpy(state).float().unsqueeze(0))
|
||||
value = F.relu(self.critic_fc1(state))
|
||||
value = self.critic_fc2(value)
|
||||
|
||||
self.linear3.weight.data.uniform_(-init_w, init_w)
|
||||
self.linear3.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.linear1(x))
|
||||
x = F.relu(self.linear2(x))
|
||||
x = torch.tanh(self.linear3(x))
|
||||
return x
|
||||
policy_dist = F.relu(self.actor_fc1(state))
|
||||
policy_dist = F.softmax(self.actor_fc2(policy_dist), dim=1)
|
||||
|
||||
return value, policy_dist
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
def __init__(self, n_states, n_actions, hidden_dim=256):
|
||||
|
||||
@@ -5,7 +5,7 @@ def register_env(env_name):
|
||||
if env_name == 'Racetrack-v0':
|
||||
register(
|
||||
id='Racetrack-v0',
|
||||
entry_point='racetrack:RacetrackEnv',
|
||||
entry_point='envs.racetrack:RacetrackEnv',
|
||||
max_episode_steps=1000,
|
||||
kwargs={}
|
||||
)
|
||||
|
||||
15
projects/codes/scripts/A2C_CartPole-v0.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
# run A2C on CartPole-v0
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/A2C/main.py
|
||||
@@ -1,6 +1,4 @@
|
||||
'''
|
||||
run DQN on CartPole-v1, not finished yet
|
||||
'''
|
||||
# run DQN on CartPole-v1, not finished yet
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
@@ -13,4 +11,4 @@ else
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/DQN/main.py --env_name CartPole-v1 --train_eps 500 --epsilon_decay 1000 --memory_capacity 200000 --batch_size 128 --device cuda
|
||||
python $codes_dir/DQN/main.py --env_name CartPole-v1 --train_eps 2000 --gamma 0.99 --epsilon_decay 6000 --lr 0.00001 --memory_capacity 200000 --batch_size 64 --device cuda
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
@@ -11,4 +10,4 @@ else
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/QLearning/main.py --device cpu
|
||||
python $codes_dir/PolicyGradient/main.py
|
||||
12
projects/codes/scripts/Qlearning_CliffWalking-v0.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/QLearning/main.py --env_name CliffWalking-v0 --train_eps 400 --gamma 0.90 --epsilon_start 0.95 --epsilon_end 0.01 --epsilon_decay 300 --lr 0.1 --device cpu
|
||||
@@ -11,5 +11,4 @@ else
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/envs/register.py # register environment
|
||||
python $codes_dir/QLearning/main.py --env_name FrozenLakeNoSlippery-v1 --train_eps 800 --epsilon_start 0.70 --epsilon_end 0.1 --epsilon_decay 2000 --gamma 0.9 --lr 0.9 --device cpu
|
||||
14
projects/codes/scripts/Qlearning_Racetrack-v0.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/QLearning/main.py --env_name Racetrack-v0 --device cpu
|
||||
12
projects/codes/scripts/Sarsa_CliffWalking-v0.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/Sarsa/main.py --env_name CliffWalking-v0 --train_eps 400 --gamma 0.90 --epsilon_start 0.95 --epsilon_end 0.01 --epsilon_decay 300 --lr 0.1 --device cpu
|
||||
13
projects/codes/scripts/Sarsa_FrozenLakeNoSlippery-v1.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
# Sarsa for FrozenLakeNoSlippery-v1, cannot converge like Qlearning!
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/Sarsa/main.py --env_name FrozenLakeNoSlippery-v1 --train_eps 800 --ep_max_steps 10 --epsilon_start 0.50 --epsilon_end 0.01 --epsilon_decay 2000 --gamma 0.9 --lr 0.1 --device cpu
|
||||
@@ -9,5 +9,4 @@ else
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/envs/register.py # register environment
|
||||
python $codes_dir/Sarsa/main.py
|
||||
python $codes_dir/Sarsa/main.py --env_name Racetrack-v0
|
||||