update codes

This commit is contained in:
johnjim0816
2021-12-28 18:46:52 +08:00
parent 41fb561d25
commit bd51b5a7ad
52 changed files with 305 additions and 292 deletions

View File

@@ -39,11 +39,11 @@ class ReplayBuffer:
'''
return len(self.buffer)
class Actor(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
super(Actor, self).__init__()
self.linear1 = nn.Linear(n_states, hidden_dim)
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, n_actions)
self.linear3 = nn.Linear(hidden_dim, action_dim)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
@@ -54,10 +54,10 @@ class Actor(nn.Module):
x = torch.tanh(self.linear3(x))
return x
class Critic(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim, init_w=3e-3):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
super(Critic, self).__init__()
self.linear1 = nn.Linear(n_states + n_actions, hidden_dim)
self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
# 随机初始化为较小的值
@@ -72,12 +72,12 @@ class Critic(nn.Module):
x = self.linear3(x)
return x
class DDPG:
def __init__(self, n_states, n_actions, cfg):
def __init__(self, state_dim, action_dim, cfg):
self.device = cfg.device
self.critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.target_critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.target_actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
self.actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
self.target_critic = Critic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
self.target_actor = Actor(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
# 复制参数到目标网络
for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):

View File

@@ -39,15 +39,15 @@ class OUNoise(object):
self.max_sigma = max_sigma
self.min_sigma = min_sigma
self.decay_period = decay_period
self.n_actions = action_space.shape[0]
self.action_dim = action_space.shape[0]
self.low = action_space.low
self.high = action_space.high
self.reset()
def reset(self):
self.obs = np.ones(self.n_actions) * self.mu
self.obs = np.ones(self.action_dim) * self.mu
def evolve_obs(self):
x = self.obs
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.n_actions)
dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
self.obs = x + dx
return self.obs
def get_action(self, action, t=0):

View File

@@ -58,9 +58,9 @@ class PlotConfig:
def env_agent_config(cfg,seed=1):
env = NormalizedActions(gym.make(cfg.env_name)) # 装饰action噪声
env.seed(seed) # 随机种子
n_states = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]
agent = DDPG(n_states,n_actions,cfg)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
agent = DDPG(state_dim,action_dim,cfg)
return env,agent
cfg = DDPGConfig()