Merge branch 'master' of github.com:datawhalechina/easy-rl

This commit is contained in:
qiwang067
2022-08-06 11:23:59 +08:00
238 changed files with 372 additions and 135 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 69 KiB

View File

@@ -1,131 +0,0 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
import gym
import torch
import numpy as np
import datetime
from common.utils import plot_rewards
from common.utils import save_results,make_dir
from ppo2 import PPO
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
class Config:
def __init__(self) -> None:
################################## 环境超参数 ###################################
self.algo_name = "PPO" # 算法名称
self.env_name = 'CartPole-v0' # 环境名称
self.continuous = False # 环境是否为连续动作
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测GPU
self.seed = 10 # 随机种子置0则不设置随机种子
self.train_eps = 200 # 训练的回合数
self.test_eps = 20 # 测试的回合数
################################################################################
################################## 算法超参数 ####################################
self.batch_size = 5 # mini-batch SGD中的批量大小
self.gamma = 0.95 # 强化学习中的折扣因子
self.n_epochs = 4
self.actor_lr = 0.0003 # actor的学习率
self.critic_lr = 0.0003 # critic的学习率
self.gae_lambda = 0.95
self.policy_clip = 0.2
self.hidden_dim = 256
self.update_fre = 20 # 策略更新频率
################################################################################
################################# 保存结果相关参数 ################################
self.result_path = curr_path+"/outputs/" + self.env_name + \
'/'+curr_time+'/results/' # 保存结果的路径
self.model_path = curr_path+"/outputs/" + self.env_name + \
'/'+curr_time+'/models/' # 保存模型的路径
self.save = True # 是否保存图片
################################################################################
def env_agent_config(cfg):
''' 创建环境和智能体
'''
env = gym.make(cfg.env_name) # 创建环境
n_states = env.observation_space.shape[0] # 状态维度
if cfg.continuous:
n_actions = env.action_space.shape[0] # 动作维度
else:
n_actions = env.action_space.n # 动作维度
agent = PPO(n_states, n_actions, cfg) # 创建智能体
if cfg.seed !=0: # 设置随机种子
torch.manual_seed(cfg.seed)
env.seed(cfg.seed)
np.random.seed(cfg.seed)
return env, agent
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
steps = 0
for i_ep in range(cfg.train_eps):
state = env.reset()
done = False
ep_reward = 0
while not done:
action, prob, val = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
steps += 1
ep_reward += reward
agent.memory.push(state, action, prob, val, reward, done)
if steps % cfg.update_fre == 0:
agent.update()
state = state_
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
if (i_ep+1)%10 == 0:
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f}")
print('完成训练!')
return rewards,ma_rewards
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.test_eps):
state = env.reset()
done = False
ep_reward = 0
while not done:
action, prob, val = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
ep_reward += reward
state = state_
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print('回合:{}/{}, 奖励:{}'.format(i_ep+1, cfg.test_eps, ep_reward))
print('完成训练!')
return rewards,ma_rewards
if __name__ == "__main__":
cfg = Config()
# 训练
env,agent = env_agent_config(cfg)
rewards, ma_rewards = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path) # 创建保存结果和模型路径的文件夹
agent.save(path=cfg.model_path)
save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="train")
# 测试
env,agent = env_agent_config(cfg)
agent.load(path=cfg.model_path)
rewards,ma_rewards = test(cfg,env,agent)
save_results(rewards,ma_rewards,tag='test',path=cfg.result_path)
plot_rewards(rewards,ma_rewards,cfg,tag="test")

View File

@@ -73,7 +73,7 @@ PDF版本是全书初稿人民邮电出版社的编辑老师们对初稿进
| [第十三章 AlphaStar 论文解读](https://datawhalechina.github.io/easy-rl/#/chapter13/chapter13) | | |
## 算法实战
[点击](https://github.com/datawhalechina/easy-rl/tree/master/codes)或者跳转```codes```文件夹下进入算法实战
[点击](../projects)或者跳转```projects```文件夹下进入算法实战
## 贡献者

View File

@@ -150,7 +150,7 @@ $$
1[DeepMind 研发的走路的智能体](https://www.youtube.com/watch?v=gn4nRCC9TwQ)。这个智能体往前走一步,就会得到一个奖励。这个智能体有不 同的形态,可以学到很多有意思的功能。比如,像人一样的智能体学习怎么在曲折的道路上往前走。结果 非常有意思,这个智能体会把手举得非常高,因为举手可以让它的身体保持平衡,它就可以更快地在环境 里面往前走。而且我们也可以增加环境的难度,加入一些扰动,智能体就会变得更鲁棒。
2[机械臂抓取](https://ai.googleblog.com/2016/03/deep-learning-for-robots-learning-from.htm)。因为我们把强化学习应用到机械臂自动抓取需要大量的预演,所以我们可以使用多 个机械臂进行训练。分布式系统可以让机械臂尝试抓取不同的物体,盘子里面物体的形状是不同的,这样 就可以让机械臂学到一个统一的动作,然后针对不同的抓取物都可以使用最优的抓取算法。因为抓取的物 体形状的差别很大,所以使用一些传统的抓取算法不能把所有物体都抓起来。传统的抓取算法对每一个物 体都需要建模,这样是非常费时的。但通过强化学习,我们可以学到一个统一的抓取算法,其适用于不同 的物体。
2[机械臂抓取](https://ai.googleblog.com/2016/03/deep-learning-for-robots-learning-from.html)。因为我们把强化学习应用到机械臂自动抓取需要大量的预演,所以我们可以使用多 个机械臂进行训练。分布式系统可以让机械臂尝试抓取不同的物体,盘子里面物体的形状是不同的,这样 就可以让机械臂学到一个统一的动作,然后针对不同的抓取物都可以使用最优的抓取算法。因为抓取的物 体形状的差别很大,所以使用一些传统的抓取算法不能把所有物体都抓起来。传统的抓取算法对每一个物 体都需要建模,这样是非常费时的。但通过强化学习,我们可以学到一个统一的抓取算法,其适用于不同 的物体。
3[OpenAI 的机械臂翻魔方](https://www.youtube.com/watch?v=jwSbzNHGflM)。OpenAI 在 2018 年的时候设计了一款带有“手指”的机械臂,它可以 通过翻动手指使得手中的木块达到预期的设定。人的手指其实非常灵活,怎么使得机械臂的手指也具有这 样灵活的能力一直是个问题。OpenAI 先在一个虚拟环境里面使用强化学习对智能体进行训练,再把它应 用到真实的机械臂上。这在强化学习里面是一种比较常用的做法,即我们先在虚拟环境里面得到一个很好 的智能体,然后把它应用到真实的机器人中。这是因为真实的机械臂通常非常容易坏,而且非常贵,一般 情况下没办法大批量地购买。OpenAI 在 2019 年对其机械臂进行了进一步的改进,这个机械臂在改进后 可以玩魔方了。

5
projects/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
.DS_Store
.ipynb_checkpoints
__pycache__
.vscode
test.py

21
projects/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 John Jim
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

49
projects/README.md Normal file
View File

@@ -0,0 +1,49 @@
## 0、写在前面
本项目用于学习RL基础算法尽量做到: **注释详细**(经过很长时间的纠结,还是中文注释好了!!!)**结构清晰**。
代码结构主要分为以下几个脚本:
* ```[algorithm_name].py```:即保存算法的脚本,例如```dqn.py```,每种算法都会有一定的基础模块,例如```Replay Buffer```、```MLP```(多层感知机)等等;
* ```task.py```: 即保存任务的脚本,基本包括基于```argparse```模块的参数,训练以及测试函数等等;
* ```utils.py```:该脚本用于保存诸如存储结果以及画图的软件,在实际项目或研究中,推荐大家使用```Tensorboard```来保存结果,然后使用诸如```matplotlib```以及```seabron```来进一步画图。
## 运行环境
python 3.7、pytorch 1.6.0-1.9.0、gym 0.21.0
或者在```README.md```目录下执行以下命令复现环境:
```bash
conda env create -f environment.yaml
```
## 使用说明
直接运行带有```train```的py文件或ipynb文件会进行训练默认的任务
也可以运行带有```task```的py文件训练不同的任务
## 内容导航
| 算法名称 | 相关论文材料 | 环境 | 备注 |
| :--------------------------------------: | :----------------------------------------------------------: | ----------------------------------------- | :--------------------------------: |
| [On-Policy First-Visit MC](./MonteCarlo) | [medium blog](https://medium.com/analytics-vidhya/monte-carlo-methods-in-reinforcement-learning-part-1-on-policy-methods-1f004d59686a) | [Racetrack](./envs/racetrack_env.md) | |
| [Q-Learning](./QLearning) | [towardsdatascience blog](https://towardsdatascience.com/simple-reinforcement-learning-q-learning-fcddc4b6fe56),[q learning paper](https://ieeexplore.ieee.org/document/8836506) | [CliffWalking-v0](./envs/gym_info.md) | |
| [Sarsa](./Sarsa) | [geeksforgeeks blog](https://www.geeksforgeeks.org/sarsa-reinforcement-learning/) | [Racetrack](./envs/racetrack_env.md) | |
| [DQN](./DQN) | [DQN Paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf),[Nature DQN Paper](https://www.nature.com/articles/nature14236) | [CartPole-v0](./envs/gym_info.md) | |
| [DQN-cnn](./DQN_cnn) | [DQN Paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) | [CartPole-v0](./envs/gym_info.md) | 与DQN相比使用了CNN而不是全链接网络 |
| [DoubleDQN](./DoubleDQN) | [DoubleDQN Paper](https://arxiv.org/abs/1509.06461) | [CartPole-v0](./envs/gym_info.md) | |
| [Hierarchical DQN](HierarchicalDQN) | [H-DQN Paper](https://arxiv.org/abs/1604.06057) | [CartPole-v0](./envs/gym_info.md) | |
| [PolicyGradient](./PolicyGradient) | [Lil'log](https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html) | [CartPole-v0](./envs/gym_info.md) | |
| [A2C](./A2C) | [A3C Paper](https://arxiv.org/abs/1602.01783) | [CartPole-v0](./envs/gym_info.md) | |
| [SAC](./SoftActorCritic) | [SAC Paper](https://arxiv.org/abs/1801.01290) | [Pendulum-v0](./envs/gym_info.md) | |
| [PPO](./PPO) | [PPO paper](https://arxiv.org/abs/1707.06347) | [CartPole-v0](./envs/gym_info.md) | |
| [DDPG](./DDPG) | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [Pendulum-v0](./envs/gym_info.md) | |
| [TD3](./TD3) | [TD3 Paper](https://arxiv.org/abs/1802.09477) | [HalfCheetah-v2]((./envs/mujoco_info.md)) | |
## Refs
[RL-Adventure-2](https://github.com/higgsfield/RL-Adventure-2)
[RL-Adventure](https://github.com/higgsfield/RL-Adventure)
[Google 开源项目风格指南——中文版](https://zh-google-styleguide.readthedocs.io/en/latest/google-python-styleguide/python_style_rules/#comments)

View File

Before

Width:  |  Height:  |  Size: 36 KiB

After

Width:  |  Height:  |  Size: 36 KiB

View File

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 76 KiB

View File

Before

Width:  |  Height:  |  Size: 58 KiB

After

Width:  |  Height:  |  Size: 58 KiB

View File

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 37 KiB

View File

Before

Width:  |  Height:  |  Size: 17 KiB

After

Width:  |  Height:  |  Size: 17 KiB

View File

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 24 KiB

View File

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 121 KiB

View File

Before

Width:  |  Height:  |  Size: 112 KiB

After

Width:  |  Height:  |  Size: 112 KiB

View File

Before

Width:  |  Height:  |  Size: 311 KiB

After

Width:  |  Height:  |  Size: 311 KiB

View File

Before

Width:  |  Height:  |  Size: 180 KiB

After

Width:  |  Height:  |  Size: 180 KiB

View File

Before

Width:  |  Height:  |  Size: 13 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Some files were not shown because too many files have changed in this diff Show More