update
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-23 15:17:42
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-04-28 10:11:09
|
||||
LastEditTime: 2021-09-26 22:02:00
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -41,10 +41,8 @@ class PPO:
|
||||
|
||||
def update(self):
|
||||
for _ in range(self.n_epochs):
|
||||
state_arr, action_arr, old_prob_arr, vals_arr,\
|
||||
reward_arr, dones_arr, batches = \
|
||||
self.memory.sample()
|
||||
values = vals_arr
|
||||
state_arr, action_arr, old_prob_arr, vals_arr,reward_arr, dones_arr, batches = self.memory.sample()
|
||||
values = vals_arr[:]
|
||||
### compute advantage ###
|
||||
advantage = np.zeros(len(reward_arr), dtype=np.float32)
|
||||
for t in range(len(reward_arr)-1):
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-23 15:30:46
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-03-23 15:30:55
|
||||
LastEditTime: 2021-09-26 22:00:07
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -24,14 +24,9 @@ class PPOMemory:
|
||||
indices = np.arange(len(self.states), dtype=np.int64)
|
||||
np.random.shuffle(indices)
|
||||
batches = [indices[i:i+self.batch_size] for i in batch_step]
|
||||
return np.array(self.states),\
|
||||
np.array(self.actions),\
|
||||
np.array(self.probs),\
|
||||
np.array(self.vals),\
|
||||
np.array(self.rewards),\
|
||||
np.array(self.dones),\
|
||||
batches
|
||||
|
||||
return np.array(self.states),np.array(self.actions),np.array(self.probs),\
|
||||
np.array(self.vals),np.array(self.rewards),np.array(self.dones),batches
|
||||
|
||||
def push(self, state, action, probs, vals, reward, done):
|
||||
self.states.append(state)
|
||||
self.actions.append(action)
|
||||
|
||||
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-22 16:18:10
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-05-06 00:43:36
|
||||
LastEditTime: 2021-09-26 22:05:00
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -17,6 +17,7 @@ sys.path.append(parent_path) # add current terminal path to sys.path
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import tqdm
|
||||
from PPO.agent import PPO
|
||||
from common.plot import plot_rewards
|
||||
from common.utils import save_results,make_dir
|
||||
@@ -51,7 +52,7 @@ def env_agent_config(cfg,seed=1):
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print('Start to train !')
|
||||
print('开始训练!')
|
||||
print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')
|
||||
rewards= []
|
||||
ma_rewards = [] # moving average rewards
|
||||
@@ -75,7 +76,7 @@ def train(cfg,env,agent):
|
||||
0.9*ma_rewards[-1]+0.1*ep_reward)
|
||||
else:
|
||||
ma_rewards.append(ep_reward)
|
||||
print(f"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}")
|
||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f}")
|
||||
print('Complete training!')
|
||||
return rewards,ma_rewards
|
||||
|
||||
|
||||
Reference in New Issue
Block a user