update rainbowdqn

This commit is contained in:
johnjim0816
2022-05-31 01:20:58 +08:00
parent cfc0f6492e
commit c7c94468c9
149 changed files with 1866 additions and 1549 deletions

View File

@@ -57,16 +57,16 @@ model就是actor和critic两个网络了
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Actor(nn.Module):
def __init__(self,state_dim, action_dim,
def __init__(self,n_states, n_actions,
hidden_dim=256):
super(Actor, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Linear(hidden_dim, n_actions),
nn.Softmax(dim=-1)
)
def forward(self, state):
@@ -75,10 +75,10 @@ class Actor(nn.Module):
return dist
class Critic(nn.Module):
def __init__(self, state_dim,hidden_dim=256):
def __init__(self, n_states,hidden_dim=256):
super(Critic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Linear(n_states, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
@@ -88,7 +88,7 @@ class Critic(nn.Module):
value = self.critic(state)
return value
```
这里Actor就是得到一个概率分布(Categorica也可以是别的分布可以搜索torch distributionsl)critc根据当前状态得到一个值这里的输入维度可以是```state_dim+action_dim```即将action信息也纳入critic网络中这样会更好一些感兴趣的小伙伴可以试试。
这里Actor就是得到一个概率分布(Categorica也可以是别的分布可以搜索torch distributionsl)critc根据当前状态得到一个值这里的输入维度可以是```n_states+n_actions```即将action信息也纳入critic网络中这样会更好一些感兴趣的小伙伴可以试试。
### PPO update
定义一个update函数主要实现伪代码中的第六步和第七步