update codes
This commit is contained in:
@@ -57,16 +57,16 @@ model就是actor和critic两个网络了:
|
||||
import torch.nn as nn
|
||||
from torch.distributions.categorical import Categorical
|
||||
class Actor(nn.Module):
|
||||
def __init__(self,state_dim, action_dim,
|
||||
def __init__(self,n_states, n_actions,
|
||||
hidden_dim=256):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.Linear(n_states, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, action_dim),
|
||||
nn.Linear(hidden_dim, n_actions),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
def forward(self, state):
|
||||
@@ -75,10 +75,10 @@ class Actor(nn.Module):
|
||||
return dist
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, state_dim,hidden_dim=256):
|
||||
def __init__(self, n_states,hidden_dim=256):
|
||||
super(Critic, self).__init__()
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.Linear(n_states, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
@@ -88,7 +88,7 @@ class Critic(nn.Module):
|
||||
value = self.critic(state)
|
||||
return value
|
||||
```
|
||||
这里Actor就是得到一个概率分布(Categorica,也可以是别的分布,可以搜索torch distributionsl),critc根据当前状态得到一个值,这里的输入维度可以是```state_dim+action_dim```,即将action信息也纳入critic网络中,这样会更好一些,感兴趣的小伙伴可以试试。
|
||||
这里Actor就是得到一个概率分布(Categorica,也可以是别的分布,可以搜索torch distributionsl),critc根据当前状态得到一个值,这里的输入维度可以是```n_states+n_actions```,即将action信息也纳入critic网络中,这样会更好一些,感兴趣的小伙伴可以试试。
|
||||
|
||||
### PPO update
|
||||
定义一个update函数主要实现伪代码中的第六步和第七步:
|
||||
|
||||
@@ -16,15 +16,15 @@ import torch.optim as optim
|
||||
from PPO.model import Actor,Critic
|
||||
from PPO.memory import PPOMemory
|
||||
class PPO:
|
||||
def __init__(self, state_dim, action_dim,cfg):
|
||||
def __init__(self, n_states, n_actions,cfg):
|
||||
self.gamma = cfg.gamma
|
||||
self.continuous = cfg.continuous
|
||||
self.policy_clip = cfg.policy_clip
|
||||
self.n_epochs = cfg.n_epochs
|
||||
self.gae_lambda = cfg.gae_lambda
|
||||
self.device = cfg.device
|
||||
self.actor = Actor(state_dim, action_dim,cfg.hidden_dim).to(self.device)
|
||||
self.critic = Critic(state_dim,cfg.hidden_dim).to(self.device)
|
||||
self.actor = Actor(n_states, n_actions,cfg.hidden_dim).to(self.device)
|
||||
self.critic = Critic(n_states,cfg.hidden_dim).to(self.device)
|
||||
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg.actor_lr)
|
||||
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)
|
||||
self.memory = PPOMemory(cfg.batch_size)
|
||||
|
||||
@@ -12,16 +12,16 @@ Environment:
|
||||
import torch.nn as nn
|
||||
from torch.distributions.categorical import Categorical
|
||||
class Actor(nn.Module):
|
||||
def __init__(self,state_dim, action_dim,
|
||||
def __init__(self,n_states, n_actions,
|
||||
hidden_dim):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.Linear(n_states, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, action_dim),
|
||||
nn.Linear(hidden_dim, n_actions),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
def forward(self, state):
|
||||
@@ -30,10 +30,10 @@ class Actor(nn.Module):
|
||||
return dist
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, state_dim,hidden_dim):
|
||||
def __init__(self, n_states,hidden_dim):
|
||||
super(Critic, self).__init__()
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.Linear(n_states, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
|
||||
@@ -45,9 +45,9 @@ class PlotConfig:
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = gym.make(cfg.env_name)
|
||||
env.seed(seed)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.n
|
||||
agent = PPO(state_dim,action_dim,cfg)
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
agent = PPO(n_states,n_actions,cfg)
|
||||
return env,agent
|
||||
|
||||
cfg = PPOConfig()
|
||||
|
||||
@@ -45,9 +45,9 @@ class PlotConfig:
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = gym.make(cfg.env_name)
|
||||
env.seed(seed)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.shape[0]
|
||||
agent = PPO(state_dim,action_dim,cfg)
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.shape[0]
|
||||
agent = PPO(n_states,n_actions,cfg)
|
||||
return env,agent
|
||||
|
||||
|
||||
|
||||
@@ -90,9 +90,9 @@
|
||||
"def env_agent_config(cfg,seed=1):\n",
|
||||
" env = gym.make(cfg.env) \n",
|
||||
" env.seed(seed)\n",
|
||||
" state_dim = env.observation_space.shape[0]\n",
|
||||
" action_dim = env.action_space.n\n",
|
||||
" agent = PPO(state_dim,action_dim,cfg)\n",
|
||||
" n_states = env.observation_space.shape[0]\n",
|
||||
" n_actions = env.action_space.n\n",
|
||||
" agent = PPO(n_states,n_actions,cfg)\n",
|
||||
" return env,agent"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -99,9 +99,9 @@ if __name__ == '__main__':
|
||||
def env_agent_config(cfg,seed=1):
|
||||
env = gym.make(cfg.env_name)
|
||||
env.seed(seed)
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.n
|
||||
agent = PPO(state_dim,action_dim,cfg)
|
||||
n_states = env.observation_space.shape[0]
|
||||
n_actions = env.action_space.n
|
||||
agent = PPO(n_states,n_actions,cfg)
|
||||
return env,agent
|
||||
|
||||
cfg = PPOConfig()
|
||||
|
||||
Reference in New Issue
Block a user