diff --git a/codes/DDPG/main.py b/codes/DDPG/main.py index 58a574c..736178b 100644 --- a/codes/DDPG/main.py +++ b/codes/DDPG/main.py @@ -5,7 +5,7 @@ @Email: johnjim0816@gmail.com @Date: 2020-06-11 20:58:21 @LastEditor: John -LastEditTime: 2021-04-08 21:50:13 +LastEditTime: 2021-04-29 01:58:50 @Discription: @Environment: python 3.7.7 ''' @@ -82,7 +82,6 @@ def train(cfg,env,agent): if __name__ == "__main__": cfg = DDPGConfig() - env = env = NormalizedActions(gym.make("Pendulum-v0")) env.seed(1) # 设置env随机种子 state_dim = env.observation_space.shape[0] diff --git a/codes/DQN/main.py b/codes/DQN/main.py index 2478050..cd22aad 100644 --- a/codes/DQN/main.py +++ b/codes/DQN/main.py @@ -5,22 +5,24 @@ @Email: johnjim0816@gmail.com @Date: 2020-06-12 00:48:57 @LastEditor: John -LastEditTime: 2021-04-18 14:44:45 +LastEditTime: 2021-04-29 02:02:12 @Discription: @Environment: python 3.7.7 ''' -from common.utils import save_results, make_dir, del_empty_dir -from common.plot import plot_rewards -from DQN.agent import DQN -import datetime -import torch -import gym -import sys -import os +import sys,os curr_path = os.path.dirname(__file__) parent_path = os.path.dirname(curr_path) sys.path.append(parent_path) # add current terminal path to sys.path +import datetime +import torch +import gym + +from common.utils import save_results, make_dir, del_empty_dir +from common.plot import plot_rewards +from DQN.agent import DQN + + curr_time = datetime.datetime.now().strftime( "%Y%m%d-%H%M%S") # obtain current time diff --git a/codes/README.md b/codes/README.md index fd86303..2f51e2a 100644 --- a/codes/README.md +++ b/codes/README.md @@ -37,13 +37,11 @@ python 3.7、pytorch 1.6.0-1.7.1、gym 0.17.0-0.18.0 | [DoubleDQN](./DoubleDQN) | | [CartPole-v0](./envs/gym_info.md) | | | [Hierarchical DQN](HierarchicalDQN) | [H-DQN Paper](https://arxiv.org/abs/1604.06057) | [CartPole-v0](./envs/gym_info.md) | | | [PolicyGradient](./PolicyGradient) | | [CartPole-v0](./envs/gym_info.md) | | -| A2C | [A3C Paper](https://arxiv.org/abs/1602.01783) | [CartPole-v0](./envs/gym_info.md) | | -| A3C | [A3C Paper](https://arxiv.org/abs/1602.01783) | | | -| SAC | [SAC Paper](https://arxiv.org/abs/1801.01290) | | | +| [A2C](./A2C) | [A3C Paper](https://arxiv.org/abs/1602.01783) | [CartPole-v0](./envs/gym_info.md) | | +| [SAC](./SAC) | [SAC Paper](https://arxiv.org/abs/1801.01290) | [Pendulum-v0](./envs/gym_info.md) | | | [PPO](./PPO) | [PPO paper](https://arxiv.org/abs/1707.06347) | [CartPole-v0](./envs/gym_info.md) | | | [DDPG](./DDPG) | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [Pendulum-v0](./envs/gym_info.md) | | | [TD3](./TD3) | [TD3 Paper](https://arxiv.org/abs/1802.09477) | HalfCheetah-v2 | | -| GAIL | [GAIL Paper](https://arxiv.org/abs/1606.03476) | | | diff --git a/codes/README_en.md b/codes/README_en.md index f3a95d6..95a6455 100644 --- a/codes/README_en.md +++ b/codes/README_en.md @@ -30,24 +30,21 @@ similar to file with ```eval```, which means to evaluate the agent. ## Schedule -| Name | Related materials | Used Envs | Notes | -| :--------------------------------------: | :---------------------------------------------------------: | ------------------------------------- | :---: | -| [On-Policy First-Visit MC](./MonteCarlo) | | [Racetrack](./envs/racetrack_env.md) | | -| [Q-Learning](./QLearning) | | [CliffWalking-v0](./envs/gym_info.md) | | -| [Sarsa](./Sarsa) | | [Racetrack](./envs/racetrack_env.md) | | -| [DQN](./DQN) | [DQN-paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) | [CartPole-v0](./envs/gym_info.md) | | -| [DQN-cnn](./DQN_cnn) | [DQN-paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) | [CartPole-v0](./envs/gym_info.md) | | -| [DoubleDQN](./DoubleDQN) | | [CartPole-v0](./envs/gym_info.md) | | -| [Hierarchical DQN](HierarchicalDQN) | [Hierarchical DQN](https://arxiv.org/abs/1604.06057) | [CartPole-v0](./envs/gym_info.md) | | -| [PolicyGradient](./PolicyGradient) | | [CartPole-v0](./envs/gym_info.md) | | -| A2C | [A3C Paper](https://arxiv.org/abs/1602.01783) | [CartPole-v0](./envs/gym_info.md) | | -| A3C | [A3C Paper](https://arxiv.org/abs/1602.01783) | | | -| SAC | [SAC Paper](https://arxiv.org/abs/1801.01290) | | | -| [PPO](./PPO) | [PPO paper](https://arxiv.org/abs/1707.06347) | [CartPole-v0](./envs/gym_info.md) | | -| [DDPG](./DDPG) | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [Pendulum-v0](./envs/gym_info.md) | | -| [TD3](./TD3) | [TD3 Paper](https://arxiv.org/abs/1802.09477) | HalfCheetah-v2 | | -| GAIL | | | | - +| Name | Related materials | Used Envs | Notes | +| :--------------------------------------: | :----------------------------------------------------------: | ------------------------------------- | :---: | +| [On-Policy First-Visit MC](./MonteCarlo) | | [Racetrack](./envs/racetrack_env.md) | | +| [Q-Learning](./QLearning) | | [CliffWalking-v0](./envs/gym_info.md) | | +| [Sarsa](./Sarsa) | | [Racetrack](./envs/racetrack_env.md) | | +| [DQN](./DQN) | [DQN-paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf),[Nature DQN Paper](https://www.nature.com/articles/nature14236) | [CartPole-v0](./envs/gym_info.md) | | +| [DQN-cnn](./DQN_cnn) | [DQN-paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) | [CartPole-v0](./envs/gym_info.md) | | +| [DoubleDQN](./DoubleDQN) | | [CartPole-v0](./envs/gym_info.md) | | +| [Hierarchical DQN](HierarchicalDQN) | [Hierarchical DQN](https://arxiv.org/abs/1604.06057) | [CartPole-v0](./envs/gym_info.md) | | +| [PolicyGradient](./PolicyGradient) | | [CartPole-v0](./envs/gym_info.md) | | +| [A2C](./A2C) | [A3C Paper](https://arxiv.org/abs/1602.01783) | [CartPole-v0](./envs/gym_info.md) | | +| [SAC](./SAC) | [SAC Paper](https://arxiv.org/abs/1801.01290) | | | +| [PPO](./PPO) | [PPO paper](https://arxiv.org/abs/1707.06347) | [CartPole-v0](./envs/gym_info.md) | | +| [DDPG](./DDPG) | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [Pendulum-v0](./envs/gym_info.md) | | +| [TD3](./TD3) | [TD3 Paper](https://arxiv.org/abs/1802.09477) | HalfCheetah-v2 | | ## Refs diff --git a/codes/SAC/agent.py b/codes/SAC/agent.py new file mode 100644 index 0000000..1568eb3 --- /dev/null +++ b/codes/SAC/agent.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: JiangJi +Email: johnjim0816@gmail.com +Date: 2021-04-29 12:53:54 +LastEditor: JiangJi +LastEditTime: 2021-04-29 13:56:39 +Discription: +Environment: +''' +import copy +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +from common.memory import ReplayBuffer +from SAC.model import ValueNet,PolicyNet,SoftQNet + +class SAC: + def __init__(self,state_dim,action_dim,cfg) -> None: + self.batch_size = cfg.batch_size + self.memory = ReplayBuffer(cfg.capacity) + self.device = cfg.device + self.value_net = ValueNet(state_dim, cfg.hidden_dim).to(self.device) + self.target_value_net = ValueNet(state_dim, cfg.hidden_dim).to(self.device) + self.soft_q_net = SoftQNet(state_dim, action_dim, cfg.hidden_dim).to(self.device) + self.policy_net = PolicyNet(state_dim, action_dim, cfg.hidden_dim).to(self.device) + self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=cfg.value_lr) + self.soft_q_optimizer = optim.Adam(self.soft_q_net.parameters(), lr=cfg.soft_q_lr) + self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.policy_lr) + for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): + target_param.data.copy_(param.data) + self.value_criterion = nn.MSELoss() + self.soft_q_criterion = nn.MSELoss() + def update(self, gamma=0.99,mean_lambda=1e-3, + std_lambda=1e-3, + z_lambda=0.0, + soft_tau=1e-2, + ): + if len(self.memory) < self.batch_size: + return + state, action, reward, next_state, done = self.memory.sample(self.batch_size) + state = torch.FloatTensor(state).to(self.device) + next_state = torch.FloatTensor(next_state).to(self.device) + action = torch.FloatTensor(action).to(self.device) + reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device) + done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device) + expected_q_value = self.soft_q_net(state, action) + expected_value = self.value_net(state) + new_action, log_prob, z, mean, log_std = self.policy_net.evaluate(state) + + + target_value = self.target_value_net(next_state) + next_q_value = reward + (1 - done) * gamma * target_value + q_value_loss = self.soft_q_criterion(expected_q_value, next_q_value.detach()) + + expected_new_q_value = self.soft_q_net(state, new_action) + next_value = expected_new_q_value - log_prob + value_loss = self.value_criterion(expected_value, next_value.detach()) + + log_prob_target = expected_new_q_value - expected_value + policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean() + + + mean_loss = mean_lambda * mean.pow(2).mean() + std_loss = std_lambda * log_std.pow(2).mean() + z_loss = z_lambda * z.pow(2).sum(1).mean() + + policy_loss += mean_loss + std_loss + z_loss + + self.soft_q_optimizer.zero_grad() + q_value_loss.backward() + self.soft_q_optimizer.step() + + self.value_optimizer.zero_grad() + value_loss.backward() + self.value_optimizer.step() + + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + + + for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): + target_param.data.copy_( + target_param.data * (1.0 - soft_tau) + param.data * soft_tau + ) + def save(self, path): + torch.save(self.value_net.state_dict(), path + "sac_value") + torch.save(self.value_optimizer.state_dict(), path + "sac_value_optimizer") + + torch.save(self.soft_q_net.state_dict(), path + "sac_soft_q") + torch.save(self.soft_q_optimizer.state_dict(), path + "sac_soft_q_optimizer") + + torch.save(self.policy_net.state_dict(), path + "sac_policy") + torch.save(self.policy_optimizer.state_dict(), path + "sac_policy_optimizer") + + + + def load(self, path): + self.value_net.load_state_dict(torch.load(path + "sac_value")) + self.value_optimizer.load_state_dict(torch.load(path + "sac_value_optimizer")) + self.target_value_net = copy.deepcopy(self.value_net) + + self.soft_q_net.load_state_dict(torch.load(path + "sac_soft_q")) + self.soft_q_optimizer.load_state_dict(torch.load(path + "sac_soft_q_optimizer")) + + self.policy_net.load_state_dict(torch.load(path + "sac_policy")) + self.policy_optimizer.load_state_dict(torch.load(path + "sac_policy_optimizer")) \ No newline at end of file diff --git a/codes/SAC/env.py b/codes/SAC/env.py new file mode 100644 index 0000000..14e37a7 --- /dev/null +++ b/codes/SAC/env.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: JiangJi +Email: johnjim0816@gmail.com +Date: 2021-04-29 12:52:11 +LastEditor: JiangJi +LastEditTime: 2021-04-29 12:52:31 +Discription: +Environment: +''' +import gym +import numpy as np +class NormalizedActions(gym.ActionWrapper): + def action(self, action): + low = self.action_space.low + high = self.action_space.high + + action = low + (action + 1.0) * 0.5 * (high - low) + action = np.clip(action, low, high) + + return action + + def reverse_action(self, action): + low = self.action_space.low + high = self.action_space.high + action = 2 * (action - low) / (high - low) - 1 + action = np.clip(action, low, high) + return action \ No newline at end of file diff --git a/codes/SAC/model.py b/codes/SAC/model.py new file mode 100644 index 0000000..146db0d --- /dev/null +++ b/codes/SAC/model.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: JiangJi +Email: johnjim0816@gmail.com +Date: 2021-04-29 12:53:58 +LastEditor: JiangJi +LastEditTime: 2021-04-29 12:57:29 +Discription: +Environment: +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal + +device=torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class ValueNet(nn.Module): + def __init__(self, state_dim, hidden_dim, init_w=3e-3): + super(ValueNet, self).__init__() + + self.linear1 = nn.Linear(state_dim, hidden_dim) + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.linear3 = nn.Linear(hidden_dim, 1) + + self.linear3.weight.data.uniform_(-init_w, init_w) + self.linear3.bias.data.uniform_(-init_w, init_w) + + def forward(self, state): + x = F.relu(self.linear1(state)) + x = F.relu(self.linear2(x)) + x = self.linear3(x) + return x + + +class SoftQNet(nn.Module): + def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3): + super(SoftQNet, self).__init__() + + self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size) + self.linear2 = nn.Linear(hidden_size, hidden_size) + self.linear3 = nn.Linear(hidden_size, 1) + + self.linear3.weight.data.uniform_(-init_w, init_w) + self.linear3.bias.data.uniform_(-init_w, init_w) + + def forward(self, state, action): + x = torch.cat([state, action], 1) + x = F.relu(self.linear1(x)) + x = F.relu(self.linear2(x)) + x = self.linear3(x) + return x + + +class PolicyNet(nn.Module): + def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2): + super(PolicyNet, self).__init__() + + self.log_std_min = log_std_min + self.log_std_max = log_std_max + + self.linear1 = nn.Linear(num_inputs, hidden_size) + self.linear2 = nn.Linear(hidden_size, hidden_size) + + self.mean_linear = nn.Linear(hidden_size, num_actions) + self.mean_linear.weight.data.uniform_(-init_w, init_w) + self.mean_linear.bias.data.uniform_(-init_w, init_w) + + self.log_std_linear = nn.Linear(hidden_size, num_actions) + self.log_std_linear.weight.data.uniform_(-init_w, init_w) + self.log_std_linear.bias.data.uniform_(-init_w, init_w) + + def forward(self, state): + x = F.relu(self.linear1(state)) + x = F.relu(self.linear2(x)) + + mean = self.mean_linear(x) + log_std = self.log_std_linear(x) + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + + return mean, log_std + + def evaluate(self, state, epsilon=1e-6): + mean, log_std = self.forward(state) + std = log_std.exp() + + normal = Normal(mean, std) + z = normal.sample() + action = torch.tanh(z) + + log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon) + log_prob = log_prob.sum(-1, keepdim=True) + + return action, log_prob, z, mean, log_std + + + def get_action(self, state): + state = torch.FloatTensor(state).unsqueeze(0).to(device) + mean, log_std = self.forward(state) + std = log_std.exp() + + normal = Normal(mean, std) + z = normal.sample() + action = torch.tanh(z) + + action = action.detach().cpu().numpy() + return action[0] \ No newline at end of file diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_policy b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_policy new file mode 100644 index 0000000..ce119d4 Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_policy differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_policy_optimizer b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_policy_optimizer new file mode 100644 index 0000000..348eca7 Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_policy_optimizer differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_soft_q b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_soft_q new file mode 100644 index 0000000..3c4f237 Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_soft_q differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_soft_q_optimizer b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_soft_q_optimizer new file mode 100644 index 0000000..52fac5b Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_soft_q_optimizer differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_value b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_value new file mode 100644 index 0000000..11989ad Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_value differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_value_optimizer b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_value_optimizer new file mode 100644 index 0000000..1d9500b Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/models/sac_value_optimizer differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/ma_rewards_train.npy b/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/ma_rewards_train.npy new file mode 100644 index 0000000..b3676ce Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/ma_rewards_train.npy differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/rewards_curve_train.png b/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/rewards_curve_train.png new file mode 100644 index 0000000..b870654 Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/rewards_curve_train.png differ diff --git a/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/rewards_train.npy b/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/rewards_train.npy new file mode 100644 index 0000000..73336b5 Binary files /dev/null and b/codes/SAC/outputs/Pendulum-v0/20210429-135700/results/rewards_train.npy differ diff --git a/codes/SAC/task0_train.py b/codes/SAC/task0_train.py new file mode 100644 index 0000000..6956baa --- /dev/null +++ b/codes/SAC/task0_train.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: JiangJi +Email: johnjim0816@gmail.com +Date: 2021-04-29 12:59:22 +LastEditor: JiangJi +LastEditTime: 2021-04-29 13:56:56 +Discription: +Environment: +''' + + +import sys,os +curr_path = os.path.dirname(__file__) +parent_path = os.path.dirname(curr_path) +sys.path.append(parent_path) # add current terminal path to sys.path + + +import gym +import torch +import datetime + +from SAC.env import NormalizedActions +from SAC.agent import SAC +from common.utils import save_results, make_dir +from common.plot import plot_rewards + +curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time + +class SACConfig: + def __init__(self) -> None: + self.algo = 'SAC' + self.env = 'Pendulum-v0' + self.result_path = curr_path+"/outputs/" +self.env+'/'+curr_time+'/results/' # path to save results + self.model_path = curr_path+"/outputs/" +self.env+'/'+curr_time+'/models/' # path to save models + self.train_eps = 300 + self.train_steps = 500 + + self.gamma = 0.99 + self.mean_lambda=1e-3 + self.std_lambda=1e-3 + self.z_lambda=0.0 + self.soft_tau=1e-2 + self.value_lr = 3e-4 + self.soft_q_lr = 3e-4 + self.policy_lr = 3e-4 + self.capacity = 1000000 + self.hidden_dim = 256 + self.batch_size = 128 + self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu") +def train(cfg,env,agent): + rewards = [] + ma_rewards = [] # moveing average reward + for i_ep in range(cfg.train_eps): + state = env.reset() + ep_reward = 0 + for i_step in range(cfg.train_steps): + action = agent.policy_net.get_action(state) + next_state, reward, done, _ = env.step(action) + agent.memory.push(state, action, reward, next_state, done) + agent.update() + state = next_state + ep_reward += reward + if done: + break + print(f"Episode:{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.3f}") + rewards.append(ep_reward) + if ma_rewards: + ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward) + else: + ma_rewards.append(ep_reward) + return rewards, ma_rewards +if __name__ == "__main__": + cfg=SACConfig() + env = NormalizedActions(gym.make("Pendulum-v0")) + action_dim = env.action_space.shape[0] + state_dim = env.observation_space.shape[0] + agent = SAC(state_dim,action_dim,cfg) + rewards,ma_rewards = train(cfg,env,agent) + make_dir(cfg.result_path,cfg.model_path) + agent.save(path=cfg.model_path) + save_results(rewards,ma_rewards,tag='train',path=cfg.result_path) + plot_rewards(rewards,ma_rewards,tag="train",env=cfg.env,algo = cfg.algo,path=cfg.result_path) + + + + + diff --git a/codes/common/multiprocessing_env.py b/codes/common/multiprocessing_env.py new file mode 100644 index 0000000..04b4e3c --- /dev/null +++ b/codes/common/multiprocessing_env.py @@ -0,0 +1,153 @@ +#This code is from openai baseline +#https://github.com/openai/baselines/tree/master/baselines/common/vec_env + +import numpy as np +from multiprocessing import Process, Pipe + +def worker(remote, parent_remote, env_fn_wrapper): + parent_remote.close() + env = env_fn_wrapper.x() + while True: + cmd, data = remote.recv() + if cmd == 'step': + ob, reward, done, info = env.step(data) + if done: + ob = env.reset() + remote.send((ob, reward, done, info)) + elif cmd == 'reset': + ob = env.reset() + remote.send(ob) + elif cmd == 'reset_task': + ob = env.reset_task() + remote.send(ob) + elif cmd == 'close': + remote.close() + break + elif cmd == 'get_spaces': + remote.send((env.observation_space, env.action_space)) + else: + raise NotImplementedError + +class VecEnv(object): + """ + An abstract asynchronous, vectorized environment. + """ + def __init__(self, num_envs, observation_space, action_space): + self.num_envs = num_envs + self.observation_space = observation_space + self.action_space = action_space + + def reset(self): + """ + Reset all the environments and return an array of + observations, or a tuple of observation arrays. + If step_async is still doing work, that work will + be cancelled and step_wait() should not be called + until step_async() is invoked again. + """ + pass + + def step_async(self, actions): + """ + Tell all the environments to start taking a step + with the given actions. + Call step_wait() to get the results of the step. + You should not call this if a step_async run is + already pending. + """ + pass + + def step_wait(self): + """ + Wait for the step taken with step_async(). + Returns (obs, rews, dones, infos): + - obs: an array of observations, or a tuple of + arrays of observations. + - rews: an array of rewards + - dones: an array of "episode done" booleans + - infos: a sequence of info objects + """ + pass + + def close(self): + """ + Clean up the environments' resources. + """ + pass + + def step(self, actions): + self.step_async(actions) + return self.step_wait() + + +class CloudpickleWrapper(object): + """ + Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) + """ + def __init__(self, x): + self.x = x + def __getstate__(self): + import cloudpickle + return cloudpickle.dumps(self.x) + def __setstate__(self, ob): + import pickle + self.x = pickle.loads(ob) + + +class SubprocVecEnv(VecEnv): + def __init__(self, env_fns, spaces=None): + """ + envs: list of gym environments to run in subprocesses + """ + self.waiting = False + self.closed = False + nenvs = len(env_fns) + self.nenvs = nenvs + self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) + self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) + for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] + for p in self.ps: + p.daemon = True # if the main process crashes, we should not cause things to hang + p.start() + for remote in self.work_remotes: + remote.close() + + self.remotes[0].send(('get_spaces', None)) + observation_space, action_space = self.remotes[0].recv() + VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + def step_async(self, actions): + for remote, action in zip(self.remotes, actions): + remote.send(('step', action)) + self.waiting = True + + def step_wait(self): + results = [remote.recv() for remote in self.remotes] + self.waiting = False + obs, rews, dones, infos = zip(*results) + return np.stack(obs), np.stack(rews), np.stack(dones), infos + + def reset(self): + for remote in self.remotes: + remote.send(('reset', None)) + return np.stack([remote.recv() for remote in self.remotes]) + + def reset_task(self): + for remote in self.remotes: + remote.send(('reset_task', None)) + return np.stack([remote.recv() for remote in self.remotes]) + + def close(self): + if self.closed: + return + if self.waiting: + for remote in self.remotes: + remote.recv() + for remote in self.remotes: + remote.send(('close', None)) + for p in self.ps: + p.join() + self.closed = True + + def __len__(self): + return self.nenvs \ No newline at end of file diff --git a/codes/test.py b/codes/test.py new file mode 100644 index 0000000..90ccec4 --- /dev/null +++ b/codes/test.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: JiangJi +Email: johnjim0816@gmail.com +Date: 2021-03-25 23:25:15 +LastEditor: JiangJi +LastEditTime: 2021-04-28 21:36:50 +Discription: +Environment: +''' +import random +dic = {0:"鳗鱼家",1:"一心",2:"bada"} +print("0:鳗鱼家,1:一心,2:bada") +print("三次随机,取最后一次选择") +for i in range(3): + if i ==2: + print(f"去{dic[random.randint(0,2)]}") + else: + print(f"不去{dic[random.randint(0,2)]}") \ No newline at end of file