update rainbowdqn
This commit is contained in:
@@ -40,10 +40,10 @@ class ActorCritic(nn.Module):
|
||||
class A2C:
|
||||
''' A2C算法
|
||||
'''
|
||||
def __init__(self,state_dim,action_dim,cfg) -> None:
|
||||
def __init__(self,n_states,n_actions,cfg) -> None:
|
||||
self.gamma = cfg.gamma
|
||||
self.device = cfg.device
|
||||
self.model = ActorCritic(state_dim, action_dim, cfg.hidden_size).to(self.device)
|
||||
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
|
||||
def compute_returns(self,next_value, rewards, masks):
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
import torch.optim as optim
|
||||
import datetime
|
||||
from common.multiprocessing_env import SubprocVecEnv
|
||||
from A2C.agent import ActorCritic
|
||||
from a2c import ActorCritic
|
||||
from common.utils import save_results, make_dir
|
||||
from common.utils import plot_rewards
|
||||
|
||||
@@ -74,9 +74,9 @@ def train(cfg,envs):
|
||||
print(f'环境:{cfg.env_name}, 算法:{cfg.algo}, 设备:{cfg.device}')
|
||||
env = gym.make(cfg.env_name) # a single env
|
||||
env.seed(10)
|
||||
state_dim = envs.observation_space.shape[0]
|
||||
action_dim = envs.action_space.n
|
||||
model = ActorCritic(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
|
||||
n_states = envs.observation_space.shape[0]
|
||||
n_actions = envs.action_space.n
|
||||
model = ActorCritic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
|
||||
optimizer = optim.Adam(model.parameters())
|
||||
frame_idx = 0
|
||||
test_rewards = []
|
||||
|
||||
Reference in New Issue
Block a user