更新算法模版
This commit is contained in:
@@ -1,14 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: JiangJi
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2022-09-19 14:48:16
|
||||
LastEditor: JiangJi
|
||||
LastEditTime: 2022-10-30 01:21:50
|
||||
Discription: #TODO,待更新模版
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class A2C_2:
|
||||
def __init__(self,models,memories,cfg):
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.gamma = cfg['gamma']
|
||||
self.device = torch.device(cfg['device'])
|
||||
self.n_actions = cfg.n_actions
|
||||
self.gamma = cfg.gamma
|
||||
self.device = torch.device(cfg.device)
|
||||
self.memory = memories['ACMemory']
|
||||
self.ac_net = models['ActorCritic'].to(self.device)
|
||||
self.ac_optimizer = torch.optim.Adam(self.ac_net.parameters(), lr=cfg['lr'])
|
||||
self.ac_optimizer = torch.optim.Adam(self.ac_net.parameters(), lr = cfg.lr)
|
||||
def sample_action(self,state):
|
||||
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
|
||||
value, dist = self.ac_net(state) # note that 'dist' need require_grad=True
|
||||
|
||||
Reference in New Issue
Block a user