56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
#!/usr/bin/env python
|
||
# coding=utf-8
|
||
'''
|
||
Author: JiangJi
|
||
Email: johnjim0816@gmail.com
|
||
Date: 2021-05-03 22:16:08
|
||
LastEditor: JiangJi
|
||
LastEditTime: 2022-07-20 23:54:40
|
||
Discription:
|
||
Environment:
|
||
'''
|
||
import torch
|
||
import torch.optim as optim
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.distributions import Categorical
|
||
|
||
class ActorCritic(nn.Module):
|
||
''' A2C网络模型,包含一个Actor和Critic
|
||
'''
|
||
def __init__(self, input_dim, output_dim, hidden_dim):
|
||
super(ActorCritic, self).__init__()
|
||
self.critic = nn.Sequential(
|
||
nn.Linear(input_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, 1)
|
||
)
|
||
|
||
self.actor = nn.Sequential(
|
||
nn.Linear(input_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, output_dim),
|
||
nn.Softmax(dim=1),
|
||
)
|
||
|
||
def forward(self, x):
|
||
value = self.critic(x)
|
||
probs = self.actor(x)
|
||
dist = Categorical(probs)
|
||
return dist, value
|
||
class A2C:
|
||
''' A2C算法
|
||
'''
|
||
def __init__(self,n_states,n_actions,cfg) -> None:
|
||
self.gamma = cfg.gamma
|
||
self.device = torch.device(cfg.device)
|
||
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
|
||
self.optimizer = optim.Adam(self.model.parameters())
|
||
|
||
def compute_returns(self,next_value, rewards, masks):
|
||
R = next_value
|
||
returns = []
|
||
for step in reversed(range(len(rewards))):
|
||
R = rewards[step] + self.gamma * R * masks[step]
|
||
returns.insert(0, R)
|
||
return returns |