update
This commit is contained in:
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-04-29 17:01:08
|
||||
LastEditTime: 2021-05-06 17:04:38
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -15,6 +15,7 @@ parent_path=os.path.dirname(curr_path)
|
||||
sys.path.append(parent_path) # add current terminal path to sys.path
|
||||
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
|
||||
from envs.gridworld_env import CliffWalkingWapper
|
||||
@@ -37,6 +38,8 @@ class QlearningConfig:
|
||||
self.epsilon_end = 0.01 # e-greedy策略中的终止epsilon
|
||||
self.epsilon_decay = 200 # e-greedy策略中epsilon的衰减率
|
||||
self.lr = 0.1 # learning rate
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # check gpu
|
||||
|
||||
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = gym.make(cfg.env)
|
||||
@@ -48,6 +51,8 @@ def env_agent_config(cfg,seed=1):
|
||||
return env,agent
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print('Start to train !')
|
||||
print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')
|
||||
rewards = []
|
||||
ma_rewards = [] # moving average reward
|
||||
for i_ep in range(cfg.train_eps):
|
||||
@@ -67,11 +72,14 @@ def train(cfg,env,agent):
|
||||
else:
|
||||
ma_rewards.append(ep_reward)
|
||||
print("Episode:{}/{}: reward:{:.1f}".format(i_ep+1, cfg.train_eps,ep_reward))
|
||||
print('Complete training!')
|
||||
return rewards,ma_rewards
|
||||
|
||||
def eval(cfg,env,agent):
|
||||
# env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up
|
||||
# env = FrozenLakeWapper(env)
|
||||
print('Start to eval !')
|
||||
print(f'Env:{cfg.env}, Algorithm:{cfg.algo}, Device:{cfg.device}')
|
||||
rewards = [] # 记录所有episode的reward
|
||||
ma_rewards = [] # 滑动平均的reward
|
||||
for i_ep in range(cfg.eval_eps):
|
||||
@@ -90,6 +98,7 @@ def eval(cfg,env,agent):
|
||||
else:
|
||||
ma_rewards.append(ep_reward)
|
||||
print(f"Episode:{i_ep+1}/{cfg.eval_eps}, reward:{ep_reward:.1f}")
|
||||
print('Complete evaling!')
|
||||
return rewards,ma_rewards
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user