update codes

This commit is contained in:
johnjim0816
2021-12-28 18:46:52 +08:00
parent 41fb561d25
commit bd51b5a7ad
52 changed files with 305 additions and 292 deletions

View File

@@ -38,9 +38,9 @@
"outputs": [],
"source": [
"class QLearning(object):\n",
" def __init__(self,n_states,\n",
" n_actions,cfg):\n",
" self.n_actions = n_actions \n",
" def __init__(self,state_dim,\n",
" action_dim,cfg):\n",
" self.action_dim = action_dim \n",
" self.lr = cfg.lr # 学习率\n",
" self.gamma = cfg.gamma \n",
" self.epsilon = 0 \n",
@@ -48,7 +48,7 @@
" self.epsilon_start = cfg.epsilon_start\n",
" self.epsilon_end = cfg.epsilon_end\n",
" self.epsilon_decay = cfg.epsilon_decay\n",
" self.Q_table = defaultdict(lambda: np.zeros(n_actions)) # 用嵌套字典存放状态->动作->状态-动作值Q值的映射即Q表\n",
" self.Q_table = defaultdict(lambda: np.zeros(action_dim)) # 用嵌套字典存放状态->动作->状态-动作值Q值的映射即Q表\n",
" def choose_action(self, state):\n",
" self.sample_count += 1\n",
" self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \\\n",
@@ -57,7 +57,7 @@
" if np.random.uniform(0, 1) > self.epsilon:\n",
" action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)最大对应的动作\n",
" else:\n",
" action = np.random.choice(self.n_actions) # 随机选择动作\n",
" action = np.random.choice(self.action_dim) # 随机选择动作\n",
" return action\n",
" def predict(self,state):\n",
" action = np.argmax(self.Q_table[str(state)])\n",
@@ -238,9 +238,9 @@
" env = gym.make(cfg.env_name) \n",
" env = CliffWalkingWapper(env)\n",
" env.seed(seed) # 设置随机种子\n",
" n_states = env.observation_space.n # 状态维度\n",
" n_actions = env.action_space.n # 动作维度\n",
" agent = QLearning(n_states,n_actions,cfg)\n",
" state_dim = env.observation_space.n # 状态维度\n",
" action_dim = env.action_space.n # 动作维度\n",
" agent = QLearning(state_dim,action_dim,cfg)\n",
" return env,agent"
]
},