update codes
This commit is contained in:
@@ -15,9 +15,9 @@ import torch
|
||||
from collections import defaultdict
|
||||
|
||||
class QLearning(object):
|
||||
def __init__(self,n_states,
|
||||
n_actions,cfg):
|
||||
self.n_actions = n_actions
|
||||
def __init__(self,state_dim,
|
||||
action_dim,cfg):
|
||||
self.action_dim = action_dim
|
||||
self.lr = cfg.lr # 学习率
|
||||
self.gamma = cfg.gamma
|
||||
self.epsilon = 0
|
||||
@@ -25,7 +25,7 @@ class QLearning(object):
|
||||
self.epsilon_start = cfg.epsilon_start
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.Q_table = defaultdict(lambda: np.zeros(n_actions)) # 用嵌套字典存放状态->动作->状态-动作值(Q值)的映射,即Q表
|
||||
self.Q_table = defaultdict(lambda: np.zeros(action_dim)) # 用嵌套字典存放状态->动作->状态-动作值(Q值)的映射,即Q表
|
||||
def choose_action(self, state):
|
||||
self.sample_count += 1
|
||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||
@@ -34,7 +34,7 @@ class QLearning(object):
|
||||
if np.random.uniform(0, 1) > self.epsilon:
|
||||
action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)最大对应的动作
|
||||
else:
|
||||
action = np.random.choice(self.n_actions) # 随机选择动作
|
||||
action = np.random.choice(self.action_dim) # 随机选择动作
|
||||
return action
|
||||
def predict(self,state):
|
||||
action = np.argmax(self.Q_table[str(state)])
|
||||
|
||||
@@ -38,9 +38,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class QLearning(object):\n",
|
||||
" def __init__(self,n_states,\n",
|
||||
" n_actions,cfg):\n",
|
||||
" self.n_actions = n_actions \n",
|
||||
" def __init__(self,state_dim,\n",
|
||||
" action_dim,cfg):\n",
|
||||
" self.action_dim = action_dim \n",
|
||||
" self.lr = cfg.lr # 学习率\n",
|
||||
" self.gamma = cfg.gamma \n",
|
||||
" self.epsilon = 0 \n",
|
||||
@@ -48,7 +48,7 @@
|
||||
" self.epsilon_start = cfg.epsilon_start\n",
|
||||
" self.epsilon_end = cfg.epsilon_end\n",
|
||||
" self.epsilon_decay = cfg.epsilon_decay\n",
|
||||
" self.Q_table = defaultdict(lambda: np.zeros(n_actions)) # 用嵌套字典存放状态->动作->状态-动作值(Q值)的映射,即Q表\n",
|
||||
" self.Q_table = defaultdict(lambda: np.zeros(action_dim)) # 用嵌套字典存放状态->动作->状态-动作值(Q值)的映射,即Q表\n",
|
||||
" def choose_action(self, state):\n",
|
||||
" self.sample_count += 1\n",
|
||||
" self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \\\n",
|
||||
@@ -57,7 +57,7 @@
|
||||
" if np.random.uniform(0, 1) > self.epsilon:\n",
|
||||
" action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)最大对应的动作\n",
|
||||
" else:\n",
|
||||
" action = np.random.choice(self.n_actions) # 随机选择动作\n",
|
||||
" action = np.random.choice(self.action_dim) # 随机选择动作\n",
|
||||
" return action\n",
|
||||
" def predict(self,state):\n",
|
||||
" action = np.argmax(self.Q_table[str(state)])\n",
|
||||
@@ -238,9 +238,9 @@
|
||||
" env = gym.make(cfg.env_name) \n",
|
||||
" env = CliffWalkingWapper(env)\n",
|
||||
" env.seed(seed) # 设置随机种子\n",
|
||||
" n_states = env.observation_space.n # 状态维度\n",
|
||||
" n_actions = env.action_space.n # 动作维度\n",
|
||||
" agent = QLearning(n_states,n_actions,cfg)\n",
|
||||
" state_dim = env.observation_space.n # 状态维度\n",
|
||||
" action_dim = env.action_space.n # 动作维度\n",
|
||||
" agent = QLearning(state_dim,action_dim,cfg)\n",
|
||||
" return env,agent"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -68,9 +68,9 @@ def env_agent_config(cfg,seed=1):
|
||||
env = gym.make(cfg.env_name)
|
||||
env = CliffWalkingWapper(env)
|
||||
env.seed(seed) # 设置随机种子
|
||||
n_states = env.observation_space.n # 状态维度
|
||||
n_actions = env.action_space.n # 动作维度
|
||||
agent = QLearning(n_states,n_actions,cfg)
|
||||
state_dim = env.observation_space.n # 状态维度
|
||||
action_dim = env.action_space.n # 动作维度
|
||||
agent = QLearning(state_dim,action_dim,cfg)
|
||||
return env,agent
|
||||
|
||||
cfg = QlearningConfig()
|
||||
|
||||
Reference in New Issue
Block a user