update codes
This commit is contained in:
@@ -39,11 +39,11 @@ class ReplayBuffer:
|
||||
'''
|
||||
return len(self.buffer)
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
|
||||
super(Actor, self).__init__()
|
||||
self.linear1 = nn.Linear(n_states, hidden_dim)
|
||||
self.linear1 = nn.Linear(state_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear3 = nn.Linear(hidden_dim, n_actions)
|
||||
self.linear3 = nn.Linear(hidden_dim, action_dim)
|
||||
|
||||
self.linear3.weight.data.uniform_(-init_w, init_w)
|
||||
self.linear3.bias.data.uniform_(-init_w, init_w)
|
||||
@@ -54,10 +54,10 @@ class Actor(nn.Module):
|
||||
x = torch.tanh(self.linear3(x))
|
||||
return x
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
|
||||
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
|
||||
super(Critic, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(n_states + n_actions, hidden_dim)
|
||||
self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear3 = nn.Linear(hidden_dim, 1)
|
||||
# 随机初始化为较小的值
|
||||
@@ -72,12 +72,12 @@ class Critic(nn.Module):
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
class DDPG:
|
||||
def __init__(self, n_states, n_actions, cfg):
|
||||
def __init__(self, state_dim, action_dim, cfg):
|
||||
self.device = cfg.device
|
||||
self.critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
self.critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
self.target_actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
|
||||
# 复制参数到目标网络
|
||||
for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
|
||||
|
||||
@@ -39,15 +39,15 @@ class OUNoise(object):
|
||||
self.max_sigma = max_sigma
|
||||
self.min_sigma = min_sigma
|
||||
self.decay_period = decay_period
|
||||
self.n_actions = action_space.shape[0]
|
||||
self.action_dim = action_space.shape[0]
|
||||
self.low = action_space.low
|
||||
self.high = action_space.high
|
||||
self.reset()
|
||||
def reset(self):
|
||||
self.obs = np.ones(self.n_actions) * self.mu
|
||||
self.obs = np.ones(self.action_dim) * self.mu
|
||||
def evolve_obs(self):
|
||||
x = self.obs
|
||||
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.n_actions)
|
||||
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
|
||||
self.obs = x + dx
|
||||
return self.obs
|
||||
def get_action(self, action, t=0):
|
||||
|
||||
@@ -58,9 +58,9 @@ class PlotConfig:
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = NormalizedActions(gym.make(cfg.env_name)) # 装饰action噪声
|
||||
env.seed(seed) # 随机种子
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.shape[0]
|
||||
agent = DDPG(n_states,n_actions,cfg)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.shape[0]
|
||||
agent = DDPG(state_dim,action_dim,cfg)
|
||||
return env,agent
|
||||
|
||||
cfg = DDPGConfig()
|
||||
|
||||
Reference in New Issue
Block a user