This commit is contained in:
johnjim0816
2021-05-03 23:00:01 +08:00
parent 895094a893
commit 8028f7145e
67 changed files with 738 additions and 1137 deletions

View File

@@ -1,32 +1,27 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2020-11-03 20:47:09
LastEditor: John
LastEditTime: 2021-03-20 17:41:21
Date: 2021-05-03 22:16:08
LastEditor: JiangJi
LastEditTime: 2021-05-03 22:23:48
Discription:
Environment:
'''
from A2C.model import ActorCritic
import torch.optim as optim
from A2C.model import ActorCritic
class A2C:
def __init__(self,state_dim, action_dim, cfg):
self.gamma = 0.99
self.model = ActorCritic(state_dim, action_dim, hidden_dim=cfg.hidden_dim).to(cfg.device)
self.optimizer = optim.Adam(self.model.parameters(),lr=cfg.lr)
def choose_action(self, state):
dist, value = self.model(state)
action = dist.sample()
return action
def __init__(self,state_dim,action_dim,cfg) -> None:
self.gamma = cfg.gamma
self.device = cfg.device
self.model = ActorCritic(state_dim, action_dim, cfg.hidden_size).to(self.device)
self.optimizer = optim.Adam(self.model.parameters())
def compute_returns(self,next_value, rewards, masks):
R = next_value
returns = []
for step in reversed(range(len(rewards))):
R = rewards[step] + self.gamma * R * masks[step]
returns.insert(0, R)
return returns
def update(self):
pass
return returns