update codes

This commit is contained in:
johnjim0816
2021-12-28 18:46:52 +08:00
parent 41fb561d25
commit bd51b5a7ad
52 changed files with 305 additions and 292 deletions

View File

@@ -42,7 +42,7 @@ class ReplayBuffer:
class MLP(nn.Module):
def __init__(self, input_dim,output_dim,hidden_dim=128):
""" 初始化q网络为全连接网络
input_dim: 输入的特征数即环境的状态
input_dim: 输入的特征数即环境的状态维度
output_dim: 输出的动作维度
"""
super(MLP, self).__init__()
@@ -57,16 +57,16 @@ class MLP(nn.Module):
return self.fc3(x)
class HierarchicalDQN:
def __init__(self,n_states,n_actions,cfg):
self.n_states = n_states
self.n_actions = n_actions
def __init__(self,state_dim,action_dim,cfg):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = cfg.gamma
self.device = cfg.device
self.batch_size = cfg.batch_size
self.frame_idx = 0 # 用于epsilon的衰减计数
self.epsilon = lambda frame_idx: cfg.epsilon_end + (cfg.epsilon_start - cfg.epsilon_end ) * math.exp(-1. * frame_idx / cfg.epsilon_decay)
self.policy_net = MLP(2*n_states, n_actions,cfg.hidden_dim).to(self.device)
self.meta_policy_net = MLP(n_states, n_states,cfg.hidden_dim).to(self.device)
self.policy_net = MLP(2*state_dim, action_dim,cfg.hidden_dim).to(self.device)
self.meta_policy_net = MLP(state_dim, state_dim,cfg.hidden_dim).to(self.device)
self.optimizer = optim.Adam(self.policy_net.parameters(),lr=cfg.lr)
self.meta_optimizer = optim.Adam(self.meta_policy_net.parameters(),lr=cfg.lr)
self.memory = ReplayBuffer(cfg.memory_capacity)
@@ -76,7 +76,7 @@ class HierarchicalDQN:
self.losses = []
self.meta_losses = []
def to_onehot(self,x):
oh = np.zeros(self.n_states)
oh = np.zeros(self.state_dim)
oh[x - 1] = 1.
return oh
def set_goal(self,state):
@@ -85,7 +85,7 @@ class HierarchicalDQN:
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(0)
goal = self.meta_policy_net(state).max(1)[1].item()
else:
goal = random.randrange(self.n_states)
goal = random.randrange(self.state_dim)
return goal
def choose_action(self,state):
self.frame_idx += 1
@@ -95,7 +95,7 @@ class HierarchicalDQN:
q_value = self.policy_net(state)
action = q_value.max(1)[1].item()
else:
action = random.randrange(self.n_actions)
action = random.randrange(self.action_dim)
return action
def update(self):
self.update_policy()