update projects

This commit is contained in:
johnjim0816
2022-07-31 23:42:12 +08:00
parent e9b3e92141
commit ffab9e3028
236 changed files with 370 additions and 133 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 69 KiB

View File

@@ -1,131 +0,0 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
import gym
import torch
import numpy as np
import datetime
from common.utils import plot_rewards
from common.utils import save_results,make_dir
from ppo2 import PPO
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
class Config:
def __init__(self) -> None:
################################## 环境超参数 ###################################
self.algo_name = "PPO" # 算法名称
self.env_name = 'CartPole-v0' # 环境名称
self.continuous = False # 环境是否为连续动作
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测GPU
self.seed = 10 # 随机种子置0则不设置随机种子
self.train_eps = 200 # 训练的回合数
self.test_eps = 20 # 测试的回合数
################################################################################
################################## 算法超参数 ####################################
self.batch_size = 5 # mini-batch SGD中的批量大小
self.gamma = 0.95 # 强化学习中的折扣因子
self.n_epochs = 4
self.actor_lr = 0.0003 # actor的学习率
self.critic_lr = 0.0003 # critic的学习率
self.gae_lambda = 0.95
self.policy_clip = 0.2
self.hidden_dim = 256
self.update_fre = 20 # 策略更新频率
################################################################################
################################# 保存结果相关参数 ################################
self.result_path = curr_path+"/outputs/" + self.env_name + \
'/'+curr_time+'/results/' # 保存结果的路径
self.model_path = curr_path+"/outputs/" + self.env_name + \
'/'+curr_time+'/models/' # 保存模型的路径
self.save = True # 是否保存图片
################################################################################
def env_agent_config(cfg):
''' 创建环境和智能体
'''
env = gym.make(cfg.env_name) # 创建环境
n_states = env.observation_space.shape[0] # 状态维度
if cfg.continuous:
n_actions = env.action_space.shape[0] # 动作维度
else:
n_actions = env.action_space.n # 动作维度
agent = PPO(n_states, n_actions, cfg) # 创建智能体
if cfg.seed !=0: # 设置随机种子
torch.manual_seed(cfg.seed)
env.seed(cfg.seed)
np.random.seed(cfg.seed)
return env, agent
def train(cfg,env,agent):
print('开始训练!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
steps = 0
for i_ep in range(cfg.train_eps):
state = env.reset()
done = False
ep_reward = 0
while not done:
action, prob, val = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
steps += 1
ep_reward += reward
agent.memory.push(state, action, prob, val, reward, done)
if steps % cfg.update_fre == 0:
agent.update()
state = state_
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
if (i_ep+1)%10 == 0:
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f}")
print('完成训练!')
return rewards,ma_rewards
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.test_eps):
state = env.reset()
done = False
ep_reward = 0
while not done:
action, prob, val = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
ep_reward += reward
state = state_
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print('回合:{}/{}, 奖励:{}'.format(i_ep+1, cfg.test_eps, ep_reward))
print('完成训练!')
return rewards,ma_rewards
if __name__ == "__main__":
cfg = Config()
# 训练
env,agent = env_agent_config(cfg)
rewards, ma_rewards = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path) # 创建保存结果和模型路径的文件夹
agent.save(path=cfg.model_path)
save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="train")
# 测试
env,agent = env_agent_config(cfg)
agent.load(path=cfg.model_path)
rewards,ma_rewards = test(cfg,env,agent)
save_results(rewards,ma_rewards,tag='test',path=cfg.result_path)
plot_rewards(rewards,ma_rewards,cfg,tag="test")

5
projects/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
.DS_Store
.ipynb_checkpoints
__pycache__
.vscode
test.py

21
projects/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 John Jim
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

49
projects/README.md Normal file
View File

@@ -0,0 +1,49 @@
## 0、写在前面
本项目用于学习RL基础算法尽量做到: **注释详细**(经过很长时间的纠结,还是中文注释好了!!!)**结构清晰**。
代码结构主要分为以下几个脚本:
* ```[algorithm_name].py```:即保存算法的脚本,例如```dqn.py```,每种算法都会有一定的基础模块,例如```Replay Buffer```、```MLP```(多层感知机)等等;
* ```task.py```: 即保存任务的脚本,基本包括基于```argparse```模块的参数,训练以及测试函数等等;
* ```utils.py```:该脚本用于保存诸如存储结果以及画图的软件,在实际项目或研究中,推荐大家使用```Tensorboard```来保存结果,然后使用诸如```matplotlib```以及```seabron```来进一步画图。
## 运行环境
python 3.7、pytorch 1.6.0-1.9.0、gym 0.21.0
或者在```README.md```目录下执行以下命令复现环境:
```bash
conda env create -f environment.yaml
```
## 使用说明
直接运行带有```train```的py文件或ipynb文件会进行训练默认的任务
也可以运行带有```task```的py文件训练不同的任务
## 内容导航
| 算法名称 | 相关论文材料 | 环境 | 备注 |
| :--------------------------------------: | :----------------------------------------------------------: | ----------------------------------------- | :--------------------------------: |
| [On-Policy First-Visit MC](./MonteCarlo) | [medium blog](https://medium.com/analytics-vidhya/monte-carlo-methods-in-reinforcement-learning-part-1-on-policy-methods-1f004d59686a) | [Racetrack](./envs/racetrack_env.md) | |
| [Q-Learning](./QLearning) | [towardsdatascience blog](https://towardsdatascience.com/simple-reinforcement-learning-q-learning-fcddc4b6fe56),[q learning paper](https://ieeexplore.ieee.org/document/8836506) | [CliffWalking-v0](./envs/gym_info.md) | |
| [Sarsa](./Sarsa) | [geeksforgeeks blog](https://www.geeksforgeeks.org/sarsa-reinforcement-learning/) | [Racetrack](./envs/racetrack_env.md) | |
| [DQN](./DQN) | [DQN Paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf),[Nature DQN Paper](https://www.nature.com/articles/nature14236) | [CartPole-v0](./envs/gym_info.md) | |
| [DQN-cnn](./DQN_cnn) | [DQN Paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) | [CartPole-v0](./envs/gym_info.md) | 与DQN相比使用了CNN而不是全链接网络 |
| [DoubleDQN](./DoubleDQN) | [DoubleDQN Paper](https://arxiv.org/abs/1509.06461) | [CartPole-v0](./envs/gym_info.md) | |
| [Hierarchical DQN](HierarchicalDQN) | [H-DQN Paper](https://arxiv.org/abs/1604.06057) | [CartPole-v0](./envs/gym_info.md) | |
| [PolicyGradient](./PolicyGradient) | [Lil'log](https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html) | [CartPole-v0](./envs/gym_info.md) | |
| [A2C](./A2C) | [A3C Paper](https://arxiv.org/abs/1602.01783) | [CartPole-v0](./envs/gym_info.md) | |
| [SAC](./SoftActorCritic) | [SAC Paper](https://arxiv.org/abs/1801.01290) | [Pendulum-v0](./envs/gym_info.md) | |
| [PPO](./PPO) | [PPO paper](https://arxiv.org/abs/1707.06347) | [CartPole-v0](./envs/gym_info.md) | |
| [DDPG](./DDPG) | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [Pendulum-v0](./envs/gym_info.md) | |
| [TD3](./TD3) | [TD3 Paper](https://arxiv.org/abs/1802.09477) | [HalfCheetah-v2]((./envs/mujoco_info.md)) | |
## Refs
[RL-Adventure-2](https://github.com/higgsfield/RL-Adventure-2)
[RL-Adventure](https://github.com/higgsfield/RL-Adventure)
[Google 开源项目风格指南——中文版](https://zh-google-styleguide.readthedocs.io/en/latest/google-python-styleguide/python_style_rules/#comments)

View File

Before

Width:  |  Height:  |  Size: 36 KiB

After

Width:  |  Height:  |  Size: 36 KiB

View File

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 76 KiB

View File

Before

Width:  |  Height:  |  Size: 58 KiB

After

Width:  |  Height:  |  Size: 58 KiB

View File

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 37 KiB

View File

Before

Width:  |  Height:  |  Size: 17 KiB

After

Width:  |  Height:  |  Size: 17 KiB

View File

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 24 KiB

View File

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 121 KiB

View File

Before

Width:  |  Height:  |  Size: 112 KiB

After

Width:  |  Height:  |  Size: 112 KiB

View File

Before

Width:  |  Height:  |  Size: 311 KiB

After

Width:  |  Height:  |  Size: 311 KiB

View File

Before

Width:  |  Height:  |  Size: 180 KiB

After

Width:  |  Height:  |  Size: 180 KiB

View File

Before

Width:  |  Height:  |  Size: 13 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Some files were not shown because too many files have changed in this diff Show More