hot update
@@ -23,52 +23,71 @@
|
||||
|
||||
注:点击对应的名称会跳到[codes](./codes/)下对应的算法中,其他版本还请读者自行翻阅
|
||||
|
||||
| 算法名称 | 参考文献 | 环境 | 备注 |
|
||||
| :-----------------------: | :----------------------------------------------------------: | :--: | :--: |
|
||||
| | | | |
|
||||
| DQN-CNN | | | 待更 |
|
||||
| [SoftQ](codes/SoftQ) | [Soft Q-learning paper](https://arxiv.org/abs/1702.08165) | | |
|
||||
| [SAC](codes/SAC) | [SAC paper](https://arxiv.org/pdf/1812.05905.pdf) | | |
|
||||
| [SAC-Discrete](codes/SAC) | [SAC-Discrete paper](https://arxiv.org/pdf/1910.07207.pdf) | | |
|
||||
| SAC-V | [SAC-V paper](https://arxiv.org/abs/1801.01290) | | |
|
||||
| DSAC | [DSAC paper](https://paperswithcode.com/paper/addressing-value-estimation-errors-in) | | 待更 |
|
||||
| 算法名称 | 参考文献 | 备注 |
|
||||
| :-----------------------: | :----------------------------------------------------------: | :--: |
|
||||
| | | |
|
||||
| DQN-CNN | | 待更 |
|
||||
| [SoftQ](codes/SoftQ) | [Soft Q-learning paper](https://arxiv.org/abs/1702.08165) | |
|
||||
| [SAC](codes/SAC) | [SAC paper](https://arxiv.org/pdf/1812.05905.pdf) | |
|
||||
| [SAC-Discrete](codes/SAC) | [SAC-Discrete paper](https://arxiv.org/pdf/1910.07207.pdf) | |
|
||||
| SAC-S | [SAC-S paper](https://arxiv.org/abs/1801.01290) | |
|
||||
| DSAC | [DSAC paper](https://paperswithcode.com/paper/addressing-value-estimation-errors-in) | 待更 |
|
||||
|
||||
## 3、算法环境
|
||||
|
||||
算法环境说明请跳转[env](./codes/envs/README.md)
|
||||
|
||||
## 3、运行环境
|
||||
## 4、运行环境
|
||||
|
||||
Python 3.7、PyTorch 1.10.0、Gym 0.21.0
|
||||
主要依赖:Python 3.7、PyTorch 1.10.0、Gym 0.21.0。
|
||||
|
||||
在项目根目录下执行以下命令复现环境:
|
||||
### 4.1、创建Conda环境
|
||||
```bash
|
||||
conda create -n easyrl python=3.7
|
||||
conda activate easyrl # 激活环境
|
||||
```
|
||||
### 4.2、安装Torch
|
||||
|
||||
安装CPU版本:
|
||||
```bash
|
||||
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cpuonly -c pytorch
|
||||
```
|
||||
安装CUDA版本:
|
||||
```bash
|
||||
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge
|
||||
```
|
||||
如果安装Torch需要镜像加速的话,点击[清华镜像链接](https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/),选择对应的操作系统,如```win-64```,然后复制链接,执行:
|
||||
```bash
|
||||
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/win-64/
|
||||
```
|
||||
也可以使用PiP镜像安装(仅限CUDA版本):
|
||||
```bash
|
||||
pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 torchaudio==0.10.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
```
|
||||
### 4.3、安装其他依赖
|
||||
|
||||
项目根目录下执行:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
如果需要使用CUDA,则需另外安装```cudatoolkit```,推荐```10.2```或者```11.3```版本的CUDA,如下:
|
||||
```bash
|
||||
conda install cudatoolkit=11.3 -c pytorch
|
||||
```
|
||||
如果conda需要镜像加速安装的话,点击[该清华镜像链接](https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/),选择对应的操作系统,比如```win-64```,然后复制链接,执行如下命令:
|
||||
```bash
|
||||
conda install cudatoolkit=11.3 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/win-64/
|
||||
```
|
||||
执行以下Python脚本,如果返回True说明cuda安装成功:
|
||||
### 4.4、检验CUDA版本Torch安装
|
||||
|
||||
CPU版本Torch请忽略此步,执行如下Python脚本,如果返回True说明CUDA版本安装成功:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
```
|
||||
如果还是不成功,可以使用pip安装:
|
||||
```bash
|
||||
pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 torchaudio==0.10.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
```
|
||||
## 4、使用说明
|
||||
|
||||
## 5、使用说明
|
||||
|
||||
对于[codes](./codes/):
|
||||
* 运行带有task的py脚本
|
||||
* 运行带有```main.py```脚本
|
||||
* 执行[scripts](codes\scripts)下对应的Bash脚本,例如```sh codes/scripts/DQN_task0.sh```,推荐创建名为"easyrl"的conda环境,否则需要更改sh脚本相关信息。对于Windows系统,建议安装Git(不要更改默认安装路径,否则VS Code可能不会显示Git Bash)然后使用git bash终端,而非PowerShell或者cmd终端!
|
||||
|
||||
对于[Jupyter Notebook](./notebooks/):
|
||||
|
||||
* 直接运行对应的ipynb文件就行
|
||||
|
||||
## 5、友情说明
|
||||
## 6、友情说明
|
||||
|
||||
推荐使用VS Code做项目,入门可参考[VSCode上手指南](https://blog.csdn.net/JohnJim0/article/details/126366454)
|
||||
@@ -28,6 +28,8 @@
|
||||
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{6}{algorithm.}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {6}SoftQ算法}{7}{section.6}\protected@file@percent }
|
||||
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{7}{algorithm.}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {7}SAC算法}{8}{section.7}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {7}SAC-S算法}{8}{section.7}\protected@file@percent }
|
||||
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{8}{algorithm.}\protected@file@percent }
|
||||
\gdef \@abspage@last{8}
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {8}SAC算法}{9}{section.8}\protected@file@percent }
|
||||
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{9}{algorithm.}\protected@file@percent }
|
||||
\gdef \@abspage@last{9}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
This is XeTeX, Version 3.141592653-2.6-0.999993 (TeX Live 2021) (preloaded format=xelatex 2021.8.22) 22 AUG 2022 16:54
|
||||
This is XeTeX, Version 3.141592653-2.6-0.999993 (TeX Live 2021) (preloaded format=xelatex 2021.8.22) 23 AUG 2022 19:26
|
||||
entering extended mode
|
||||
restricted \write18 enabled.
|
||||
file:line:error style messages enabled.
|
||||
@@ -415,85 +415,85 @@ Package: titlesec 2019/10/16 v2.13 Sectioning titles
|
||||
) (./pseudocodes.aux)
|
||||
\openout1 = `pseudocodes.aux'.
|
||||
|
||||
LaTeX Font Info: Checking defaults for OML/cmm/m/it on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for OMS/cmsy/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for OT1/cmr/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for T1/cmr/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for TS1/cmr/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for TU/lmr/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for OMX/cmex/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for U/cmr/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for PD1/pdf/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for PU/pdf/m/n on input line 13.
|
||||
LaTeX Font Info: ... okay on input line 13.
|
||||
LaTeX Font Info: Checking defaults for OML/cmm/m/it on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for OMS/cmsy/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for OT1/cmr/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for T1/cmr/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for TS1/cmr/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for TU/lmr/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for OMX/cmex/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for U/cmr/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for PD1/pdf/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
LaTeX Font Info: Checking defaults for PU/pdf/m/n on input line 14.
|
||||
LaTeX Font Info: ... okay on input line 14.
|
||||
ABD: EverySelectfont initializing macros
|
||||
LaTeX Info: Redefining \selectfont on input line 13.
|
||||
LaTeX Info: Redefining \selectfont on input line 14.
|
||||
|
||||
Package fontspec Info: Adjusting the maths setup (use [no-math] to avoid
|
||||
(fontspec) this).
|
||||
|
||||
\symlegacymaths=\mathgroup6
|
||||
LaTeX Font Info: Overwriting symbol font `legacymaths' in version `bold'
|
||||
(Font) OT1/cmr/m/n --> OT1/cmr/bx/n on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \acute on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \grave on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \ddot on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \tilde on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \bar on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \breve on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \check on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \hat on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \dot on input line 13.
|
||||
LaTeX Font Info: Redeclaring math accent \mathring on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Gamma on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Delta on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Theta on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Lambda on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Xi on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Pi on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Sigma on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Upsilon on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Phi on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Psi on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \Omega on input line 13.
|
||||
LaTeX Font Info: Redeclaring math symbol \mathdollar on input line 13.
|
||||
LaTeX Font Info: Redeclaring symbol font `operators' on input line 13.
|
||||
(Font) OT1/cmr/m/n --> OT1/cmr/bx/n on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \acute on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \grave on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \ddot on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \tilde on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \bar on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \breve on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \check on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \hat on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \dot on input line 14.
|
||||
LaTeX Font Info: Redeclaring math accent \mathring on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Gamma on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Delta on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Theta on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Lambda on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Xi on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Pi on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Sigma on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Upsilon on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Phi on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Psi on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \Omega on input line 14.
|
||||
LaTeX Font Info: Redeclaring math symbol \mathdollar on input line 14.
|
||||
LaTeX Font Info: Redeclaring symbol font `operators' on input line 14.
|
||||
LaTeX Font Info: Encoding `OT1' has changed to `TU' for symbol font
|
||||
(Font) `operators' in the math version `normal' on input line 13.
|
||||
(Font) `operators' in the math version `normal' on input line 14.
|
||||
LaTeX Font Info: Overwriting symbol font `operators' in version `normal'
|
||||
(Font) OT1/cmr/m/n --> TU/lmr/m/n on input line 13.
|
||||
(Font) OT1/cmr/m/n --> TU/lmr/m/n on input line 14.
|
||||
LaTeX Font Info: Encoding `OT1' has changed to `TU' for symbol font
|
||||
(Font) `operators' in the math version `bold' on input line 13.
|
||||
(Font) `operators' in the math version `bold' on input line 14.
|
||||
LaTeX Font Info: Overwriting symbol font `operators' in version `bold'
|
||||
(Font) OT1/cmr/bx/n --> TU/lmr/m/n on input line 13.
|
||||
(Font) OT1/cmr/bx/n --> TU/lmr/m/n on input line 14.
|
||||
LaTeX Font Info: Overwriting symbol font `operators' in version `normal'
|
||||
(Font) TU/lmr/m/n --> TU/lmr/m/n on input line 13.
|
||||
(Font) TU/lmr/m/n --> TU/lmr/m/n on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathit' in version `normal'
|
||||
(Font) OT1/cmr/m/it --> TU/lmr/m/it on input line 13.
|
||||
(Font) OT1/cmr/m/it --> TU/lmr/m/it on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathbf' in version `normal'
|
||||
(Font) OT1/cmr/bx/n --> TU/lmr/b/n on input line 13.
|
||||
(Font) OT1/cmr/bx/n --> TU/lmr/b/n on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathsf' in version `normal'
|
||||
(Font) OT1/cmss/m/n --> TU/lmss/m/n on input line 13.
|
||||
(Font) OT1/cmss/m/n --> TU/lmss/m/n on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathtt' in version `normal'
|
||||
(Font) OT1/cmtt/m/n --> TU/lmtt/m/n on input line 13.
|
||||
(Font) OT1/cmtt/m/n --> TU/lmtt/m/n on input line 14.
|
||||
LaTeX Font Info: Overwriting symbol font `operators' in version `bold'
|
||||
(Font) TU/lmr/m/n --> TU/lmr/b/n on input line 13.
|
||||
(Font) TU/lmr/m/n --> TU/lmr/b/n on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathit' in version `bold'
|
||||
(Font) OT1/cmr/bx/it --> TU/lmr/b/it on input line 13.
|
||||
(Font) OT1/cmr/bx/it --> TU/lmr/b/it on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathsf' in version `bold'
|
||||
(Font) OT1/cmss/bx/n --> TU/lmss/b/n on input line 13.
|
||||
(Font) OT1/cmss/bx/n --> TU/lmss/b/n on input line 14.
|
||||
LaTeX Font Info: Overwriting math alphabet `\mathtt' in version `bold'
|
||||
(Font) OT1/cmtt/m/n --> TU/lmtt/b/n on input line 13.
|
||||
Package hyperref Info: Link coloring OFF on input line 13.
|
||||
(Font) OT1/cmtt/m/n --> TU/lmtt/b/n on input line 14.
|
||||
Package hyperref Info: Link coloring OFF on input line 14.
|
||||
(/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/nameref.sty
|
||||
Package: nameref 2021-04-02 v2.47 Cross-referencing by name of section
|
||||
(/usr/local/texlive/2021/texmf-dist/tex/latex/refcount/refcount.sty
|
||||
@@ -503,9 +503,9 @@ Package: gettitlestring 2019/12/15 v1.6 Cleanup title references (HO)
|
||||
)
|
||||
\c@section@level=\count313
|
||||
)
|
||||
LaTeX Info: Redefining \ref on input line 13.
|
||||
LaTeX Info: Redefining \pageref on input line 13.
|
||||
LaTeX Info: Redefining \nameref on input line 13.
|
||||
LaTeX Info: Redefining \ref on input line 14.
|
||||
LaTeX Info: Redefining \pageref on input line 14.
|
||||
LaTeX Info: Redefining \nameref on input line 14.
|
||||
(./pseudocodes.out) (./pseudocodes.out)
|
||||
\@outlinefile=\write3
|
||||
\openout3 = `pseudocodes.out'.
|
||||
@@ -515,19 +515,19 @@ LaTeX Info: Redefining \nameref on input line 13.
|
||||
\openout4 = `pseudocodes.toc'.
|
||||
|
||||
LaTeX Font Info: Font shape `TU/SongtiSCLight(0)/m/sl' in size <10.95> not available
|
||||
(Font) Font shape `TU/SongtiSCLight(0)/m/it' tried instead on input line 16.
|
||||
(Font) Font shape `TU/SongtiSCLight(0)/m/it' tried instead on input line 17.
|
||||
[1
|
||||
|
||||
]
|
||||
Package hyperref Info: bookmark level for unknown algorithm defaults to 0 on input line 21.
|
||||
Package hyperref Info: bookmark level for unknown algorithm defaults to 0 on input line 22.
|
||||
[2
|
||||
|
||||
]
|
||||
LaTeX Font Info: Trying to load font information for U+msa on input line 31.
|
||||
LaTeX Font Info: Trying to load font information for U+msa on input line 32.
|
||||
(/usr/local/texlive/2021/texmf-dist/tex/latex/amsfonts/umsa.fd
|
||||
File: umsa.fd 2013/01/14 v3.01 AMS symbols A
|
||||
)
|
||||
LaTeX Font Info: Trying to load font information for U+msb on input line 31.
|
||||
LaTeX Font Info: Trying to load font information for U+msb on input line 32.
|
||||
(/usr/local/texlive/2021/texmf-dist/tex/latex/amsfonts/umsb.fd
|
||||
File: umsb.fd 2013/01/14 v3.01 AMS symbols B
|
||||
) [3
|
||||
@@ -536,38 +536,35 @@ File: umsb.fd 2013/01/14 v3.01 AMS symbols B
|
||||
|
||||
] [5
|
||||
|
||||
]
|
||||
Underfull \hbox (badness 10000) in paragraph at lines 111--112
|
||||
[] []\TU/SongtiSCLight(0)/m/n/10.95 计 算 实 际 的 $\OML/cmm/m/it/10.95 Q$ \TU/SongtiSCLight(0)/m/n/10.95 值,| 即 $\OML/cmm/m/it/10.95 y[] \OT1/cmr/m/n/10.95 =
|
||||
[]
|
||||
|
||||
[6
|
||||
] [6
|
||||
|
||||
] [7
|
||||
|
||||
] [8
|
||||
|
||||
]
|
||||
Overfull \hbox (32.54117pt too wide) in paragraph at lines 183--183
|
||||
Overfull \hbox (32.54117pt too wide) in paragraph at lines 212--212
|
||||
[][]$[]\OML/cmm/m/it/9 J[]\OT1/cmr/m/n/9 (\OML/cmm/m/it/9 ^^R\OT1/cmr/m/n/9 ) = \OMS/cmsy/m/n/9 r[]\OML/cmm/m/it/9 Q[] [] []$|
|
||||
[]
|
||||
|
||||
|
||||
Overfull \hbox (15.41673pt too wide) in paragraph at lines 184--184
|
||||
Overfull \hbox (15.41673pt too wide) in paragraph at lines 213--213
|
||||
[][]$[]\OML/cmm/m/it/9 J[]\OT1/cmr/m/n/9 (\OML/cmm/m/it/9 ^^^\OT1/cmr/m/n/9 ) = \OMS/cmsy/m/n/9 r[]\OML/cmm/m/it/9 ^^K [] [] \OT1/cmr/m/n/9 + [] \OMS/cmsy/m/n/9 r[]\OML/cmm/m/it/9 f[] []$\TU/lmr/m/n/9 ,$[][] \OT1/cmr/m/n/9 =
|
||||
[]
|
||||
|
||||
[8
|
||||
[9
|
||||
|
||||
] (./pseudocodes.aux)
|
||||
Package rerunfilecheck Info: File `pseudocodes.out' has not changed.
|
||||
(rerunfilecheck) Checksum: 4575BA7458AA23D6E696EFFE39D05727;640.
|
||||
(rerunfilecheck) Checksum: 35B5A79A86EF3BC70F1A0B3BCBEBAA13;724.
|
||||
)
|
||||
Here is how much of TeX's memory you used:
|
||||
14813 strings out of 476919
|
||||
312635 string characters out of 5821840
|
||||
653471 words of memory out of 5000000
|
||||
34563 multiletter control sequences out of 15000+600000
|
||||
413601 words of font info for 90 fonts, out of 8000000 for 9000
|
||||
14827 strings out of 476919
|
||||
313456 string characters out of 5821840
|
||||
653576 words of memory out of 5000000
|
||||
34576 multiletter control sequences out of 15000+600000
|
||||
413609 words of font info for 91 fonts, out of 8000000 for 9000
|
||||
1348 hyphenation exceptions out of 8191
|
||||
101i,13n,104p,676b,736s stack positions out of 5000i,500n,10000p,200000b,80000s
|
||||
101i,13n,104p,676b,697s stack positions out of 5000i,500n,10000p,200000b,80000s
|
||||
|
||||
Output written on pseudocodes.pdf (8 pages).
|
||||
Output written on pseudocodes.pdf (9 pages).
|
||||
|
||||
@@ -4,4 +4,5 @@
|
||||
\BOOKMARK [1][-]{section.4}{\376\377\000P\000o\000l\000i\000c\000y\000\040\000G\000r\000a\000d\000i\000e\000n\000t\173\227\154\325}{}% 4
|
||||
\BOOKMARK [1][-]{section.5}{\376\377\000D\000Q\000N\173\227\154\325}{}% 5
|
||||
\BOOKMARK [1][-]{section.6}{\376\377\000S\000o\000f\000t\000Q\173\227\154\325}{}% 6
|
||||
\BOOKMARK [1][-]{section.7}{\376\377\000S\000A\000C\173\227\154\325}{}% 7
|
||||
\BOOKMARK [1][-]{section.7}{\376\377\000S\000A\000C\000-\000S\173\227\154\325}{}% 7
|
||||
\BOOKMARK [1][-]{section.8}{\376\377\000S\000A\000C\173\227\154\325}{}% 8
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
\usepackage{titlesec}
|
||||
\usepackage{float} % 调用该包能够使用[H]
|
||||
% \pagestyle{plain} % 去除页眉,但是保留页脚编号,都去掉plain换empty
|
||||
|
||||
\begin{document}
|
||||
\tableofcontents % 目录,注意要运行两下或者vscode保存两下才能显示
|
||||
% \singlespacing
|
||||
@@ -88,7 +89,7 @@
|
||||
\clearpage
|
||||
\section{DQN算法}
|
||||
\begin{algorithm}[H] % [H]固定位置
|
||||
\floatname{algorithm}{{DQN算法}}
|
||||
\floatname{algorithm}{{DQN算法}{\hypersetup{linkcolor=white}\footnotemark}}
|
||||
\renewcommand{\thealgorithm}{} % 去掉算法标号
|
||||
\caption{}
|
||||
\renewcommand{\algorithmicrequire}{\textbf{输入:}}
|
||||
@@ -108,13 +109,17 @@
|
||||
\STATE 更新环境状态$s_{t+1} \leftarrow s_t$
|
||||
\STATE {\bfseries 更新策略:}
|
||||
\STATE 从$D$中采样一个batch的transition
|
||||
\STATE 计算实际的$Q$值,即$y_{j}= \begin{cases}r_{j} & \text {对于终止状态} s_{j+1} \\ r_{j}+\gamma \max _{a^{\prime}} Q\left(s_{j+1}, a^{\prime} ; \theta\right) & \text {对于非终止状态} s_{j+1}\end{cases}$
|
||||
\STATE 对损失 $\left(y_{j}-Q\left(s_{j}, a_{j} ; \theta\right)\right)^{2}$关于参数$\theta$做随机梯度下降
|
||||
\STATE 计算实际的$Q$值,即$y_{j}${\hypersetup{linkcolor=white}\footnotemark}
|
||||
\STATE 对损失 $L(\theta)=\left(y_{i}-Q\left(s_{i}, a_{i} ; \theta\right)\right)^{2}$关于参数$\theta$做随机梯度下降{\hypersetup{linkcolor=white}\footnotemark}
|
||||
\ENDFOR
|
||||
\STATE 每$C$个回合复制参数$\hat{Q}\leftarrow Q$(此处也可像原论文中放到小循环中改成每$C$步,但没有每$C$个回合稳定)
|
||||
\STATE 每$C$个回合复制参数$\hat{Q}\leftarrow Q${\hypersetup{linkcolor=white}\footnotemark}
|
||||
\ENDFOR
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
\footnotetext[1]{Playing Atari with Deep Reinforcement Learning}
|
||||
\footnotetext[2]{$y_{i}= \begin{cases}r_{i} & \text {对于终止状态} s_{i+1} \\ r_{i}+\gamma \max _{a^{\prime}} Q\left(s_{i+1}, a^{\prime} ; \theta\right) & \text {对于非终止状态} s_{i+1}\end{cases}$}
|
||||
\footnotetext[3]{$\theta_i \leftarrow \theta_i - \lambda \nabla_{\theta_{i}} L_{i}\left(\theta_{i}\right)$}
|
||||
\footnotetext[4]{此处也可像原论文中放到小循环中改成每$C$步,但没有每$C$个回合稳定}
|
||||
\clearpage
|
||||
|
||||
\section{SoftQ算法}
|
||||
@@ -153,13 +158,37 @@
|
||||
\footnotetext[2]{$J_{Q}(\theta)=\mathbb{E}_{\mathbf{s}_{t} \sim q_{\mathbf{s}_{t}}, \mathbf{a}_{t} \sim q_{\mathbf{a}_{t}}}\left[\frac{1}{2}\left(\hat{Q}_{\mathrm{soft}}^{\bar{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-Q_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right)^{2}\right]$}
|
||||
\footnotetext[3]{$\begin{aligned} \Delta f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)=& \mathbb{E}_{\mathbf{a}_{t} \sim \pi^{\phi}}\left[\left.\kappa\left(\mathbf{a}_{t}, f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)\right) \nabla_{\mathbf{a}^{\prime}} Q_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}, \mathbf{a}^{\prime}\right)\right|_{\mathbf{a}^{\prime}=\mathbf{a}_{t}}\right.\\ &\left.+\left.\alpha \nabla_{\mathbf{a}^{\prime}} \kappa\left(\mathbf{a}^{\prime}, f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)\right)\right|_{\mathbf{a}^{\prime}=\mathbf{a}_{t}}\right] \end{aligned}$}
|
||||
\clearpage
|
||||
\section{SAC-S算法}
|
||||
\begin{algorithm}[H] % [H]固定位置
|
||||
\floatname{algorithm}{{SAC-S算法}\footnotemark[1]}
|
||||
\renewcommand{\thealgorithm}{} % 去掉算法标号
|
||||
\caption{}
|
||||
\begin{algorithmic}[1] % [1]显示步数
|
||||
\STATE 初始化参数$\psi, \bar{\psi}, \theta, \phi$
|
||||
\FOR {回合数 = $1,M$}
|
||||
\FOR {时步 = $1,t$}
|
||||
\STATE 根据$\boldsymbol{a}_{t} \sim \pi_{\phi}\left(\boldsymbol{a}_{t} \mid \mathbf{s}_{t}\right)$采样动作$a_t$
|
||||
\STATE 环境反馈奖励和下一个状态,$\mathbf{s}_{t+1} \sim p\left(\mathbf{s}_{t+1} \mid \mathbf{s}_{t}, \mathbf{a}_{t}\right)$
|
||||
\STATE 存储transition到经验回放中,$\mathcal{D} \leftarrow \mathcal{D} \cup\left\{\left(\mathbf{s}_{t}, \mathbf{a}_{t}, r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right), \mathbf{s}_{t+1}\right)\right\}$
|
||||
\STATE 更新环境状态$s_{t+1} \leftarrow s_t$
|
||||
\STATE {\bfseries 更新策略:}
|
||||
\STATE $\psi \leftarrow \psi-\lambda_{V} \hat{\nabla}_{\psi} J_{V}(\psi)$
|
||||
\STATE $\theta_{i} \leftarrow \theta_{i}-\lambda_{Q} \hat{\nabla}_{\theta_{i}} J_{Q}\left(\theta_{i}\right)$ for $i \in\{1,2\}$
|
||||
\STATE $\phi \leftarrow \phi-\lambda_{\pi} \hat{\nabla}_{\phi} J_{\pi}(\phi)$
|
||||
\STATE $\bar{\psi} \leftarrow \tau \psi+(1-\tau) \bar{\psi}$
|
||||
\ENDFOR
|
||||
\ENDFOR
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
\footnotetext[1]{Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor}
|
||||
\clearpage
|
||||
\section{SAC算法}
|
||||
\begin{algorithm}[H] % [H]固定位置
|
||||
\floatname{algorithm}{{Soft Actor Critic算法}}
|
||||
\floatname{algorithm}{{SAC算法}\footnotemark[1]}
|
||||
\renewcommand{\thealgorithm}{} % 去掉算法标号
|
||||
\caption{}
|
||||
\begin{algorithmic}[1]
|
||||
\STATE 初始化两个Actor的网络参数$\theta_1,\theta_2$以及一个Critic网络参数$\phi$ % 初始化
|
||||
\STATE 初始化网络参数$\theta_1,\theta_2$以及$\phi$ % 初始化
|
||||
\STATE 复制参数到目标网络$\bar{\theta_1} \leftarrow \theta_1,\bar{\theta_2} \leftarrow \theta_2,$
|
||||
\STATE 初始化经验回放$D$
|
||||
\FOR {回合数 = $1,M$}
|
||||
@@ -170,18 +199,18 @@
|
||||
\STATE 存储transition到经验回放中,$\mathcal{D} \leftarrow \mathcal{D} \cup\left\{\left(\mathbf{s}_{t}, \mathbf{a}_{t}, r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right), \mathbf{s}_{t+1}\right)\right\}$
|
||||
\STATE 更新环境状态$s_{t+1} \leftarrow s_t$
|
||||
\STATE {\bfseries 更新策略:}
|
||||
\STATE 更新$Q$函数,$\theta_{i} \leftarrow \theta_{i}-\lambda_{Q} \hat{\nabla}_{\theta_{i}} J_{Q}\left(\theta_{i}\right)$ for $i \in\{1,2\}$\footnotemark[1]\footnotemark[2]
|
||||
\STATE 更新策略权重,$\phi \leftarrow \phi-\lambda_{\pi} \hat{\nabla}_{\phi} J_{\pi}(\phi)$ \footnotemark[3]
|
||||
\STATE 调整temperature,$\alpha \leftarrow \alpha-\lambda \hat{\nabla}_{\alpha} J(\alpha)$ \footnotemark[4]
|
||||
\STATE 更新$Q$函数,$\theta_{i} \leftarrow \theta_{i}-\lambda_{Q} \hat{\nabla}_{\theta_{i}} J_{Q}\left(\theta_{i}\right)$ for $i \in\{1,2\}$\footnotemark[2]\footnotemark[3]
|
||||
\STATE 更新策略权重,$\phi \leftarrow \phi-\lambda_{\pi} \hat{\nabla}_{\phi} J_{\pi}(\phi)$ \footnotemark[4]
|
||||
\STATE 调整temperature,$\alpha \leftarrow \alpha-\lambda \hat{\nabla}_{\alpha} J(\alpha)$ \footnotemark[5]
|
||||
\STATE 更新目标网络权重,$\bar{\theta}_{i} \leftarrow \tau \theta_{i}+(1-\tau) \bar{\theta}_{i}$ for $i \in\{1,2\}$
|
||||
\ENDFOR
|
||||
\ENDFOR
|
||||
\end{algorithmic}
|
||||
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
\footnotetext[1]{$J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\left(r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\gamma \mathbb{E}_{\mathbf{s}_{t+1} \sim p}\left[V_{\bar{\theta}}\left(\mathbf{s}_{t+1}\right)\right]\right)\right)^{2}\right]$}
|
||||
\footnotetext[2]{$\hat{\nabla}_{\theta} J_{Q}(\theta)=\nabla_{\theta} Q_{\theta}\left(\mathbf{a}_{t}, \mathbf{s}_{t}\right)\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\left(r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\gamma\left(Q_{\bar{\theta}}\left(\mathbf{s}_{t+1}, \mathbf{a}_{t+1}\right)-\alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t+1} \mid \mathbf{s}_{t+1}\right)\right)\right)\right)\right.$}
|
||||
\footnotetext[3]{$\hat{\nabla}_{\phi} J_{\pi}(\phi)=\nabla_{\phi} \alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right)+\left(\nabla_{\mathbf{a}_{t}} \alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right)-\nabla_{\mathbf{a}_{t}} Q\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right) \nabla_{\phi} f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)$,$\mathbf{a}_{t}=f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)$}
|
||||
\footnotetext[4]{$J(\alpha)=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{t}}\left[-\alpha \log \pi_{t}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)-\alpha \overline{\mathcal{H}}\right]$}
|
||||
\footnotetext[2]{Soft Actor-Critic Algorithms and Applications}
|
||||
\footnotetext[2]{$J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\left(r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\gamma \mathbb{E}_{\mathbf{s}_{t+1} \sim p}\left[V_{\bar{\theta}}\left(\mathbf{s}_{t+1}\right)\right]\right)\right)^{2}\right]$}
|
||||
\footnotetext[3]{$\hat{\nabla}_{\theta} J_{Q}(\theta)=\nabla_{\theta} Q_{\theta}\left(\mathbf{a}_{t}, \mathbf{s}_{t}\right)\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\left(r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\gamma\left(Q_{\bar{\theta}}\left(\mathbf{s}_{t+1}, \mathbf{a}_{t+1}\right)-\alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t+1} \mid \mathbf{s}_{t+1}\right)\right)\right)\right)\right.$}
|
||||
\footnotetext[4]{$\hat{\nabla}_{\phi} J_{\pi}(\phi)=\nabla_{\phi} \alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right)+\left(\nabla_{\mathbf{a}_{t}} \alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right)-\nabla_{\mathbf{a}_{t}} Q\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right) \nabla_{\phi} f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)$,$\mathbf{a}_{t}=f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)$}
|
||||
\footnotetext[5]{$J(\alpha)=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{t}}\left[-\alpha \log \pi_{t}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)-\alpha \overline{\mathcal{H}}\right]$}
|
||||
\clearpage
|
||||
\end{document}
|
||||
@@ -4,4 +4,5 @@
|
||||
\contentsline {section}{\numberline {4}Policy Gradient算法}{5}{section.4}%
|
||||
\contentsline {section}{\numberline {5}DQN算法}{6}{section.5}%
|
||||
\contentsline {section}{\numberline {6}SoftQ算法}{7}{section.6}%
|
||||
\contentsline {section}{\numberline {7}SAC算法}{8}{section.7}%
|
||||
\contentsline {section}{\numberline {7}SAC-S算法}{8}{section.7}%
|
||||
\contentsline {section}{\numberline {8}SAC算法}{9}{section.8}%
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:50:49
|
||||
@LastEditor: John
|
||||
LastEditTime: 2022-08-18 14:27:18
|
||||
LastEditTime: 2022-08-23 23:59:54
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -20,26 +20,26 @@ import math
|
||||
import numpy as np
|
||||
|
||||
class DQN:
|
||||
def __init__(self,n_actions,model,memory,cfg):
|
||||
def __init__(self,model,memory,cfg):
|
||||
|
||||
self.n_actions = n_actions
|
||||
self.device = torch.device(cfg.device)
|
||||
self.gamma = cfg.gamma
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.device = torch.device(cfg['device'])
|
||||
self.gamma = cfg['gamma']
|
||||
## e-greedy parameters
|
||||
self.sample_count = 0 # sample count for epsilon decay
|
||||
self.epsilon = cfg.epsilon_start
|
||||
self.epsilon = cfg['epsilon_start']
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg.epsilon_start
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.batch_size = cfg.batch_size
|
||||
self.epsilon_start = cfg['epsilon_start']
|
||||
self.epsilon_end = cfg['epsilon_end']
|
||||
self.epsilon_decay = cfg['epsilon_decay']
|
||||
self.batch_size = cfg['batch_size']
|
||||
self.policy_net = model.to(self.device)
|
||||
self.target_net = model.to(self.device)
|
||||
## copy parameters from policy net to target net
|
||||
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
# self.target_net.load_state_dict(self.policy_net.state_dict()) # or use this to copy parameters
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg['lr'])
|
||||
self.memory = memory
|
||||
self.update_flag = False
|
||||
|
||||
|
||||
137
projects/codes/DQN/main.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import sys,os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import numpy as np
|
||||
import argparse
|
||||
from common.utils import save_results,all_seed
|
||||
from common.utils import plot_rewards,save_args
|
||||
from common.models import MLP
|
||||
from common.memories import ReplayBuffer
|
||||
from dqn import DQN
|
||||
|
||||
def get_args():
|
||||
""" hyperparameters
|
||||
"""
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='DQN',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.95,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=500,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
|
||||
parser.add_argument('--memory_capacity',default=100000,type=int,help="memory capacity")
|
||||
parser.add_argument('--batch_size',default=64,type=int)
|
||||
parser.add_argument('--target_update',default=4,type=int)
|
||||
parser.add_argument('--hidden_dim',default=256,type=int)
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
# please manually change the following args in this script if you want
|
||||
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/results' )
|
||||
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
|
||||
'/' + curr_time + '/models' )
|
||||
args = parser.parse_args()
|
||||
args = {**vars(args)} # type(dict)
|
||||
return args
|
||||
|
||||
def env_agent_config(cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
env = gym.make(cfg['env_name']) # create env
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
n_states = env.observation_space.shape[0] # state dimension
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
model = MLP(n_states,n_actions,hidden_dim=cfg["hidden_dim"])
|
||||
memory = ReplayBuffer(cfg["memory_capacity"]) # replay buffer
|
||||
agent = DQN(model,memory,cfg) # create agent
|
||||
return env, agent
|
||||
|
||||
def train(cfg, env, agent):
|
||||
''' 训练
|
||||
'''
|
||||
print("start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algo: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = []
|
||||
for i_ep in range(cfg["train_eps"]):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
ep_step += 1
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
agent.memory.push(state, action, reward,
|
||||
next_state, done) # save transitions
|
||||
state = next_state # update next state for env
|
||||
agent.update() # update agent
|
||||
ep_reward += reward #
|
||||
if done:
|
||||
break
|
||||
if (i_ep + 1) % cfg["target_update"] == 0: # target net update, target_update means "C" in pseucodes
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
steps.append(ep_step)
|
||||
rewards.append(ep_reward)
|
||||
if (i_ep + 1) % 10 == 0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
|
||||
print("finish training!")
|
||||
env.close()
|
||||
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
return res_dic
|
||||
|
||||
def test(cfg, env, agent):
|
||||
print("start testing!")
|
||||
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = []
|
||||
for i_ep in range(cfg.test_eps):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
ep_step+=1
|
||||
action = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
steps.append(ep_step)
|
||||
rewards.append(ep_reward)
|
||||
print(f'Episode: {i_ep+1}/{cfg.test_eps},Reward: {ep_reward:.2f}')
|
||||
print("finish testing!")
|
||||
env.close()
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# training
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
save_args(cfg,path = cfg['result_path']) # save parameters
|
||||
agent.save_model(path = cfg['model_path']) # save models
|
||||
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
||||
# testing
|
||||
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
agent.load_model(path = cfg['model_path']) # load model
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path = cfg['result_path'])
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "test")
|
||||
@@ -0,0 +1 @@
|
||||
{"algo_name": "DQN", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 500, "lr": 0.0001, "memory_capacity": 100000, "batch_size": 64, "target_update": 4, "hidden_dim": 256, "device": "cpu", "seed": 10, "result_path": "C:\\Users\\jiangji\\Desktop\\rl-tutorials\\codes\\DQN/outputs/CartPole-v0/20220823-173936/results", "model_path": "C:\\Users\\jiangji\\Desktop\\rl-tutorials\\codes\\DQN/outputs/CartPole-v0/20220823-173936/models", "show_fig": false, "save_fig": true}
|
||||
|
After Width: | Height: | Size: 27 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards
|
||||
0,200.0
|
||||
1,200.0
|
||||
2,200.0
|
||||
3,200.0
|
||||
4,200.0
|
||||
5,200.0
|
||||
6,200.0
|
||||
7,200.0
|
||||
8,200.0
|
||||
9,200.0
|
||||
10,200.0
|
||||
11,200.0
|
||||
12,200.0
|
||||
13,200.0
|
||||
14,200.0
|
||||
15,200.0
|
||||
16,200.0
|
||||
17,200.0
|
||||
18,200.0
|
||||
19,200.0
|
||||
|
|
After Width: | Height: | Size: 38 KiB |
@@ -0,0 +1,201 @@
|
||||
episodes,rewards
|
||||
0,38.0
|
||||
1,16.0
|
||||
2,37.0
|
||||
3,15.0
|
||||
4,22.0
|
||||
5,34.0
|
||||
6,20.0
|
||||
7,12.0
|
||||
8,16.0
|
||||
9,14.0
|
||||
10,13.0
|
||||
11,21.0
|
||||
12,14.0
|
||||
13,12.0
|
||||
14,17.0
|
||||
15,12.0
|
||||
16,10.0
|
||||
17,14.0
|
||||
18,10.0
|
||||
19,10.0
|
||||
20,16.0
|
||||
21,9.0
|
||||
22,14.0
|
||||
23,13.0
|
||||
24,10.0
|
||||
25,9.0
|
||||
26,12.0
|
||||
27,12.0
|
||||
28,14.0
|
||||
29,11.0
|
||||
30,9.0
|
||||
31,8.0
|
||||
32,9.0
|
||||
33,11.0
|
||||
34,12.0
|
||||
35,10.0
|
||||
36,11.0
|
||||
37,10.0
|
||||
38,10.0
|
||||
39,18.0
|
||||
40,13.0
|
||||
41,15.0
|
||||
42,10.0
|
||||
43,9.0
|
||||
44,14.0
|
||||
45,14.0
|
||||
46,23.0
|
||||
47,17.0
|
||||
48,15.0
|
||||
49,15.0
|
||||
50,20.0
|
||||
51,28.0
|
||||
52,36.0
|
||||
53,36.0
|
||||
54,23.0
|
||||
55,27.0
|
||||
56,53.0
|
||||
57,19.0
|
||||
58,35.0
|
||||
59,62.0
|
||||
60,57.0
|
||||
61,38.0
|
||||
62,61.0
|
||||
63,65.0
|
||||
64,58.0
|
||||
65,43.0
|
||||
66,67.0
|
||||
67,56.0
|
||||
68,91.0
|
||||
69,128.0
|
||||
70,71.0
|
||||
71,126.0
|
||||
72,100.0
|
||||
73,200.0
|
||||
74,200.0
|
||||
75,200.0
|
||||
76,200.0
|
||||
77,200.0
|
||||
78,200.0
|
||||
79,200.0
|
||||
80,200.0
|
||||
81,200.0
|
||||
82,200.0
|
||||
83,200.0
|
||||
84,200.0
|
||||
85,200.0
|
||||
86,200.0
|
||||
87,200.0
|
||||
88,200.0
|
||||
89,200.0
|
||||
90,200.0
|
||||
91,200.0
|
||||
92,200.0
|
||||
93,200.0
|
||||
94,200.0
|
||||
95,200.0
|
||||
96,200.0
|
||||
97,200.0
|
||||
98,200.0
|
||||
99,200.0
|
||||
100,200.0
|
||||
101,200.0
|
||||
102,200.0
|
||||
103,200.0
|
||||
104,200.0
|
||||
105,200.0
|
||||
106,200.0
|
||||
107,200.0
|
||||
108,200.0
|
||||
109,200.0
|
||||
110,200.0
|
||||
111,200.0
|
||||
112,200.0
|
||||
113,200.0
|
||||
114,200.0
|
||||
115,200.0
|
||||
116,200.0
|
||||
117,200.0
|
||||
118,200.0
|
||||
119,200.0
|
||||
120,200.0
|
||||
121,200.0
|
||||
122,200.0
|
||||
123,200.0
|
||||
124,200.0
|
||||
125,200.0
|
||||
126,200.0
|
||||
127,200.0
|
||||
128,200.0
|
||||
129,200.0
|
||||
130,200.0
|
||||
131,200.0
|
||||
132,200.0
|
||||
133,200.0
|
||||
134,200.0
|
||||
135,200.0
|
||||
136,200.0
|
||||
137,200.0
|
||||
138,200.0
|
||||
139,200.0
|
||||
140,200.0
|
||||
141,200.0
|
||||
142,200.0
|
||||
143,200.0
|
||||
144,200.0
|
||||
145,200.0
|
||||
146,200.0
|
||||
147,200.0
|
||||
148,200.0
|
||||
149,200.0
|
||||
150,200.0
|
||||
151,200.0
|
||||
152,200.0
|
||||
153,200.0
|
||||
154,200.0
|
||||
155,200.0
|
||||
156,200.0
|
||||
157,200.0
|
||||
158,200.0
|
||||
159,200.0
|
||||
160,200.0
|
||||
161,200.0
|
||||
162,200.0
|
||||
163,200.0
|
||||
164,200.0
|
||||
165,200.0
|
||||
166,200.0
|
||||
167,200.0
|
||||
168,200.0
|
||||
169,200.0
|
||||
170,200.0
|
||||
171,200.0
|
||||
172,200.0
|
||||
173,200.0
|
||||
174,200.0
|
||||
175,200.0
|
||||
176,200.0
|
||||
177,200.0
|
||||
178,200.0
|
||||
179,200.0
|
||||
180,200.0
|
||||
181,200.0
|
||||
182,200.0
|
||||
183,200.0
|
||||
184,200.0
|
||||
185,200.0
|
||||
186,200.0
|
||||
187,200.0
|
||||
188,200.0
|
||||
189,200.0
|
||||
190,200.0
|
||||
191,200.0
|
||||
192,200.0
|
||||
193,200.0
|
||||
194,200.0
|
||||
195,200.0
|
||||
196,200.0
|
||||
197,200.0
|
||||
198,200.0
|
||||
199,200.0
|
||||
|
153
projects/codes/QLearning/main.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-24 11:27:01
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
import sys,os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # avoid "OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized."
|
||||
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
|
||||
parent_path = os.path.dirname(curr_path) # parent path
|
||||
sys.path.append(parent_path) # add path to system path
|
||||
|
||||
import gym
|
||||
import datetime
|
||||
import argparse
|
||||
from envs.gridworld_env import CliffWalkingWapper,FrozenLakeWapper
|
||||
from qlearning import QLearning
|
||||
from common.utils import plot_rewards,save_args,all_seed
|
||||
from common.utils import save_results,make_dir
|
||||
|
||||
def get_args():
|
||||
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
|
||||
parser = argparse.ArgumentParser(description="hyperparameters")
|
||||
parser.add_argument('--algo_name',default='Q-learning',type=str,help="name of algorithm")
|
||||
parser.add_argument('--env_name',default='CliffWalking-v0',type=str,help="name of environment")
|
||||
parser.add_argument('--train_eps',default=400,type=int,help="episodes of training")
|
||||
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
|
||||
parser.add_argument('--gamma',default=0.90,type=float,help="discounted factor")
|
||||
parser.add_argument('--epsilon_start',default=0.95,type=float,help="initial value of epsilon")
|
||||
parser.add_argument('--epsilon_end',default=0.01,type=float,help="final value of epsilon")
|
||||
parser.add_argument('--epsilon_decay',default=300,type=int,help="decay rate of epsilon")
|
||||
parser.add_argument('--lr',default=0.1,type=float,help="learning rate")
|
||||
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
|
||||
parser.add_argument('--seed',default=10,type=int,help="seed")
|
||||
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
|
||||
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
|
||||
args = parser.parse_args()
|
||||
default_args = {'result_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",
|
||||
'model_path':f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",
|
||||
}
|
||||
args = {**vars(args),**default_args} # type(dict)
|
||||
return args
|
||||
def env_agent_config(cfg):
|
||||
''' create env and agent
|
||||
'''
|
||||
if cfg['env_name'] == 'CliffWalking-v0':
|
||||
env = gym.make(cfg['env_name'])
|
||||
env = CliffWalkingWapper(env)
|
||||
if cfg['env_name'] == 'FrozenLake-v1':
|
||||
env = gym.make(cfg['env_name'],is_slippery=False)
|
||||
if cfg['seed'] !=0: # set random seed
|
||||
all_seed(env,seed=cfg["seed"])
|
||||
n_states = env.observation_space.n # state dimension
|
||||
n_actions = env.action_space.n # action dimension
|
||||
print(f"n_states: {n_states}, n_actions: {n_actions}")
|
||||
cfg.update({"n_states":n_states,"n_actions":n_actions}) # update to cfg paramters
|
||||
agent = QLearning(cfg)
|
||||
return env,agent
|
||||
|
||||
def main(cfg,env,agent,tag = 'train'):
|
||||
print(f"Start {tag}ing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # 记录奖励
|
||||
for i_ep in range(cfg.train_eps):
|
||||
ep_reward = 0 # 记录每个回合的奖励
|
||||
state = env.reset() # 重置环境,即开始新的回合
|
||||
while True:
|
||||
if tag == 'train':action = agent.sample_action(state) # 根据算法采样一个动作
|
||||
else: agent.predict_action(state)
|
||||
next_state, reward, done, _ = env.step(action) # 与环境进行一次动作交互
|
||||
if tag == 'train':agent.update(state, action, reward, next_state, done) # Q学习算法更新
|
||||
state = next_state # 更新状态
|
||||
ep_reward += reward
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
print(f"回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.1f},Epsilon:{agent.epsilon}")
|
||||
print(f"Finish {tag}ing!")
|
||||
return {"rewards":rewards}
|
||||
|
||||
def train(cfg,env,agent):
|
||||
print("Start training!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['train_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0 # step per episode
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.sample_action(state) # sample action
|
||||
next_state, reward, done, _ = env.step(action) # update env and return transitions
|
||||
agent.update(state, action, reward, next_state, done) # update agent
|
||||
state = next_state # update state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
if (i_ep+1)%10==0:
|
||||
print(f'Episode: {i_ep+1}/{cfg["train_eps"]}, Reward: {ep_reward:.2f}, Steps:{ep_step}, Epislon: {agent.epsilon:.3f}')
|
||||
print("Finish training!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
def test(cfg,env,agent):
|
||||
print("Start testing!")
|
||||
print(f"Env: {cfg['env_name']}, Algorithm: {cfg['algo_name']}, Device: {cfg['device']}")
|
||||
rewards = [] # record rewards for all episodes
|
||||
steps = [] # record steps for all episodes
|
||||
for i_ep in range(cfg['test_eps']):
|
||||
ep_reward = 0 # reward per episode
|
||||
ep_step = 0
|
||||
state = env.reset() # reset and obtain initial state
|
||||
while True:
|
||||
action = agent.predict_action(state) # predict action
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
state = next_state
|
||||
ep_reward += reward
|
||||
ep_step += 1
|
||||
if done:
|
||||
break
|
||||
rewards.append(ep_reward)
|
||||
steps.append(ep_step)
|
||||
print(f"Episode: {i_ep+1}/{cfg['test_eps']}, Steps:{ep_step}, Reward: {ep_reward:.2f}")
|
||||
print("Finish testing!")
|
||||
return {'episodes':range(len(rewards)),'rewards':rewards,'steps':steps}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_args()
|
||||
# training
|
||||
env, agent = env_agent_config(cfg)
|
||||
res_dic = train(cfg, env, agent)
|
||||
save_args(cfg,path = cfg['result_path']) # save parameters
|
||||
agent.save_model(path = cfg['model_path']) # save models
|
||||
save_results(res_dic, tag = 'train', path = cfg['result_path']) # save results
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "train") # plot results
|
||||
# testing
|
||||
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
|
||||
agent.load_model(path = cfg['model_path']) # load model
|
||||
res_dic = test(cfg, env, agent)
|
||||
save_results(res_dic, tag='test',
|
||||
path = cfg['result_path'])
|
||||
plot_rewards(res_dic['rewards'], cfg, path = cfg['result_path'],tag = "test")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"algo_name": "Q-learning",
|
||||
"env_name": "CliffWalking-v0",
|
||||
"train_eps": 400,
|
||||
"test_eps": 20,
|
||||
"gamma": 0.9,
|
||||
"epsilon_start": 0.95,
|
||||
"epsilon_end": 0.01,
|
||||
"epsilon_decay": 300,
|
||||
"lr": 0.1,
|
||||
"device": "cpu",
|
||||
"seed": 10,
|
||||
"show_fig": false,
|
||||
"save_fig": true,
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/CliffWalking-v0/20220824-103255/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/CliffWalking-v0/20220824-103255/models/",
|
||||
"n_states": 48,
|
||||
"n_actions": 4
|
||||
}
|
||||
|
After Width: | Height: | Size: 24 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards
|
||||
0,-13
|
||||
1,-13
|
||||
2,-13
|
||||
3,-13
|
||||
4,-13
|
||||
5,-13
|
||||
6,-13
|
||||
7,-13
|
||||
8,-13
|
||||
9,-13
|
||||
10,-13
|
||||
11,-13
|
||||
12,-13
|
||||
13,-13
|
||||
14,-13
|
||||
15,-13
|
||||
16,-13
|
||||
17,-13
|
||||
18,-13
|
||||
19,-13
|
||||
|
|
After Width: | Height: | Size: 35 KiB |
@@ -0,0 +1,401 @@
|
||||
episodes,rewards
|
||||
0,-2131
|
||||
1,-1086
|
||||
2,-586
|
||||
3,-220
|
||||
4,-154
|
||||
5,-122
|
||||
6,-150
|
||||
7,-159
|
||||
8,-164
|
||||
9,-88
|
||||
10,-195
|
||||
11,-114
|
||||
12,-60
|
||||
13,-179
|
||||
14,-101
|
||||
15,-304
|
||||
16,-96
|
||||
17,-119
|
||||
18,-113
|
||||
19,-98
|
||||
20,-106
|
||||
21,-105
|
||||
22,-77
|
||||
23,-51
|
||||
24,-105
|
||||
25,-136
|
||||
26,-100
|
||||
27,-29
|
||||
28,-79
|
||||
29,-114
|
||||
30,-82
|
||||
31,-70
|
||||
32,-75
|
||||
33,-51
|
||||
34,-94
|
||||
35,-52
|
||||
36,-93
|
||||
37,-71
|
||||
38,-73
|
||||
39,-48
|
||||
40,-52
|
||||
41,-96
|
||||
42,-46
|
||||
43,-65
|
||||
44,-57
|
||||
45,-41
|
||||
46,-104
|
||||
47,-51
|
||||
48,-181
|
||||
49,-229
|
||||
50,-39
|
||||
51,-69
|
||||
52,-53
|
||||
53,-59
|
||||
54,-26
|
||||
55,-75
|
||||
56,-31
|
||||
57,-60
|
||||
58,-63
|
||||
59,-40
|
||||
60,-35
|
||||
61,-79
|
||||
62,-42
|
||||
63,-22
|
||||
64,-73
|
||||
65,-71
|
||||
66,-18
|
||||
67,-55
|
||||
68,-29
|
||||
69,-43
|
||||
70,-70
|
||||
71,-49
|
||||
72,-42
|
||||
73,-29
|
||||
74,-81
|
||||
75,-36
|
||||
76,-38
|
||||
77,-36
|
||||
78,-52
|
||||
79,-28
|
||||
80,-42
|
||||
81,-52
|
||||
82,-66
|
||||
83,-31
|
||||
84,-27
|
||||
85,-49
|
||||
86,-28
|
||||
87,-54
|
||||
88,-34
|
||||
89,-35
|
||||
90,-50
|
||||
91,-36
|
||||
92,-36
|
||||
93,-46
|
||||
94,-34
|
||||
95,-135
|
||||
96,-39
|
||||
97,-36
|
||||
98,-26
|
||||
99,-56
|
||||
100,-40
|
||||
101,-40
|
||||
102,-26
|
||||
103,-28
|
||||
104,-31
|
||||
105,-35
|
||||
106,-26
|
||||
107,-57
|
||||
108,-44
|
||||
109,-41
|
||||
110,-31
|
||||
111,-26
|
||||
112,-25
|
||||
113,-41
|
||||
114,-32
|
||||
115,-44
|
||||
116,-30
|
||||
117,-32
|
||||
118,-30
|
||||
119,-25
|
||||
120,-23
|
||||
121,-47
|
||||
122,-24
|
||||
123,-45
|
||||
124,-39
|
||||
125,-21
|
||||
126,-43
|
||||
127,-143
|
||||
128,-26
|
||||
129,-20
|
||||
130,-32
|
||||
131,-16
|
||||
132,-24
|
||||
133,-42
|
||||
134,-25
|
||||
135,-36
|
||||
136,-19
|
||||
137,-29
|
||||
138,-43
|
||||
139,-17
|
||||
140,-150
|
||||
141,-32
|
||||
142,-34
|
||||
143,-19
|
||||
144,-26
|
||||
145,-30
|
||||
146,-31
|
||||
147,-49
|
||||
148,-33
|
||||
149,-21
|
||||
150,-17
|
||||
151,-48
|
||||
152,-34
|
||||
153,-20
|
||||
154,-20
|
||||
155,-26
|
||||
156,-21
|
||||
157,-13
|
||||
158,-40
|
||||
159,-22
|
||||
160,-26
|
||||
161,-30
|
||||
162,-29
|
||||
163,-25
|
||||
164,-26
|
||||
165,-27
|
||||
166,-21
|
||||
167,-29
|
||||
168,-24
|
||||
169,-17
|
||||
170,-22
|
||||
171,-35
|
||||
172,-35
|
||||
173,-18
|
||||
174,-135
|
||||
175,-15
|
||||
176,-23
|
||||
177,-28
|
||||
178,-25
|
||||
179,-24
|
||||
180,-29
|
||||
181,-31
|
||||
182,-24
|
||||
183,-129
|
||||
184,-45
|
||||
185,-24
|
||||
186,-17
|
||||
187,-20
|
||||
188,-21
|
||||
189,-23
|
||||
190,-15
|
||||
191,-32
|
||||
192,-22
|
||||
193,-19
|
||||
194,-17
|
||||
195,-45
|
||||
196,-15
|
||||
197,-14
|
||||
198,-14
|
||||
199,-37
|
||||
200,-23
|
||||
201,-17
|
||||
202,-19
|
||||
203,-21
|
||||
204,-23
|
||||
205,-27
|
||||
206,-14
|
||||
207,-18
|
||||
208,-23
|
||||
209,-34
|
||||
210,-23
|
||||
211,-13
|
||||
212,-25
|
||||
213,-17
|
||||
214,-13
|
||||
215,-21
|
||||
216,-29
|
||||
217,-18
|
||||
218,-24
|
||||
219,-15
|
||||
220,-27
|
||||
221,-25
|
||||
222,-21
|
||||
223,-19
|
||||
224,-17
|
||||
225,-18
|
||||
226,-13
|
||||
227,-22
|
||||
228,-14
|
||||
229,-13
|
||||
230,-29
|
||||
231,-23
|
||||
232,-15
|
||||
233,-15
|
||||
234,-14
|
||||
235,-28
|
||||
236,-25
|
||||
237,-17
|
||||
238,-23
|
||||
239,-29
|
||||
240,-15
|
||||
241,-14
|
||||
242,-15
|
||||
243,-23
|
||||
244,-15
|
||||
245,-16
|
||||
246,-19
|
||||
247,-13
|
||||
248,-16
|
||||
249,-17
|
||||
250,-25
|
||||
251,-30
|
||||
252,-13
|
||||
253,-14
|
||||
254,-15
|
||||
255,-22
|
||||
256,-14
|
||||
257,-17
|
||||
258,-126
|
||||
259,-15
|
||||
260,-21
|
||||
261,-16
|
||||
262,-23
|
||||
263,-14
|
||||
264,-13
|
||||
265,-13
|
||||
266,-19
|
||||
267,-13
|
||||
268,-19
|
||||
269,-17
|
||||
270,-17
|
||||
271,-13
|
||||
272,-19
|
||||
273,-13
|
||||
274,-13
|
||||
275,-16
|
||||
276,-22
|
||||
277,-14
|
||||
278,-15
|
||||
279,-19
|
||||
280,-34
|
||||
281,-13
|
||||
282,-15
|
||||
283,-32
|
||||
284,-13
|
||||
285,-13
|
||||
286,-13
|
||||
287,-14
|
||||
288,-16
|
||||
289,-13
|
||||
290,-13
|
||||
291,-17
|
||||
292,-13
|
||||
293,-13
|
||||
294,-22
|
||||
295,-14
|
||||
296,-15
|
||||
297,-13
|
||||
298,-13
|
||||
299,-13
|
||||
300,-16
|
||||
301,-13
|
||||
302,-14
|
||||
303,-13
|
||||
304,-13
|
||||
305,-13
|
||||
306,-24
|
||||
307,-13
|
||||
308,-13
|
||||
309,-15
|
||||
310,-13
|
||||
311,-13
|
||||
312,-13
|
||||
313,-15
|
||||
314,-13
|
||||
315,-19
|
||||
316,-15
|
||||
317,-17
|
||||
318,-13
|
||||
319,-13
|
||||
320,-13
|
||||
321,-13
|
||||
322,-13
|
||||
323,-15
|
||||
324,-13
|
||||
325,-13
|
||||
326,-13
|
||||
327,-123
|
||||
328,-13
|
||||
329,-13
|
||||
330,-13
|
||||
331,-13
|
||||
332,-13
|
||||
333,-13
|
||||
334,-13
|
||||
335,-13
|
||||
336,-16
|
||||
337,-13
|
||||
338,-23
|
||||
339,-13
|
||||
340,-13
|
||||
341,-13
|
||||
342,-13
|
||||
343,-13
|
||||
344,-13
|
||||
345,-13
|
||||
346,-13
|
||||
347,-13
|
||||
348,-13
|
||||
349,-13
|
||||
350,-134
|
||||
351,-13
|
||||
352,-13
|
||||
353,-13
|
||||
354,-13
|
||||
355,-13
|
||||
356,-13
|
||||
357,-13
|
||||
358,-13
|
||||
359,-13
|
||||
360,-15
|
||||
361,-13
|
||||
362,-13
|
||||
363,-13
|
||||
364,-13
|
||||
365,-13
|
||||
366,-13
|
||||
367,-13
|
||||
368,-13
|
||||
369,-14
|
||||
370,-13
|
||||
371,-13
|
||||
372,-13
|
||||
373,-13
|
||||
374,-13
|
||||
375,-13
|
||||
376,-13
|
||||
377,-124
|
||||
378,-13
|
||||
379,-13
|
||||
380,-13
|
||||
381,-13
|
||||
382,-13
|
||||
383,-13
|
||||
384,-13
|
||||
385,-13
|
||||
386,-13
|
||||
387,-13
|
||||
388,-13
|
||||
389,-121
|
||||
390,-13
|
||||
391,-13
|
||||
392,-13
|
||||
393,-13
|
||||
394,-13
|
||||
395,-13
|
||||
396,-13
|
||||
397,-13
|
||||
398,-17
|
||||
399,-13
|
||||
|
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"algo_name": "Q-learning",
|
||||
"env_name": "FrozenLake-v1",
|
||||
"train_eps": 800,
|
||||
"test_eps": 20,
|
||||
"gamma": 0.9,
|
||||
"epsilon_start": 0.7,
|
||||
"epsilon_end": 0.1,
|
||||
"epsilon_decay": 2000,
|
||||
"lr": 0.9,
|
||||
"device": "cpu",
|
||||
"seed": 10,
|
||||
"show_fig": false,
|
||||
"save_fig": true,
|
||||
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLake-v1/20220824-112735/results/",
|
||||
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/QLearning/outputs/FrozenLake-v1/20220824-112735/models/",
|
||||
"n_states": 16,
|
||||
"n_actions": 4
|
||||
}
|
||||
|
After Width: | Height: | Size: 22 KiB |
@@ -0,0 +1,21 @@
|
||||
episodes,rewards,steps
|
||||
0,1.0,6
|
||||
1,1.0,6
|
||||
2,1.0,6
|
||||
3,1.0,6
|
||||
4,1.0,6
|
||||
5,1.0,6
|
||||
6,1.0,6
|
||||
7,1.0,6
|
||||
8,1.0,6
|
||||
9,1.0,6
|
||||
10,1.0,6
|
||||
11,1.0,6
|
||||
12,1.0,6
|
||||
13,1.0,6
|
||||
14,1.0,6
|
||||
15,1.0,6
|
||||
16,1.0,6
|
||||
17,1.0,6
|
||||
18,1.0,6
|
||||
19,1.0,6
|
||||
|
|
After Width: | Height: | Size: 53 KiB |
@@ -0,0 +1,801 @@
|
||||
episodes,rewards,steps
|
||||
0,0.0,20
|
||||
1,0.0,14
|
||||
2,0.0,13
|
||||
3,0.0,9
|
||||
4,0.0,10
|
||||
5,0.0,6
|
||||
6,0.0,11
|
||||
7,0.0,6
|
||||
8,0.0,3
|
||||
9,0.0,9
|
||||
10,0.0,11
|
||||
11,0.0,22
|
||||
12,0.0,5
|
||||
13,0.0,16
|
||||
14,0.0,4
|
||||
15,0.0,9
|
||||
16,0.0,18
|
||||
17,0.0,2
|
||||
18,0.0,4
|
||||
19,0.0,8
|
||||
20,0.0,7
|
||||
21,0.0,4
|
||||
22,0.0,22
|
||||
23,0.0,15
|
||||
24,0.0,5
|
||||
25,0.0,16
|
||||
26,0.0,7
|
||||
27,0.0,19
|
||||
28,0.0,22
|
||||
29,0.0,16
|
||||
30,0.0,11
|
||||
31,0.0,22
|
||||
32,0.0,28
|
||||
33,0.0,23
|
||||
34,0.0,4
|
||||
35,0.0,11
|
||||
36,0.0,8
|
||||
37,0.0,15
|
||||
38,0.0,5
|
||||
39,0.0,7
|
||||
40,0.0,9
|
||||
41,0.0,4
|
||||
42,0.0,3
|
||||
43,0.0,6
|
||||
44,0.0,41
|
||||
45,0.0,9
|
||||
46,0.0,23
|
||||
47,0.0,3
|
||||
48,1.0,38
|
||||
49,0.0,29
|
||||
50,0.0,17
|
||||
51,0.0,4
|
||||
52,0.0,2
|
||||
53,0.0,25
|
||||
54,0.0,6
|
||||
55,0.0,2
|
||||
56,0.0,30
|
||||
57,0.0,6
|
||||
58,0.0,7
|
||||
59,0.0,11
|
||||
60,0.0,9
|
||||
61,0.0,8
|
||||
62,0.0,23
|
||||
63,0.0,10
|
||||
64,0.0,3
|
||||
65,0.0,5
|
||||
66,0.0,7
|
||||
67,0.0,18
|
||||
68,0.0,8
|
||||
69,0.0,26
|
||||
70,0.0,6
|
||||
71,0.0,14
|
||||
72,0.0,4
|
||||
73,0.0,25
|
||||
74,0.0,21
|
||||
75,0.0,13
|
||||
76,0.0,4
|
||||
77,0.0,29
|
||||
78,0.0,21
|
||||
79,0.0,6
|
||||
80,0.0,6
|
||||
81,0.0,11
|
||||
82,0.0,21
|
||||
83,0.0,9
|
||||
84,0.0,9
|
||||
85,0.0,7
|
||||
86,0.0,48
|
||||
87,0.0,23
|
||||
88,0.0,100
|
||||
89,0.0,60
|
||||
90,0.0,7
|
||||
91,0.0,10
|
||||
92,0.0,24
|
||||
93,0.0,4
|
||||
94,0.0,7
|
||||
95,0.0,17
|
||||
96,0.0,87
|
||||
97,0.0,28
|
||||
98,0.0,7
|
||||
99,0.0,5
|
||||
100,0.0,12
|
||||
101,0.0,14
|
||||
102,0.0,6
|
||||
103,0.0,13
|
||||
104,0.0,93
|
||||
105,0.0,4
|
||||
106,0.0,50
|
||||
107,0.0,8
|
||||
108,0.0,12
|
||||
109,0.0,43
|
||||
110,0.0,30
|
||||
111,0.0,15
|
||||
112,0.0,19
|
||||
113,0.0,100
|
||||
114,0.0,82
|
||||
115,0.0,40
|
||||
116,0.0,88
|
||||
117,0.0,19
|
||||
118,0.0,30
|
||||
119,0.0,27
|
||||
120,0.0,5
|
||||
121,0.0,87
|
||||
122,0.0,9
|
||||
123,0.0,64
|
||||
124,0.0,27
|
||||
125,0.0,68
|
||||
126,0.0,81
|
||||
127,0.0,86
|
||||
128,0.0,100
|
||||
129,0.0,100
|
||||
130,0.0,27
|
||||
131,0.0,41
|
||||
132,0.0,70
|
||||
133,0.0,27
|
||||
134,0.0,6
|
||||
135,0.0,18
|
||||
136,0.0,38
|
||||
137,0.0,26
|
||||
138,0.0,36
|
||||
139,0.0,3
|
||||
140,0.0,61
|
||||
141,0.0,100
|
||||
142,0.0,4
|
||||
143,0.0,39
|
||||
144,0.0,18
|
||||
145,0.0,33
|
||||
146,0.0,29
|
||||
147,0.0,49
|
||||
148,0.0,88
|
||||
149,0.0,22
|
||||
150,0.0,65
|
||||
151,0.0,36
|
||||
152,0.0,30
|
||||
153,0.0,58
|
||||
154,0.0,43
|
||||
155,0.0,53
|
||||
156,0.0,43
|
||||
157,0.0,13
|
||||
158,0.0,8
|
||||
159,0.0,39
|
||||
160,0.0,29
|
||||
161,0.0,26
|
||||
162,0.0,60
|
||||
163,0.0,100
|
||||
164,0.0,31
|
||||
165,0.0,22
|
||||
166,0.0,100
|
||||
167,0.0,46
|
||||
168,0.0,23
|
||||
169,0.0,54
|
||||
170,0.0,8
|
||||
171,0.0,58
|
||||
172,0.0,3
|
||||
173,0.0,47
|
||||
174,0.0,16
|
||||
175,0.0,21
|
||||
176,0.0,44
|
||||
177,0.0,29
|
||||
178,0.0,100
|
||||
179,0.0,100
|
||||
180,0.0,62
|
||||
181,0.0,83
|
||||
182,0.0,26
|
||||
183,0.0,24
|
||||
184,0.0,10
|
||||
185,0.0,12
|
||||
186,0.0,40
|
||||
187,0.0,25
|
||||
188,0.0,18
|
||||
189,0.0,60
|
||||
190,0.0,100
|
||||
191,0.0,100
|
||||
192,0.0,24
|
||||
193,0.0,56
|
||||
194,0.0,71
|
||||
195,0.0,19
|
||||
196,0.0,100
|
||||
197,0.0,44
|
||||
198,0.0,41
|
||||
199,0.0,41
|
||||
200,0.0,60
|
||||
201,0.0,31
|
||||
202,0.0,34
|
||||
203,0.0,35
|
||||
204,0.0,59
|
||||
205,0.0,51
|
||||
206,0.0,100
|
||||
207,0.0,100
|
||||
208,0.0,100
|
||||
209,0.0,100
|
||||
210,0.0,37
|
||||
211,0.0,68
|
||||
212,0.0,40
|
||||
213,0.0,17
|
||||
214,0.0,79
|
||||
215,0.0,100
|
||||
216,0.0,26
|
||||
217,0.0,61
|
||||
218,0.0,25
|
||||
219,0.0,18
|
||||
220,0.0,27
|
||||
221,0.0,13
|
||||
222,0.0,100
|
||||
223,0.0,87
|
||||
224,0.0,100
|
||||
225,0.0,92
|
||||
226,0.0,100
|
||||
227,0.0,8
|
||||
228,0.0,100
|
||||
229,0.0,64
|
||||
230,0.0,17
|
||||
231,0.0,82
|
||||
232,0.0,100
|
||||
233,0.0,94
|
||||
234,0.0,7
|
||||
235,0.0,36
|
||||
236,0.0,100
|
||||
237,0.0,56
|
||||
238,0.0,17
|
||||
239,0.0,100
|
||||
240,0.0,83
|
||||
241,0.0,100
|
||||
242,0.0,100
|
||||
243,0.0,43
|
||||
244,0.0,87
|
||||
245,0.0,42
|
||||
246,0.0,80
|
||||
247,0.0,54
|
||||
248,0.0,82
|
||||
249,0.0,97
|
||||
250,0.0,65
|
||||
251,0.0,83
|
||||
252,0.0,100
|
||||
253,0.0,59
|
||||
254,0.0,100
|
||||
255,0.0,78
|
||||
256,0.0,100
|
||||
257,0.0,100
|
||||
258,0.0,43
|
||||
259,0.0,80
|
||||
260,0.0,100
|
||||
261,0.0,70
|
||||
262,0.0,94
|
||||
263,0.0,100
|
||||
264,0.0,100
|
||||
265,0.0,37
|
||||
266,0.0,11
|
||||
267,0.0,31
|
||||
268,0.0,100
|
||||
269,0.0,34
|
||||
270,0.0,32
|
||||
271,0.0,58
|
||||
272,0.0,38
|
||||
273,0.0,28
|
||||
274,0.0,100
|
||||
275,0.0,59
|
||||
276,0.0,100
|
||||
277,0.0,82
|
||||
278,0.0,51
|
||||
279,0.0,25
|
||||
280,0.0,73
|
||||
281,0.0,56
|
||||
282,0.0,55
|
||||
283,0.0,38
|
||||
284,0.0,100
|
||||
285,0.0,100
|
||||
286,0.0,92
|
||||
287,0.0,100
|
||||
288,0.0,100
|
||||
289,0.0,100
|
||||
290,0.0,37
|
||||
291,0.0,100
|
||||
292,0.0,66
|
||||
293,0.0,24
|
||||
294,0.0,17
|
||||
295,0.0,100
|
||||
296,0.0,59
|
||||
297,0.0,25
|
||||
298,0.0,73
|
||||
299,0.0,100
|
||||
300,0.0,29
|
||||
301,0.0,100
|
||||
302,0.0,72
|
||||
303,0.0,6
|
||||
304,1.0,57
|
||||
305,0.0,47
|
||||
306,0.0,48
|
||||
307,0.0,13
|
||||
308,0.0,100
|
||||
309,0.0,38
|
||||
310,0.0,100
|
||||
311,0.0,20
|
||||
312,0.0,100
|
||||
313,0.0,100
|
||||
314,0.0,5
|
||||
315,0.0,39
|
||||
316,0.0,11
|
||||
317,0.0,83
|
||||
318,0.0,42
|
||||
319,0.0,100
|
||||
320,0.0,99
|
||||
321,0.0,83
|
||||
322,0.0,28
|
||||
323,0.0,46
|
||||
324,0.0,100
|
||||
325,0.0,100
|
||||
326,0.0,62
|
||||
327,0.0,100
|
||||
328,0.0,23
|
||||
329,0.0,91
|
||||
330,0.0,53
|
||||
331,0.0,19
|
||||
332,0.0,26
|
||||
333,0.0,93
|
||||
334,0.0,38
|
||||
335,0.0,22
|
||||
336,0.0,43
|
||||
337,0.0,100
|
||||
338,0.0,90
|
||||
339,0.0,18
|
||||
340,0.0,45
|
||||
341,0.0,65
|
||||
342,1.0,22
|
||||
343,0.0,100
|
||||
344,1.0,15
|
||||
345,1.0,72
|
||||
346,0.0,5
|
||||
347,1.0,6
|
||||
348,1.0,6
|
||||
349,1.0,9
|
||||
350,1.0,8
|
||||
351,1.0,9
|
||||
352,1.0,8
|
||||
353,1.0,6
|
||||
354,1.0,6
|
||||
355,1.0,10
|
||||
356,1.0,6
|
||||
357,0.0,5
|
||||
358,0.0,3
|
||||
359,1.0,6
|
||||
360,1.0,6
|
||||
361,1.0,6
|
||||
362,1.0,6
|
||||
363,1.0,8
|
||||
364,1.0,6
|
||||
365,1.0,8
|
||||
366,1.0,6
|
||||
367,1.0,6
|
||||
368,1.0,8
|
||||
369,1.0,6
|
||||
370,1.0,6
|
||||
371,0.0,5
|
||||
372,1.0,6
|
||||
373,0.0,6
|
||||
374,1.0,6
|
||||
375,1.0,12
|
||||
376,1.0,6
|
||||
377,1.0,6
|
||||
378,1.0,9
|
||||
379,1.0,6
|
||||
380,1.0,6
|
||||
381,0.0,2
|
||||
382,0.0,3
|
||||
383,0.0,2
|
||||
384,0.0,4
|
||||
385,0.0,3
|
||||
386,1.0,7
|
||||
387,1.0,6
|
||||
388,1.0,6
|
||||
389,1.0,8
|
||||
390,1.0,9
|
||||
391,1.0,8
|
||||
392,1.0,8
|
||||
393,1.0,6
|
||||
394,1.0,6
|
||||
395,1.0,7
|
||||
396,1.0,6
|
||||
397,0.0,5
|
||||
398,0.0,5
|
||||
399,1.0,10
|
||||
400,1.0,6
|
||||
401,0.0,3
|
||||
402,1.0,6
|
||||
403,1.0,7
|
||||
404,1.0,6
|
||||
405,1.0,6
|
||||
406,1.0,6
|
||||
407,1.0,6
|
||||
408,1.0,6
|
||||
409,1.0,6
|
||||
410,1.0,6
|
||||
411,0.0,5
|
||||
412,1.0,6
|
||||
413,1.0,6
|
||||
414,0.0,2
|
||||
415,1.0,6
|
||||
416,1.0,6
|
||||
417,1.0,6
|
||||
418,1.0,6
|
||||
419,1.0,6
|
||||
420,1.0,8
|
||||
421,1.0,6
|
||||
422,1.0,6
|
||||
423,1.0,6
|
||||
424,1.0,6
|
||||
425,1.0,7
|
||||
426,0.0,5
|
||||
427,1.0,6
|
||||
428,1.0,6
|
||||
429,1.0,6
|
||||
430,1.0,8
|
||||
431,1.0,6
|
||||
432,1.0,6
|
||||
433,1.0,6
|
||||
434,1.0,6
|
||||
435,0.0,2
|
||||
436,1.0,8
|
||||
437,1.0,7
|
||||
438,1.0,6
|
||||
439,1.0,7
|
||||
440,1.0,6
|
||||
441,1.0,6
|
||||
442,0.0,3
|
||||
443,0.0,4
|
||||
444,1.0,6
|
||||
445,1.0,6
|
||||
446,1.0,7
|
||||
447,1.0,6
|
||||
448,1.0,6
|
||||
449,1.0,6
|
||||
450,1.0,6
|
||||
451,1.0,6
|
||||
452,1.0,6
|
||||
453,1.0,8
|
||||
454,1.0,6
|
||||
455,1.0,6
|
||||
456,1.0,6
|
||||
457,1.0,6
|
||||
458,1.0,6
|
||||
459,1.0,7
|
||||
460,1.0,8
|
||||
461,1.0,6
|
||||
462,1.0,7
|
||||
463,1.0,6
|
||||
464,1.0,6
|
||||
465,1.0,6
|
||||
466,1.0,6
|
||||
467,1.0,8
|
||||
468,1.0,6
|
||||
469,1.0,6
|
||||
470,1.0,8
|
||||
471,1.0,6
|
||||
472,1.0,11
|
||||
473,1.0,6
|
||||
474,1.0,6
|
||||
475,1.0,6
|
||||
476,1.0,8
|
||||
477,0.0,2
|
||||
478,1.0,7
|
||||
479,1.0,6
|
||||
480,1.0,6
|
||||
481,1.0,7
|
||||
482,1.0,6
|
||||
483,1.0,6
|
||||
484,1.0,6
|
||||
485,1.0,6
|
||||
486,0.0,3
|
||||
487,1.0,7
|
||||
488,1.0,6
|
||||
489,1.0,6
|
||||
490,1.0,6
|
||||
491,0.0,3
|
||||
492,1.0,6
|
||||
493,1.0,7
|
||||
494,1.0,12
|
||||
495,1.0,6
|
||||
496,0.0,9
|
||||
497,1.0,6
|
||||
498,1.0,6
|
||||
499,0.0,8
|
||||
500,1.0,6
|
||||
501,0.0,3
|
||||
502,0.0,5
|
||||
503,0.0,3
|
||||
504,1.0,6
|
||||
505,1.0,6
|
||||
506,1.0,6
|
||||
507,1.0,6
|
||||
508,1.0,6
|
||||
509,1.0,6
|
||||
510,1.0,6
|
||||
511,1.0,6
|
||||
512,1.0,6
|
||||
513,1.0,6
|
||||
514,0.0,2
|
||||
515,1.0,7
|
||||
516,1.0,6
|
||||
517,1.0,6
|
||||
518,1.0,6
|
||||
519,1.0,6
|
||||
520,1.0,6
|
||||
521,1.0,7
|
||||
522,0.0,4
|
||||
523,1.0,6
|
||||
524,0.0,5
|
||||
525,1.0,6
|
||||
526,1.0,6
|
||||
527,1.0,6
|
||||
528,1.0,6
|
||||
529,0.0,3
|
||||
530,1.0,6
|
||||
531,1.0,6
|
||||
532,1.0,6
|
||||
533,1.0,7
|
||||
534,1.0,8
|
||||
535,1.0,6
|
||||
536,1.0,6
|
||||
537,1.0,6
|
||||
538,1.0,6
|
||||
539,1.0,7
|
||||
540,1.0,7
|
||||
541,1.0,7
|
||||
542,1.0,8
|
||||
543,1.0,6
|
||||
544,1.0,10
|
||||
545,1.0,6
|
||||
546,1.0,6
|
||||
547,1.0,6
|
||||
548,1.0,8
|
||||
549,1.0,6
|
||||
550,1.0,6
|
||||
551,1.0,8
|
||||
552,1.0,6
|
||||
553,1.0,7
|
||||
554,1.0,6
|
||||
555,1.0,7
|
||||
556,1.0,6
|
||||
557,1.0,6
|
||||
558,1.0,7
|
||||
559,1.0,7
|
||||
560,1.0,7
|
||||
561,1.0,6
|
||||
562,1.0,6
|
||||
563,1.0,6
|
||||
564,1.0,6
|
||||
565,1.0,6
|
||||
566,1.0,6
|
||||
567,1.0,6
|
||||
568,1.0,7
|
||||
569,0.0,4
|
||||
570,1.0,8
|
||||
571,1.0,8
|
||||
572,1.0,7
|
||||
573,1.0,6
|
||||
574,1.0,8
|
||||
575,1.0,6
|
||||
576,1.0,6
|
||||
577,1.0,7
|
||||
578,1.0,6
|
||||
579,1.0,6
|
||||
580,1.0,8
|
||||
581,1.0,7
|
||||
582,1.0,6
|
||||
583,1.0,6
|
||||
584,0.0,3
|
||||
585,1.0,11
|
||||
586,1.0,6
|
||||
587,1.0,8
|
||||
588,0.0,2
|
||||
589,1.0,6
|
||||
590,1.0,6
|
||||
591,1.0,6
|
||||
592,1.0,6
|
||||
593,1.0,8
|
||||
594,1.0,6
|
||||
595,1.0,7
|
||||
596,1.0,6
|
||||
597,1.0,7
|
||||
598,1.0,6
|
||||
599,1.0,8
|
||||
600,0.0,2
|
||||
601,1.0,6
|
||||
602,1.0,7
|
||||
603,1.0,6
|
||||
604,1.0,6
|
||||
605,1.0,10
|
||||
606,1.0,7
|
||||
607,1.0,6
|
||||
608,1.0,6
|
||||
609,1.0,6
|
||||
610,1.0,6
|
||||
611,1.0,6
|
||||
612,1.0,7
|
||||
613,0.0,4
|
||||
614,1.0,7
|
||||
615,1.0,6
|
||||
616,1.0,8
|
||||
617,0.0,3
|
||||
618,1.0,6
|
||||
619,1.0,6
|
||||
620,1.0,6
|
||||
621,1.0,6
|
||||
622,0.0,2
|
||||
623,1.0,6
|
||||
624,1.0,6
|
||||
625,1.0,6
|
||||
626,1.0,6
|
||||
627,1.0,6
|
||||
628,1.0,7
|
||||
629,1.0,6
|
||||
630,1.0,6
|
||||
631,1.0,7
|
||||
632,1.0,6
|
||||
633,1.0,6
|
||||
634,1.0,6
|
||||
635,1.0,6
|
||||
636,1.0,6
|
||||
637,1.0,6
|
||||
638,1.0,6
|
||||
639,1.0,8
|
||||
640,1.0,6
|
||||
641,1.0,8
|
||||
642,1.0,7
|
||||
643,1.0,6
|
||||
644,0.0,3
|
||||
645,1.0,6
|
||||
646,1.0,7
|
||||
647,1.0,6
|
||||
648,1.0,6
|
||||
649,1.0,6
|
||||
650,1.0,10
|
||||
651,1.0,6
|
||||
652,1.0,6
|
||||
653,1.0,6
|
||||
654,1.0,6
|
||||
655,1.0,10
|
||||
656,1.0,6
|
||||
657,1.0,8
|
||||
658,1.0,8
|
||||
659,1.0,7
|
||||
660,1.0,6
|
||||
661,0.0,5
|
||||
662,0.0,2
|
||||
663,1.0,8
|
||||
664,1.0,6
|
||||
665,1.0,10
|
||||
666,1.0,6
|
||||
667,1.0,8
|
||||
668,1.0,10
|
||||
669,1.0,6
|
||||
670,1.0,6
|
||||
671,1.0,6
|
||||
672,1.0,10
|
||||
673,1.0,6
|
||||
674,0.0,4
|
||||
675,1.0,6
|
||||
676,1.0,6
|
||||
677,1.0,6
|
||||
678,1.0,15
|
||||
679,1.0,6
|
||||
680,1.0,6
|
||||
681,1.0,6
|
||||
682,1.0,6
|
||||
683,1.0,6
|
||||
684,1.0,6
|
||||
685,1.0,8
|
||||
686,1.0,6
|
||||
687,1.0,7
|
||||
688,1.0,6
|
||||
689,1.0,6
|
||||
690,1.0,8
|
||||
691,1.0,6
|
||||
692,1.0,6
|
||||
693,1.0,8
|
||||
694,1.0,8
|
||||
695,1.0,6
|
||||
696,1.0,6
|
||||
697,1.0,6
|
||||
698,1.0,10
|
||||
699,1.0,6
|
||||
700,1.0,6
|
||||
701,1.0,6
|
||||
702,1.0,6
|
||||
703,1.0,6
|
||||
704,1.0,6
|
||||
705,1.0,6
|
||||
706,1.0,8
|
||||
707,1.0,8
|
||||
708,1.0,6
|
||||
709,1.0,6
|
||||
710,0.0,2
|
||||
711,1.0,6
|
||||
712,1.0,6
|
||||
713,1.0,6
|
||||
714,1.0,8
|
||||
715,1.0,6
|
||||
716,1.0,6
|
||||
717,1.0,6
|
||||
718,1.0,6
|
||||
719,1.0,6
|
||||
720,1.0,6
|
||||
721,1.0,6
|
||||
722,1.0,6
|
||||
723,1.0,6
|
||||
724,1.0,7
|
||||
725,0.0,3
|
||||
726,1.0,7
|
||||
727,1.0,6
|
||||
728,1.0,6
|
||||
729,1.0,6
|
||||
730,0.0,2
|
||||
731,1.0,6
|
||||
732,1.0,8
|
||||
733,1.0,6
|
||||
734,1.0,6
|
||||
735,1.0,6
|
||||
736,1.0,6
|
||||
737,1.0,9
|
||||
738,1.0,6
|
||||
739,1.0,6
|
||||
740,1.0,6
|
||||
741,1.0,6
|
||||
742,1.0,6
|
||||
743,1.0,6
|
||||
744,1.0,9
|
||||
745,1.0,7
|
||||
746,0.0,4
|
||||
747,1.0,6
|
||||
748,1.0,8
|
||||
749,1.0,11
|
||||
750,1.0,6
|
||||
751,1.0,6
|
||||
752,1.0,6
|
||||
753,1.0,6
|
||||
754,1.0,6
|
||||
755,1.0,8
|
||||
756,1.0,6
|
||||
757,1.0,6
|
||||
758,1.0,8
|
||||
759,1.0,7
|
||||
760,1.0,6
|
||||
761,1.0,8
|
||||
762,1.0,6
|
||||
763,0.0,5
|
||||
764,1.0,9
|
||||
765,1.0,8
|
||||
766,1.0,8
|
||||
767,1.0,6
|
||||
768,1.0,8
|
||||
769,1.0,8
|
||||
770,1.0,6
|
||||
771,0.0,5
|
||||
772,0.0,3
|
||||
773,0.0,2
|
||||
774,1.0,8
|
||||
775,1.0,6
|
||||
776,1.0,6
|
||||
777,1.0,6
|
||||
778,1.0,6
|
||||
779,1.0,6
|
||||
780,1.0,6
|
||||
781,1.0,6
|
||||
782,1.0,6
|
||||
783,1.0,6
|
||||
784,1.0,6
|
||||
785,1.0,6
|
||||
786,1.0,6
|
||||
787,1.0,6
|
||||
788,1.0,6
|
||||
789,0.0,2
|
||||
790,1.0,6
|
||||
791,0.0,4
|
||||
792,1.0,6
|
||||
793,1.0,6
|
||||
794,1.0,6
|
||||
795,1.0,6
|
||||
796,1.0,6
|
||||
797,1.0,8
|
||||
798,0.0,5
|
||||
799,1.0,6
|
||||
|
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2020-09-11 23:03:00
|
||||
LastEditor: John
|
||||
LastEditTime: 2021-12-22 10:54:57
|
||||
LastEditTime: 2022-08-24 10:31:04
|
||||
Discription: use defaultdict to define Q table
|
||||
Environment:
|
||||
'''
|
||||
@@ -15,50 +15,52 @@ import torch
|
||||
from collections import defaultdict
|
||||
|
||||
class QLearning(object):
|
||||
def __init__(self,
|
||||
n_actions,cfg):
|
||||
self.n_actions = n_actions
|
||||
self.lr = cfg.lr # 学习率
|
||||
self.gamma = cfg.gamma
|
||||
self.epsilon = cfg.epsilon_start
|
||||
def __init__(self,cfg):
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.lr = cfg['lr']
|
||||
self.gamma = cfg['gamma']
|
||||
self.epsilon = cfg['epsilon_start']
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg.epsilon_start
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.Q_table = defaultdict(lambda: np.zeros(n_actions)) # 用嵌套字典存放状态->动作->状态-动作值(Q值)的映射,即Q表
|
||||
def sample(self, state):
|
||||
''' 采样动作,训练时用
|
||||
self.epsilon_start = cfg['epsilon_start']
|
||||
self.epsilon_end = cfg['epsilon_end']
|
||||
self.epsilon_decay = cfg['epsilon_decay']
|
||||
self.Q_table = defaultdict(lambda: np.zeros(self.n_actions)) # use nested dictionary to represent Q(s,a), here set all Q(s,a)=0 initially, not like pseudo code
|
||||
def sample_action(self, state):
|
||||
''' sample action with e-greedy policy while training
|
||||
'''
|
||||
self.sample_count += 1
|
||||
# epsilon must decay(linear,exponential and etc.) for balancing exploration and exploitation
|
||||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||
math.exp(-1. * self.sample_count / self.epsilon_decay) # epsilon是会递减的,这里选择指数递减
|
||||
# e-greedy 策略
|
||||
math.exp(-1. * self.sample_count / self.epsilon_decay)
|
||||
if np.random.uniform(0, 1) > self.epsilon:
|
||||
action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)最大对应的动作
|
||||
action = np.argmax(self.Q_table[str(state)]) # choose action corresponding to the maximum q value
|
||||
else:
|
||||
action = np.random.choice(self.n_actions) # 随机选择动作
|
||||
action = np.random.choice(self.n_actions) # choose action randomly
|
||||
return action
|
||||
def predict(self,state):
|
||||
''' 预测或选择动作,测试时用
|
||||
def predict_action(self,state):
|
||||
''' predict action while testing
|
||||
'''
|
||||
action = np.argmax(self.Q_table[str(state)])
|
||||
return action
|
||||
def update(self, state, action, reward, next_state, done):
|
||||
Q_predict = self.Q_table[str(state)][action]
|
||||
if done: # 终止状态
|
||||
if done: # terminal state
|
||||
Q_target = reward
|
||||
else:
|
||||
Q_target = reward + self.gamma * np.max(self.Q_table[str(next_state)])
|
||||
self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)
|
||||
def save(self,path):
|
||||
def save_model(self,path):
|
||||
import dill
|
||||
from pathlib import Path
|
||||
# create path
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
obj=self.Q_table,
|
||||
f=path+"Qleaning_model.pkl",
|
||||
pickle_module=dill
|
||||
)
|
||||
print("保存模型成功!")
|
||||
def load(self, path):
|
||||
print("Model saved!")
|
||||
def load_model(self, path):
|
||||
import dill
|
||||
self.Q_table =torch.load(f=path+'Qleaning_model.pkl',pickle_module=dill)
|
||||
print("加载模型成功!")
|
||||
print("Mode loaded!")
|
||||
27
projects/codes/SAC-S/sac.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
class SAC:
|
||||
def __init__(self,n_actions,models,memory,cfg):
|
||||
self.device = cfg.device
|
||||
self.value_net = models['ValueNet'].to(self.device) # $\psi$
|
||||
self.target_value_net = models['ValueNet'].to(self.device) # $\bar{\psi}$
|
||||
self.soft_q_net = models['SoftQNet'].to(self.device) # $\theta$
|
||||
self.policy_net = models['PolicyNet'].to(self.device) # $\phi$
|
||||
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=cfg.value_lr)
|
||||
self.soft_q_optimizer = optim.Adam(self.soft_q_net.parameters(), lr=cfg.soft_q_lr)
|
||||
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.policy_lr)
|
||||
for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
self.value_criterion = nn.MSELoss()
|
||||
self.soft_q_criterion = nn.MSELoss()
|
||||
def update(self):
|
||||
# sample a batch of transitions from replay buffer
|
||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
|
||||
self.batch_size)
|
||||
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
|
||||
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) # shape(batchsize,1)
|
||||
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize)
|
||||
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
|
||||
done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1) # shape(batchsize,1)
|
||||
0
projects/codes/SAC/sacd_cnn.py
Normal file
@@ -5,7 +5,7 @@ Author: John
|
||||
Email: johnjim0816@gmail.com
|
||||
Date: 2021-03-12 16:02:24
|
||||
LastEditor: John
|
||||
LastEditTime: 2022-08-22 17:41:28
|
||||
LastEditTime: 2022-08-24 10:31:30
|
||||
Discription:
|
||||
Environment:
|
||||
'''
|
||||
@@ -64,14 +64,14 @@ def smooth(data, weight=0.9):
|
||||
def plot_rewards(rewards,cfg,path=None,tag='train'):
|
||||
sns.set()
|
||||
plt.figure() # 创建一个图形实例,方便同时多画几个图
|
||||
plt.title(f"{tag}ing curve on {cfg.device} of {cfg.algo_name} for {cfg.env_name}")
|
||||
plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
|
||||
plt.xlabel('epsiodes')
|
||||
plt.plot(rewards, label='rewards')
|
||||
plt.plot(smooth(rewards), label='smoothed')
|
||||
plt.legend()
|
||||
if cfg.save_fig:
|
||||
if cfg['save_fig']:
|
||||
plt.savefig(f"{path}/{tag}ing_curve.png")
|
||||
if cfg.show_fig:
|
||||
if cfg['show_fig']:
|
||||
plt.show()
|
||||
|
||||
def plot_losses(losses, algo="DQN", save=True, path='./'):
|
||||
@@ -110,12 +110,21 @@ def del_empty_dir(*paths):
|
||||
if not os.listdir(os.path.join(path, dir)):
|
||||
os.removedirs(os.path.join(path, dir))
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
def save_args(args,path=None):
|
||||
# 保存参数
|
||||
args_dict = vars(args)
|
||||
# save parameters
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
with open(f"{path}/params.json", 'w') as fp:
|
||||
json.dump(args_dict, fp)
|
||||
json.dump(args, fp,cls=NpEncoder)
|
||||
print("Parameters saved!")
|
||||
|
||||
def all_seed(env,seed = 1):
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
## 环境汇总
|
||||
# 环境说明汇总
|
||||
|
||||
## 算法SAR一览
|
||||
|
||||
说明:SAR分别指状态(S)、动作(A)以及奖励(R),下表的Reward Range表示每回合能获得的奖励范围,Steps表示环境中每回合的最大步数
|
||||
|
||||
| Environment ID | Observation Space | Action Space | Reward Range | Steps |
|
||||
| :--------------------------------: | :---------------: | :----------: | :----------: | :------: |
|
||||
| CartPole-v0 | Box(4,) | Discrete(2) | [0,200] | 200 |
|
||||
| CartPole-v1 | Box(4,) | Discrete(2) | [0,500] | 500 |
|
||||
| CliffWalking-v0 | Discrete(48) | Discrete(4) | [-inf,-13] | [13,inf] |
|
||||
| FrozenLake-v1(*is_slippery*=False) | Discrete(16) | Discrete(4) | 0 or 1 | [6,info] |
|
||||
|
||||
## 环境描述
|
||||
|
||||
[OpenAI Gym](./gym_info.md)
|
||||
[MuJoCo](./mujoco_info.md)
|
||||
|
||||
|
||||
|
||||
15
projects/codes/scripts/DQN_task0.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
# run DQN on CartPole-v0
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/DQN/main.py
|
||||
16
projects/codes/scripts/DQN_task1.sh
Normal file
@@ -0,0 +1,16 @@
|
||||
'''
|
||||
run DQN on CartPole-v1, not finished yet
|
||||
'''
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/DQN/main.py --env_name CartPole-v1 --train_eps 500 --epsilon_decay 1000 --memory_capacity 200000 --batch_size 128 --device cuda
|
||||
14
projects/codes/scripts/Qlearning_task0.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/QLearning/main.py --device cpu
|
||||
14
projects/codes/scripts/Qlearning_task1.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
# source conda, if you are already in proper conda environment, then comment the codes util "conda activate easyrl"
|
||||
if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f "$HOME/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||
echo "source file at ~/opt/anaconda3/etc/profile.d/conda.sh"
|
||||
source ~/opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo 'please manually config the conda source path'
|
||||
fi
|
||||
conda activate easyrl # easyrl here can be changed to another name of conda env that you have created
|
||||
codes_dir=$(dirname $(dirname $(readlink -f "$0"))) # "codes" path
|
||||
python $codes_dir/QLearning/main.py --env_name FrozenLake-v1 --train_eps 800 --epsilon_start 0.70 --epsilon_end 0.1 --epsilon_decay 2000 --gamma 0.9 --lr 0.9 --device cpu
|
||||
|
Before Width: | Height: | Size: 317 KiB After Width: | Height: | Size: 235 KiB |
@@ -1,11 +1,8 @@
|
||||
gym==0.21.0
|
||||
torch==1.10.0
|
||||
torchvision==0.11.0
|
||||
torchaudio==0.10.0
|
||||
ipykernel==6.15.1
|
||||
jupyter==1.0.0
|
||||
matplotlib==3.5.2
|
||||
seaborn==0.11.2
|
||||
dill==0.3.5.1
|
||||
argparse==1.4.0
|
||||
pandas==1.3.5
|
||||
pandas==1.3.5
|
||||
|
||||