update projects

This commit is contained in:
johnjim0816
2022-07-31 23:42:12 +08:00
parent e9b3e92141
commit ffab9e3028
236 changed files with 370 additions and 133 deletions

56
projects/codes/A2C/a2c.py Normal file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2021-05-03 22:16:08
LastEditor: JiangJi
LastEditTime: 2022-07-20 23:54:40
Discription:
Environment:
'''
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
class ActorCritic(nn.Module):
''' A2C网络模型包含一个Actor和Critic
'''
def __init__(self, input_dim, output_dim, hidden_dim):
super(ActorCritic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.actor = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.Softmax(dim=1),
)
def forward(self, x):
value = self.critic(x)
probs = self.actor(x)
dist = Categorical(probs)
return dist, value
class A2C:
''' A2C算法
'''
def __init__(self,n_states,n_actions,cfg) -> None:
self.gamma = cfg.gamma
self.device = torch.device(cfg.device)
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
self.optimizer = optim.Adam(self.model.parameters())
def compute_returns(self,next_value, rewards, masks):
R = next_value
returns = []
for step in reversed(range(len(rewards))):
R = rewards[step] + self.gamma * R * masks[step]
returns.insert(0, R)
return returns