hot update A2C

This commit is contained in:
johnjim0816
2022-08-29 15:12:33 +08:00
parent 99a3c1afec
commit 0b0f7e857d
109 changed files with 8213 additions and 1658 deletions

View File

@@ -24,6 +24,7 @@ def get_args():
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--ep_max_steps',default = 100000,type=int,help="steps per episode, much larger value can simulate infinite steps")
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
@@ -72,7 +73,7 @@ def train(cfg, env, agent):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
for _ in range(cfg['ep_max_steps']):
ep_step += 1
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action) # update env and return transitions
@@ -91,7 +92,7 @@ def train(cfg, env, agent):
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
print("Finish training!")
env.close()
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
res_dic = {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
return res_dic
def test(cfg, env, agent):
@@ -103,7 +104,7 @@ def test(cfg, env, agent):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
for _ in range(cfg['ep_max_steps']):
ep_step+=1
action = agent.predict_action(state) # predict action
next_state, reward, done, _ = env.step(action)
@@ -116,7 +117,7 @@ def test(cfg, env, agent):
print(f"Episode: {i_ep+1}/{cfg['test_eps']}Reward: {ep_reward:.2f}")
print("Finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
if __name__ == "__main__":