update rainbowdqn

This commit is contained in:
johnjim0816
2022-05-31 01:20:58 +08:00
parent cfc0f6492e
commit c7c94468c9
149 changed files with 1866 additions and 1549 deletions

View File

@@ -43,10 +43,10 @@ class ReplayBuffer:
return len(self.buffer)
class ValueNet(nn.Module):
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
def __init__(self, n_states, hidden_dim, init_w=3e-3):
super(ValueNet, self).__init__()
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear1 = nn.Linear(n_states, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
@@ -61,10 +61,10 @@ class ValueNet(nn.Module):
class SoftQNet(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
super(SoftQNet, self).__init__()
self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.linear1 = nn.Linear(n_states + n_actions, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
@@ -80,20 +80,20 @@ class SoftQNet(nn.Module):
class PolicyNet(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
super(PolicyNet, self).__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear1 = nn.Linear(n_states, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.mean_linear = nn.Linear(hidden_dim, action_dim)
self.mean_linear = nn.Linear(hidden_dim, n_actions)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_dim, action_dim)
self.log_std_linear = nn.Linear(hidden_dim, n_actions)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
@@ -134,14 +134,14 @@ class PolicyNet(nn.Module):
return action[0]
class SAC:
def __init__(self,state_dim,action_dim,cfg) -> None:
def __init__(self,n_states,n_actions,cfg) -> None:
self.batch_size = cfg.batch_size
self.memory = ReplayBuffer(cfg.capacity)
self.device = cfg.device
self.value_net = ValueNet(state_dim, cfg.hidden_dim).to(self.device)
self.target_value_net = ValueNet(state_dim, cfg.hidden_dim).to(self.device)
self.soft_q_net = SoftQNet(state_dim, action_dim, cfg.hidden_dim).to(self.device)
self.policy_net = PolicyNet(state_dim, action_dim, cfg.hidden_dim).to(self.device)
self.value_net = ValueNet(n_states, cfg.hidden_dim).to(self.device)
self.target_value_net = ValueNet(n_states, cfg.hidden_dim).to(self.device)
self.soft_q_net = SoftQNet(n_states, n_actions, cfg.hidden_dim).to(self.device)
self.policy_net = PolicyNet(n_states, n_actions, cfg.hidden_dim).to(self.device)
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=cfg.value_lr)
self.soft_q_optimizer = optim.Adam(self.soft_q_net.parameters(), lr=cfg.soft_q_lr)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.policy_lr)