update rainbowdqn

This commit is contained in:
johnjim0816
2022-05-31 01:20:58 +08:00
parent cfc0f6492e
commit c7c94468c9
149 changed files with 1866 additions and 1549 deletions

View File

@@ -6,10 +6,9 @@ sys.path.append(parent_path) # 添加路径到系统路径
import gym
import torch
import datetime
from common.plot import plot_rewards
from common.utils import plot_rewards
from common.utils import save_results,make_dir
from PPO.agent import PPO
from PPO.train import train
from ppo2 import PPO
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
@@ -45,9 +44,9 @@ class PlotConfig:
def env_agent_config(cfg,seed=1):
env = gym.make(cfg.env_name)
env.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
agent = PPO(state_dim,action_dim,cfg)
n_states = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]
agent = PPO(n_states,n_actions,cfg)
return env,agent