更新算法模版

This commit is contained in:
johnjim0816
2022-11-06 12:15:36 +08:00
parent 466a17707f
commit dc78698262
256 changed files with 17282 additions and 10229 deletions

View File

@@ -1,14 +1,24 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2022-09-19 14:48:16
LastEditor: JiangJi
LastEditTime: 2022-10-30 01:21:50
Discription: #TODO待更新模版
'''
import torch
import numpy as np
class A2C_2:
def __init__(self,models,memories,cfg):
self.n_actions = cfg['n_actions']
self.gamma = cfg['gamma']
self.device = torch.device(cfg['device'])
self.n_actions = cfg.n_actions
self.gamma = cfg.gamma
self.device = torch.device(cfg.device)
self.memory = memories['ACMemory']
self.ac_net = models['ActorCritic'].to(self.device)
self.ac_optimizer = torch.optim.Adam(self.ac_net.parameters(), lr=cfg['lr'])
self.ac_optimizer = torch.optim.Adam(self.ac_net.parameters(), lr = cfg.lr)
def sample_action(self,state):
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
value, dist = self.ac_net(state) # note that 'dist' need require_grad=True