update codes

This commit is contained in:
johnjim0816
2021-12-22 16:55:09 +08:00
parent 75df999258
commit 41fb561d25
75 changed files with 1248 additions and 918 deletions

View File

@@ -57,16 +57,16 @@ model就是actor和critic两个网络了
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Actor(nn.Module):
def __init__(self,state_dim, action_dim,
def __init__(self,n_states, n_actions,
hidden_dim=256):
super(Actor, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Linear(hidden_dim, n_actions),
nn.Softmax(dim=-1)
)
def forward(self, state):
@@ -75,10 +75,10 @@ class Actor(nn.Module):
return dist
class Critic(nn.Module):
def __init__(self, state_dim,hidden_dim=256):
def __init__(self, n_states,hidden_dim=256):
super(Critic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
@@ -88,7 +88,7 @@ class Critic(nn.Module):
value = self.critic(state)
return value
```
这里Actor就是得到一个概率分布(Categorica也可以是别的分布可以搜索torch distributionsl)critc根据当前状态得到一个值这里的输入维度可以是```state_dim+action_dim```即将action信息也纳入critic网络中这样会更好一些感兴趣的小伙伴可以试试。
这里Actor就是得到一个概率分布(Categorica也可以是别的分布可以搜索torch distributionsl)critc根据当前状态得到一个值这里的输入维度可以是```n_states+n_actions```即将action信息也纳入critic网络中这样会更好一些感兴趣的小伙伴可以试试。
### PPO update
定义一个update函数主要实现伪代码中的第六步和第七步

View File

@@ -16,15 +16,15 @@ import torch.optim as optim
from PPO.model import Actor,Critic
from PPO.memory import PPOMemory
class PPO:
def __init__(self, state_dim, action_dim,cfg):
def __init__(self, n_states, n_actions,cfg):
self.gamma = cfg.gamma
self.continuous = cfg.continuous
self.policy_clip = cfg.policy_clip
self.n_epochs = cfg.n_epochs
self.gae_lambda = cfg.gae_lambda
self.device = cfg.device
self.actor = Actor(state_dim, action_dim,cfg.hidden_dim).to(self.device)
self.critic = Critic(state_dim,cfg.hidden_dim).to(self.device)
self.actor = Actor(n_states, n_actions,cfg.hidden_dim).to(self.device)
self.critic = Critic(n_states,cfg.hidden_dim).to(self.device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg.actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)
self.memory = PPOMemory(cfg.batch_size)

View File

@@ -12,16 +12,16 @@ Environment:
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Actor(nn.Module):
def __init__(self,state_dim, action_dim,
def __init__(self,n_states, n_actions,
hidden_dim):
super(Actor, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Linear(hidden_dim, n_actions),
nn.Softmax(dim=-1)
)
def forward(self, state):
@@ -30,10 +30,10 @@ class Actor(nn.Module):
return dist
class Critic(nn.Module):
def __init__(self, state_dim,hidden_dim):
def __init__(self, n_states,hidden_dim):
super(Critic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),

View File

@@ -45,9 +45,9 @@ class PlotConfig:
def env_agent_config(cfg,seed=1):
env = gym.make(cfg.env_name)
env.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim,action_dim,cfg)
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = PPO(n_states,n_actions,cfg)
return env,agent
cfg = PPOConfig()

View File

@@ -45,9 +45,9 @@ class PlotConfig:
def env_agent_config(cfg,seed=1):
env = gym.make(cfg.env_name)
env.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
agent = PPO(state_dim,action_dim,cfg)
n_states = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]
agent = PPO(n_states,n_actions,cfg)
return env,agent

View File

@@ -90,9 +90,9 @@
"def env_agent_config(cfg,seed=1):\n",
" env = gym.make(cfg.env) \n",
" env.seed(seed)\n",
" state_dim = env.observation_space.shape[0]\n",
" action_dim = env.action_space.n\n",
" agent = PPO(state_dim,action_dim,cfg)\n",
" n_states = env.observation_space.shape[0]\n",
" n_actions = env.action_space.n\n",
" agent = PPO(n_states,n_actions,cfg)\n",
" return env,agent"
]
},

View File

@@ -99,9 +99,9 @@ if __name__ == '__main__':
def env_agent_config(cfg,seed=1):
env = gym.make(cfg.env_name)
env.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim,action_dim,cfg)
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = PPO(n_states,n_actions,cfg)
return env,agent
cfg = PPOConfig()