This commit is contained in:
johnjim0816
2021-04-28 22:11:22 +08:00
parent e4690ac89f
commit ed7b60fd5b
73 changed files with 502 additions and 187 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-23 15:17:42
LastEditor: John
LastEditTime: 2021-04-11 01:24:24
LastEditTime: 2021-04-28 10:11:09
Discription:
Environment:
'''
@@ -17,7 +17,6 @@ from PPO.model import Actor,Critic
from PPO.memory import PPOMemory
class PPO:
def __init__(self, state_dim, action_dim,cfg):
self.env = cfg.env
self.gamma = cfg.gamma
self.policy_clip = cfg.policy_clip
self.n_epochs = cfg.n_epochs
@@ -84,13 +83,13 @@ class PPO:
self.critic_optimizer.step()
self.memory.clear()
def save(self,path):
actor_checkpoint = os.path.join(path, self.env+'_actor.pt')
critic_checkpoint= os.path.join(path, self.env+'_critic.pt')
actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
torch.save(self.actor.state_dict(), actor_checkpoint)
torch.save(self.critic.state_dict(), critic_checkpoint)
def load(self,path):
actor_checkpoint = os.path.join(path, self.env+'_actor.pt')
critic_checkpoint= os.path.join(path, self.env+'_critic.pt')
actor_checkpoint = os.path.join(path, 'ppo_actor.pt')
critic_checkpoint= os.path.join(path, 'ppo_critic.pt')
self.actor.load_state_dict(torch.load(actor_checkpoint))
self.critic.load_state_dict(torch.load(critic_checkpoint))

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-22 16:18:10
LastEditor: John
LastEditTime: 2021-04-11 01:24:41
LastEditTime: 2021-04-28 10:13:00
Discription:
Environment:
'''
@@ -19,24 +19,16 @@ import torch
import datetime
from PPO.agent import PPO
from common.plot import plot_rewards
from common.utils import save_results
from common.utils import save_results,make_dir
SEQUENCE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
SAVED_MODEL_PATH = os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"+SEQUENCE+'/' # 生成保存的模型路径
if not os.path.exists(os.path.split(os.path.abspath(__file__))[0]+"/saved_model/"): # 检测是否存在文件夹
os.mkdir(os.path.split(os.path.abspath(__file__))[0]+"/saved_model/")
if not os.path.exists(SAVED_MODEL_PATH): # 检测是否存在文件夹
os.mkdir(SAVED_MODEL_PATH)
RESULT_PATH = os.path.split(os.path.abspath(__file__))[0]+"/results/"+SEQUENCE+'/' # 存储reward的路径
if not os.path.exists(os.path.split(os.path.abspath(__file__))[0]+"/results/"): # 检测是否存在文件夹
os.mkdir(os.path.split(os.path.abspath(__file__))[0]+"/results/")
if not os.path.exists(RESULT_PATH): # 检测是否存在文件夹
os.mkdir(RESULT_PATH)
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
class PPOConfig:
def __init__(self) -> None:
self.env = 'CartPole-v0'
self.algo = 'PPO'
self.result_path = curr_path+"/results/" +self.env+'/'+curr_time+'/results/' # path to save results
self.model_path = curr_path+"/results/" +self.env+'/'+curr_time+'/models/' # path to save models
self.batch_size = 5
self.gamma=0.99
self.n_epochs = 4
@@ -50,12 +42,10 @@ class PPOConfig:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # check gpu
def train(cfg,env,agent):
best_reward = env.reward_range[0]
rewards= []
ma_rewards = [] # moving average rewards
avg_reward = 0
running_steps = 0
for i_episode in range(cfg.train_eps):
for i_ep in range(cfg.train_eps):
state = env.reset()
done = False
ep_reward = 0
@@ -74,21 +64,18 @@ def train(cfg,env,agent):
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
avg_reward = np.mean(rewards[-100:])
if avg_rewardself.actor_lr = 0.002
self.critic_lr = 0.005 > best_reward:
best_reward = avg_reward
agent.save(path=SAVED_MODEL_PATH)
print('Episode:{}/{}, Reward:{:.1f}, avg reward:{:.1f}, Done:{}'.format(i_episode+1,cfg.train_eps,ep_reward,avg_reward,done))
print(f"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}")
return rewards,ma_rewards
if __name__ == '__main__':
cfg = PPOConfig()
cfg = PPOConfig()
env = gym.make(cfg.env)
env.seed(1)
env.seed(1) # Set seeds
state_dim=env.observation_space.shape[0]
action_dim=env.action_space.n
agent = PPO(state_dim,action_dim,cfg)
rewards,ma_rewards = train(cfg,env,agent)
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)
plot_rewards(rewards,ma_rewards,tag="train",algo = cfg.algo,path=RESULT_PATH)
make_dir(cfg.result_path,cfg.model_path)
agent.save(path=cfg.model_path)
save_results(rewards,ma_rewards,tag='train',path=cfg.result_path)
plot_rewards(rewards,ma_rewards,tag="train",env=cfg.env,algo = cfg.algo,path=cfg.result_path)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB