update
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-23 15:17:42
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-04-11 01:24:24
|
||||
LastEditTime: 2021-04-28 10:11:09
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -17,7 +17,6 @@ from PPO.model import Actor,Critic
|
||||
from PPO.memory import PPOMemory
|
||||
class PPO:
|
||||
def __init__(self, state_dim, action_dim,cfg):
|
||||
self.env = cfg.env
|
||||
self.gamma = cfg.gamma
|
||||
self.policy_clip = cfg.policy_clip
|
||||
self.n_epochs = cfg.n_epochs
|
||||
@@ -84,13 +83,13 @@ class PPO:
|
||||
self.critic_optimizer.step()
|
||||
self.memory.clear()
|
||||
def save(self,path):
|
||||
actor_checkpoint = os.path.join(path, self.env+'_actor.pt')
|
||||
critic_checkpoint= os.path.join(path, self.env+'_critic.pt')
|
||||
actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
|
||||
critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
|
||||
torch.save(self.actor.state_dict(), actor_checkpoint)
|
||||
torch.save(self.critic.state_dict(), critic_checkpoint)
|
||||
def load(self,path):
|
||||
actor_checkpoint = os.path.join(path, self.env+'_actor.pt')
|
||||
critic_checkpoint= os.path.join(path, self.env+'_critic.pt')
|
||||
actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
|
||||
critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
|
||||
self.actor.load_state_dict(torch.load(actor_checkpoint))
|
||||
self.critic.load_state_dict(torch.load(critic_checkpoint))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-22 16:18:10
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-04-11 01:24:41
|
||||
LastEditTime: 2021-04-28 10:13:00
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -19,24 +19,16 @@ import torch
|
||||
import datetime
|
||||
from PPO.agent import PPO
|
||||
from common.plot import plot_rewards
|
||||
from common.utils import save_results
|
||||
from common.utils import save_results,make_dir
|
||||
|
||||
SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
|
||||
SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/' # 生成保存的模型路径
|
||||
if not os.path.exists(os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"): # 检测是否存在文件夹
|
||||
os.mkdir(os.path.split(os.path.abspath(__file__))[0]+"/saved_model/")
|
||||
if not os.path.exists(SAVED_MODEL_PATH): # 检测是否存在文件夹
|
||||
os.mkdir(SAVED_MODEL_PATH)
|
||||
RESULT_PATH = os.path.split(os.path.abspath(__file__))[0]+"/results/"+SEQUENCE+'/' # 存储reward的路径
|
||||
if not os.path.exists(os.path.split(os.path.abspath(__file__))[0]+"/results/"): # 检测是否存在文件夹
|
||||
os.mkdir(os.path.split(os.path.abspath(__file__))[0]+"/results/")
|
||||
if not os.path.exists(RESULT_PATH): # 检测是否存在文件夹
|
||||
os.mkdir(RESULT_PATH)
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
|
||||
class PPOConfig:
|
||||
def __init__(self) -> None:
|
||||
self.env = 'CartPole-v0'
|
||||
self.algo = 'PPO'
|
||||
self.result_path = curr_path+"/results/" +self.env+'/'+curr_time+'/results/' # path to save results
|
||||
self.model_path = curr_path+"/results/" +self.env+'/'+curr_time+'/models/' # path to save models
|
||||
self.batch_size = 5
|
||||
self.gamma=0.99
|
||||
self.n_epochs = 4
|
||||
@@ -50,12 +42,10 @@ class PPOConfig:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # check gpu
|
||||
|
||||
def train(cfg,env,agent):
|
||||
best_reward = env.reward_range[0]
|
||||
rewards= []
|
||||
ma_rewards = [] # moving average rewards
|
||||
avg_reward = 0
|
||||
running_steps = 0
|
||||
for i_episode in range(cfg.train_eps):
|
||||
for i_ep in range(cfg.train_eps):
|
||||
state = env.reset()
|
||||
done = False
|
||||
ep_reward = 0
|
||||
@@ -74,21 +64,18 @@ def train(cfg,env,agent):
|
||||
0.9*ma_rewards[-1]+0.1*ep_reward)
|
||||
else:
|
||||
ma_rewards.append(ep_reward)
|
||||
avg_reward = np.mean(rewards[-100:])
|
||||
if avg_rewardself.actor_lr = 0.002
|
||||
self.critic_lr = 0.005 > best_reward:
|
||||
best_reward = avg_reward
|
||||
agent.save(path=SAVED_MODEL_PATH)
|
||||
print('Episode:{}/{}, Reward:{:.1f}, avg reward:{:.1f}, Done:{}'.format(i_episode+1,cfg.train_eps,ep_reward,avg_reward,done))
|
||||
print(f"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}")
|
||||
return rewards,ma_rewards
|
||||
|
||||
if __name__ == '__main__':
|
||||
cfg = PPOConfig()
|
||||
cfg = PPOConfig()
|
||||
env = gym.make(cfg.env)
|
||||
env.seed(1)
|
||||
env.seed(1) # Set seeds
|
||||
state_dim=env.observation_space.shape[0]
|
||||
action_dim=env.action_space.n
|
||||
agent = PPO(state_dim,action_dim,cfg)
|
||||
rewards,ma_rewards = train(cfg,env,agent)
|
||||
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
|
||||
plot_rewards(rewards,ma_rewards,tag="train",algo = cfg.algo,path=RESULT_PATH)
|
||||
make_dir(cfg.result_path,cfg.model_path)
|
||||
agent.save(path=cfg.model_path)
|
||||
save_results(rewards,ma_rewards,tag='train',path=cfg.result_path)
|
||||
plot_rewards(rewards,ma_rewards,tag="train",env=cfg.env,algo = cfg.algo,path=cfg.result_path)
|
||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 58 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 63 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user