#!/usr/bin/env python # coding=utf-8 ''' Author: JiangJi Email: johnjim0816@gmail.com Date: 2021-05-03 22:16:08 LastEditor: JiangJi LastEditTime: 2022-07-20 23:54:40 Discription: Environment: ''' import torch import torch.optim as optim import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical class ActorCritic(nn.Module): ''' A2C网络模型,包含一个Actor和Critic ''' def __init__(self, input_dim, output_dim, hidden_dim): super(ActorCritic, self).__init__() self.critic = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) self.actor = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), nn.Softmax(dim=1), ) def forward(self, x): value = self.critic(x) probs = self.actor(x) dist = Categorical(probs) return dist, value class A2C: ''' A2C算法 ''' def __init__(self,n_states,n_actions,cfg) -> None: self.gamma = cfg.gamma self.device = torch.device(cfg.device) self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device) self.optimizer = optim.Adam(self.model.parameters()) def compute_returns(self,next_value, rewards, masks): R = next_value returns = [] for step in reversed(range(len(rewards))): R = rewards[step] + self.gamma * R * masks[step] returns.insert(0, R) return returns