This commit is contained in:
JohnJim0816
2021-03-28 11:18:52 +08:00
parent 2df8d965d2
commit 6e4d966e1f
56 changed files with 497 additions and 165 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:48:57
@LastEditor: John
LastEditTime: 2021-03-17 20:11:19
LastEditTime: 2021-03-28 11:05:14
@Discription:
@Environment: python 3.7.7
'''
@@ -32,7 +32,7 @@ if not os.path.exists(RESULT_PATH):
class DoubleDQNConfig:
def __init__(self):
self.algo = "Double DQN" # 算法名称
self.algo = "Double DQN" # name of algo
self.gamma = 0.99
self.epsilon_start = 0.9 # e-greedy策略的初始epsilon
self.epsilon_end = 0.01
@@ -40,7 +40,7 @@ class DoubleDQNConfig:
self.lr = 0.01 # 学习率
self.memory_capacity = 10000 # Replay Memory容量
self.batch_size = 128
self.train_eps = 250 # 训练的episode数目
self.train_eps = 300 # 训练的episode数目
self.train_steps = 200 # 训练每个episode的最大长度
self.target_update = 2 # target net的更新频率
self.eval_eps = 20 # 测试的episode数目
@@ -84,9 +84,9 @@ if __name__ == "__main__":
cfg = DoubleDQNConfig()
env = gym.make('CartPole-v0').unwrapped # 可google为什么unwrapped gym此处一般不需要
env.seed(1) # 设置env随机种子
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = DoubleDQN(n_states,n_actions,cfg)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DoubleDQN(state_dim,action_dim,cfg)
rewards,ma_rewards = train(cfg,env,agent)
agent.save(path=SAVED_MODEL_PATH)
save_results(rewards,ma_rewards,tag='train',path=RESULT_PATH)