Merge branch 'master' of github.com:datawhalechina/easy-rl

This commit is contained in:
qiwang067
2022-08-23 16:20:16 +08:00
54 changed files with 1639 additions and 503 deletions

View File

@@ -1,6 +1,6 @@
## 0、写在前面
本项目用于学习RL基础算法主要面向对象为RL初学者、需要结合RL的非专业学习者尽量做到: **(中文)注释详细****结构清晰**。
本项目用于学习RL基础算法主要面向对象为RL初学者、需要结合RL的非专业学习者尽量做到: **注释详细****结构清晰**。
注意本项目为实战内容,建议首先掌握相关算法的一些理论基础,再来享用本项目,理论教程参考本人参与编写的[蘑菇书](https://github.com/datawhalechina/easy-rl)。
@@ -10,25 +10,65 @@
项目内容主要包含以下几个部分:
* [Jupyter Notebook](./notebooks/)使用Notebook写的算法有比较详细的实战引导推荐新手食用
* [codes](./assets/)这些是基于Python脚本写的算法风格比较接近实际项目的写法推荐有一定代码基础的人阅读下面会说明其具体的一些架构
* [codes](./codes/)这些是基于Python脚本写的算法风格比较接近实际项目的写法推荐有一定代码基础的人阅读下面会说明其具体的一些架构
* [parl](./PARL/):应业务需求,写了一些基于百度飞浆平台和```parl```模块的RL实例
* [附件](./assets/):目前包含强化学习各算法的中文伪代码
[codes](./assets/)结构主要分为以下几个脚本:
* ```[algorithm_name].py```:即保存算法的脚本,例如```dqn.py```,每种算法都会有一定的基础模块,例如```Replay Buffer```、```MLP```(多层感知机)等等;
* ```task.py```: 即保存任务的脚本,基本包括基于```argparse```模块的参数,训练以及测试函数等等;
* ```task.py```: 即保存任务的脚本,基本包括基于```argparse```模块的参数,训练以及测试函数等等,其中训练函数即```train```遵循伪代码而设计,想读懂代码可从该函数入手
* ```utils.py```:该脚本用于保存诸如存储结果以及画图的软件,在实际项目或研究中,推荐大家使用```Tensorboard```来保存结果,然后使用诸如```matplotlib```以及```seabron```来进一步画图。
## 2、算法列表
## 2、运行环境
注:点击对应的名称会跳到[codes](./codes/)下对应的算法中,其他版本还请读者自行翻阅
python 3.7、pytorch 1.6.0-1.9.0、gym 0.21.0
| 算法名称 | 参考文献 | 环境 | 备注 |
| :-----------------------: | :----------------------------------------------------------: | :--: | :--: |
| | | | |
| DQN-CNN | | | 待更 |
| [SoftQ](codes/SoftQ) | [Soft Q-learning paper](https://arxiv.org/abs/1702.08165) | | |
| [SAC](codes/SAC) | [SAC paper](https://arxiv.org/pdf/1812.05905.pdf) | | |
| [SAC-Discrete](codes/SAC) | [SAC-Discrete paper](https://arxiv.org/pdf/1910.07207.pdf) | | |
| SAC-V | [SAC-V paper](https://arxiv.org/abs/1801.01290) | | |
| DSAC | [DSAC paper](https://paperswithcode.com/paper/addressing-value-estimation-errors-in) | | 待更 |
## 3、运行环境
Python 3.7、PyTorch 1.10.0、Gym 0.21.0
在项目根目录下执行以下命令复现环境:
```bash
pip install -r requirements.txt
```
## 3、使用说明
如果需要使用CUDA则需另外安装```cudatoolkit```,推荐```10.2```或者```11.3```版本的CUDA如下
```bash
conda install cudatoolkit=11.3 -c pytorch
```
如果conda需要镜像加速安装的话点击[该清华镜像链接](https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/),选择对应的操作系统,比如```win-64```,然后复制链接,执行如下命令:
```bash
conda install cudatoolkit=11.3 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/win-64/
```
执行以下Python脚本如果返回True说明cuda安装成功:
```python
import torch
print(torch.cuda.is_available())
```
如果还是不成功可以使用pip安装
```bash
pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 torchaudio==0.10.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
## 4、使用说明
直接运行带有```train```的py文件或ipynb文件会进行训练默认的任务
也可以运行带有```task```的py文件训练不同的任务
对于[codes](./codes/)
* 运行带有task的py脚本
对于[Jupyter Notebook](./notebooks/)
* 直接运行对应的ipynb文件就行
## 5、友情说明
推荐使用VS Code做项目入门可参考[VSCode上手指南](https://blog.csdn.net/JohnJim0/article/details/126366454)

View File

@@ -1,4 +1,33 @@
\relax
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{1}{}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{2}{}\protected@file@percent }
\gdef \@abspage@last{2}
\providecommand\hyper@newdestlabel[2]{}
\providecommand\HyperFirstAtBeginDocument{\AtBeginDocument}
\HyperFirstAtBeginDocument{\ifx\hyper@anchor\@undefined
\global\let\oldcontentsline\contentsline
\gdef\contentsline#1#2#3#4{\oldcontentsline{#1}{#2}{#3}}
\global\let\oldnewlabel\newlabel
\gdef\newlabel#1#2{\newlabelxx{#1}#2}
\gdef\newlabelxx#1#2#3#4#5#6{\oldnewlabel{#1}{{#2}{#3}}}
\AtEndDocument{\ifx\hyper@anchor\@undefined
\let\contentsline\oldcontentsline
\let\newlabel\oldnewlabel
\fi}
\fi}
\global\let\hyper@last\relax
\gdef\HyperFirstAtBeginDocument#1{#1}
\providecommand*\HyPL@Entry[1]{}
\HyPL@Entry{0<</S/D>>}
\@writefile{toc}{\contentsline {section}{\numberline {1}模版备用}{2}{section.1}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{2}{algorithm.}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {2}Q learning算法}{3}{section.2}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{3}{algorithm.}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {3}Sarsa算法}{4}{section.3}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{4}{algorithm.}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {4}Policy Gradient算法}{5}{section.4}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{5}{algorithm.}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {5}DQN算法}{6}{section.5}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{6}{algorithm.}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {6}SoftQ算法}{7}{section.6}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{7}{algorithm.}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {7}SAC算法}{8}{section.7}\protected@file@percent }
\@writefile{loa}{\contentsline {algorithm}{\numberline {}{\ignorespaces }}{8}{algorithm.}\protected@file@percent }
\gdef \@abspage@last{8}

View File

@@ -1,4 +1,4 @@
This is XeTeX, Version 3.141592653-2.6-0.999993 (TeX Live 2021) (preloaded format=xelatex 2021.8.22) 15 AUG 2022 15:05
This is XeTeX, Version 3.141592653-2.6-0.999993 (TeX Live 2021) (preloaded format=xelatex 2021.8.22) 22 AUG 2022 16:54
entering extended mode
restricted \write18 enabled.
file:line:error style messages enabled.
@@ -292,107 +292,282 @@ LaTeX Font Info: Redeclaring font encoding OMS on input line 744.
\mathdisplay@stack=\toks24
LaTeX Info: Redefining \[ on input line 2923.
LaTeX Info: Redefining \] on input line 2924.
) (/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/hyperref.sty
Package: hyperref 2021-02-27 v7.00k Hypertext links for LaTeX
(/usr/local/texlive/2021/texmf-dist/tex/generic/ltxcmds/ltxcmds.sty
Package: ltxcmds 2020-05-10 v1.25 LaTeX kernel commands for general use (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/generic/iftex/iftex.sty
Package: iftex 2020/03/06 v1.0d TeX engine tests
) (/usr/local/texlive/2021/texmf-dist/tex/generic/pdftexcmds/pdftexcmds.sty
Package: pdftexcmds 2020-06-27 v0.33 Utility functions of pdfTeX for LuaTeX (HO)
(/usr/local/texlive/2021/texmf-dist/tex/generic/infwarerr/infwarerr.sty
Package: infwarerr 2019/12/03 v1.5 Providing info/warning/error messages (HO)
)
Package pdftexcmds Info: \pdf@primitive is available.
Package pdftexcmds Info: \pdf@ifprimitive is available.
Package pdftexcmds Info: \pdfdraftmode not found.
) (/usr/local/texlive/2021/texmf-dist/tex/generic/kvsetkeys/kvsetkeys.sty
Package: kvsetkeys 2019/12/15 v1.18 Key value parser (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/generic/kvdefinekeys/kvdefinekeys.sty
Package: kvdefinekeys 2019-12-19 v1.6 Define keys (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/generic/pdfescape/pdfescape.sty
Package: pdfescape 2019/12/09 v1.15 Implements pdfTeX's escape features (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/latex/hycolor/hycolor.sty
Package: hycolor 2020-01-27 v1.10 Color options for hyperref/bookmark (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/latex/letltxmacro/letltxmacro.sty
Package: letltxmacro 2019/12/03 v1.6 Let assignment for LaTeX macros (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/latex/auxhook/auxhook.sty
Package: auxhook 2019-12-17 v1.6 Hooks for auxiliary files (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/latex/kvoptions/kvoptions.sty
Package: kvoptions 2020-10-07 v3.14 Key value format for package options (HO)
)
\@linkdim=\dimen183
\Hy@linkcounter=\count301
\Hy@pagecounter=\count302
(/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/pd1enc.def
File: pd1enc.def 2021-02-27 v7.00k Hyperref: PDFDocEncoding definition (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/hyperref-langpatches.def
File: hyperref-langpatches.def 2021-02-27 v7.00k Hyperref: patches for babel languages
) (/usr/local/texlive/2021/texmf-dist/tex/generic/intcalc/intcalc.sty
Package: intcalc 2019/12/15 v1.3 Expandable calculations with integers (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/generic/etexcmds/etexcmds.sty
Package: etexcmds 2019/12/15 v1.7 Avoid name clashes with e-TeX commands (HO)
)
\Hy@SavedSpaceFactor=\count303
(/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/puenc.def
File: puenc.def 2021-02-27 v7.00k Hyperref: PDF Unicode definition (HO)
)
Package hyperref Info: Option `unicode' set `true' on input line 4073.
Package hyperref Info: Hyper figures OFF on input line 4192.
Package hyperref Info: Link nesting OFF on input line 4197.
Package hyperref Info: Hyper index ON on input line 4200.
Package hyperref Info: Plain pages OFF on input line 4207.
Package hyperref Info: Backreferencing OFF on input line 4212.
Package hyperref Info: Implicit mode ON; LaTeX internals redefined.
Package hyperref Info: Bookmarks ON on input line 4445.
\c@Hy@tempcnt=\count304
(/usr/local/texlive/2021/texmf-dist/tex/latex/url/url.sty
\Urlmuskip=\muskip18
Package: url 2013/09/16 ver 3.4 Verb mode for urls, etc.
)
LaTeX Info: Redefining \url on input line 4804.
\XeTeXLinkMargin=\dimen184
(/usr/local/texlive/2021/texmf-dist/tex/generic/bitset/bitset.sty
Package: bitset 2019/12/09 v1.3 Handle bit-vector datatype (HO)
(/usr/local/texlive/2021/texmf-dist/tex/generic/bigintcalc/bigintcalc.sty
Package: bigintcalc 2019/12/15 v1.5 Expandable calculations on big integers (HO)
))
\Fld@menulength=\count305
\Field@Width=\dimen185
\Fld@charsize=\dimen186
Package hyperref Info: Hyper figures OFF on input line 6075.
Package hyperref Info: Link nesting OFF on input line 6080.
Package hyperref Info: Hyper index ON on input line 6083.
Package hyperref Info: backreferencing OFF on input line 6090.
Package hyperref Info: Link coloring OFF on input line 6095.
Package hyperref Info: Link coloring with OCG OFF on input line 6100.
Package hyperref Info: PDF/A mode OFF on input line 6105.
LaTeX Info: Redefining \ref on input line 6145.
LaTeX Info: Redefining \pageref on input line 6149.
(/usr/local/texlive/2021/texmf-dist/tex/latex/base/atbegshi-ltx.sty
Package: atbegshi-ltx 2020/08/17 v1.0a Emulation of the original atbegshi package
with kernel methods
)
\Hy@abspage=\count306
\c@Item=\count307
\c@Hfootnote=\count308
)
Package hyperref Info: Driver (autodetected): hxetex.
(/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/hxetex.def
File: hxetex.def 2021-02-27 v7.00k Hyperref driver for XeTeX
(/usr/local/texlive/2021/texmf-dist/tex/generic/stringenc/stringenc.sty
Package: stringenc 2019/11/29 v1.12 Convert strings between diff. encodings (HO)
)
\pdfm@box=\box54
\c@Hy@AnnotLevel=\count309
\HyField@AnnotCount=\count310
\Fld@listcount=\count311
\c@bookmark@seq@number=\count312
(/usr/local/texlive/2021/texmf-dist/tex/latex/rerunfilecheck/rerunfilecheck.sty
Package: rerunfilecheck 2019/12/05 v1.9 Rerun checks for auxiliary files (HO)
(/usr/local/texlive/2021/texmf-dist/tex/latex/base/atveryend-ltx.sty
Package: atveryend-ltx 2020/08/19 v1.0a Emulation of the original atvery package
with kernel methods
) (/usr/local/texlive/2021/texmf-dist/tex/generic/uniquecounter/uniquecounter.sty
Package: uniquecounter 2019/12/15 v1.4 Provide unlimited unique counter (HO)
)
Package uniquecounter Info: New unique counter `rerunfilecheck' on input line 286.
)
\Hy@SectionHShift=\skip63
) (/usr/local/texlive/2021/texmf-dist/tex/latex/setspace/setspace.sty
Package: setspace 2011/12/19 v6.7a set line spacing
) (/usr/local/texlive/2021/texmf-dist/tex/latex/titlesec/titlesec.sty
Package: titlesec 2019/10/16 v2.13 Sectioning titles
\ttl@box=\box55
\beforetitleunit=\skip64
\aftertitleunit=\skip65
\ttl@plus=\dimen187
\ttl@minus=\dimen188
\ttl@toksa=\toks25
\titlewidth=\dimen189
\titlewidthlast=\dimen190
\titlewidthfirst=\dimen191
) (./pseudocodes.aux)
\openout1 = `pseudocodes.aux'.
LaTeX Font Info: Checking defaults for OML/cmm/m/it on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for OMS/cmsy/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for OT1/cmr/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for T1/cmr/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for TS1/cmr/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for TU/lmr/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for OMX/cmex/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for U/cmr/m/n on input line 9.
LaTeX Font Info: ... okay on input line 9.
LaTeX Font Info: Checking defaults for OML/cmm/m/it on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for OMS/cmsy/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for OT1/cmr/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for T1/cmr/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for TS1/cmr/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for TU/lmr/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for OMX/cmex/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for U/cmr/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for PD1/pdf/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
LaTeX Font Info: Checking defaults for PU/pdf/m/n on input line 13.
LaTeX Font Info: ... okay on input line 13.
ABD: EverySelectfont initializing macros
LaTeX Info: Redefining \selectfont on input line 9.
LaTeX Info: Redefining \selectfont on input line 13.
Package fontspec Info: Adjusting the maths setup (use [no-math] to avoid
(fontspec) this).
\symlegacymaths=\mathgroup6
LaTeX Font Info: Overwriting symbol font `legacymaths' in version `bold'
(Font) OT1/cmr/m/n --> OT1/cmr/bx/n on input line 9.
LaTeX Font Info: Redeclaring math accent \acute on input line 9.
LaTeX Font Info: Redeclaring math accent \grave on input line 9.
LaTeX Font Info: Redeclaring math accent \ddot on input line 9.
LaTeX Font Info: Redeclaring math accent \tilde on input line 9.
LaTeX Font Info: Redeclaring math accent \bar on input line 9.
LaTeX Font Info: Redeclaring math accent \breve on input line 9.
LaTeX Font Info: Redeclaring math accent \check on input line 9.
LaTeX Font Info: Redeclaring math accent \hat on input line 9.
LaTeX Font Info: Redeclaring math accent \dot on input line 9.
LaTeX Font Info: Redeclaring math accent \mathring on input line 9.
LaTeX Font Info: Redeclaring math symbol \Gamma on input line 9.
LaTeX Font Info: Redeclaring math symbol \Delta on input line 9.
LaTeX Font Info: Redeclaring math symbol \Theta on input line 9.
LaTeX Font Info: Redeclaring math symbol \Lambda on input line 9.
LaTeX Font Info: Redeclaring math symbol \Xi on input line 9.
LaTeX Font Info: Redeclaring math symbol \Pi on input line 9.
LaTeX Font Info: Redeclaring math symbol \Sigma on input line 9.
LaTeX Font Info: Redeclaring math symbol \Upsilon on input line 9.
LaTeX Font Info: Redeclaring math symbol \Phi on input line 9.
LaTeX Font Info: Redeclaring math symbol \Psi on input line 9.
LaTeX Font Info: Redeclaring math symbol \Omega on input line 9.
LaTeX Font Info: Redeclaring math symbol \mathdollar on input line 9.
LaTeX Font Info: Redeclaring symbol font `operators' on input line 9.
(Font) OT1/cmr/m/n --> OT1/cmr/bx/n on input line 13.
LaTeX Font Info: Redeclaring math accent \acute on input line 13.
LaTeX Font Info: Redeclaring math accent \grave on input line 13.
LaTeX Font Info: Redeclaring math accent \ddot on input line 13.
LaTeX Font Info: Redeclaring math accent \tilde on input line 13.
LaTeX Font Info: Redeclaring math accent \bar on input line 13.
LaTeX Font Info: Redeclaring math accent \breve on input line 13.
LaTeX Font Info: Redeclaring math accent \check on input line 13.
LaTeX Font Info: Redeclaring math accent \hat on input line 13.
LaTeX Font Info: Redeclaring math accent \dot on input line 13.
LaTeX Font Info: Redeclaring math accent \mathring on input line 13.
LaTeX Font Info: Redeclaring math symbol \Gamma on input line 13.
LaTeX Font Info: Redeclaring math symbol \Delta on input line 13.
LaTeX Font Info: Redeclaring math symbol \Theta on input line 13.
LaTeX Font Info: Redeclaring math symbol \Lambda on input line 13.
LaTeX Font Info: Redeclaring math symbol \Xi on input line 13.
LaTeX Font Info: Redeclaring math symbol \Pi on input line 13.
LaTeX Font Info: Redeclaring math symbol \Sigma on input line 13.
LaTeX Font Info: Redeclaring math symbol \Upsilon on input line 13.
LaTeX Font Info: Redeclaring math symbol \Phi on input line 13.
LaTeX Font Info: Redeclaring math symbol \Psi on input line 13.
LaTeX Font Info: Redeclaring math symbol \Omega on input line 13.
LaTeX Font Info: Redeclaring math symbol \mathdollar on input line 13.
LaTeX Font Info: Redeclaring symbol font `operators' on input line 13.
LaTeX Font Info: Encoding `OT1' has changed to `TU' for symbol font
(Font) `operators' in the math version `normal' on input line 9.
(Font) `operators' in the math version `normal' on input line 13.
LaTeX Font Info: Overwriting symbol font `operators' in version `normal'
(Font) OT1/cmr/m/n --> TU/lmr/m/n on input line 9.
(Font) OT1/cmr/m/n --> TU/lmr/m/n on input line 13.
LaTeX Font Info: Encoding `OT1' has changed to `TU' for symbol font
(Font) `operators' in the math version `bold' on input line 9.
(Font) `operators' in the math version `bold' on input line 13.
LaTeX Font Info: Overwriting symbol font `operators' in version `bold'
(Font) OT1/cmr/bx/n --> TU/lmr/m/n on input line 9.
(Font) OT1/cmr/bx/n --> TU/lmr/m/n on input line 13.
LaTeX Font Info: Overwriting symbol font `operators' in version `normal'
(Font) TU/lmr/m/n --> TU/lmr/m/n on input line 9.
(Font) TU/lmr/m/n --> TU/lmr/m/n on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathit' in version `normal'
(Font) OT1/cmr/m/it --> TU/lmr/m/it on input line 9.
(Font) OT1/cmr/m/it --> TU/lmr/m/it on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathbf' in version `normal'
(Font) OT1/cmr/bx/n --> TU/lmr/b/n on input line 9.
(Font) OT1/cmr/bx/n --> TU/lmr/b/n on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathsf' in version `normal'
(Font) OT1/cmss/m/n --> TU/lmss/m/n on input line 9.
(Font) OT1/cmss/m/n --> TU/lmss/m/n on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathtt' in version `normal'
(Font) OT1/cmtt/m/n --> TU/lmtt/m/n on input line 9.
(Font) OT1/cmtt/m/n --> TU/lmtt/m/n on input line 13.
LaTeX Font Info: Overwriting symbol font `operators' in version `bold'
(Font) TU/lmr/m/n --> TU/lmr/b/n on input line 9.
(Font) TU/lmr/m/n --> TU/lmr/b/n on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathit' in version `bold'
(Font) OT1/cmr/bx/it --> TU/lmr/b/it on input line 9.
(Font) OT1/cmr/bx/it --> TU/lmr/b/it on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathsf' in version `bold'
(Font) OT1/cmss/bx/n --> TU/lmss/b/n on input line 9.
(Font) OT1/cmss/bx/n --> TU/lmss/b/n on input line 13.
LaTeX Font Info: Overwriting math alphabet `\mathtt' in version `bold'
(Font) OT1/cmtt/m/n --> TU/lmtt/b/n on input line 9.
LaTeX Font Info: Trying to load font information for U+msa on input line 20.
(Font) OT1/cmtt/m/n --> TU/lmtt/b/n on input line 13.
Package hyperref Info: Link coloring OFF on input line 13.
(/usr/local/texlive/2021/texmf-dist/tex/latex/hyperref/nameref.sty
Package: nameref 2021-04-02 v2.47 Cross-referencing by name of section
(/usr/local/texlive/2021/texmf-dist/tex/latex/refcount/refcount.sty
Package: refcount 2019/12/15 v3.6 Data extraction from label references (HO)
) (/usr/local/texlive/2021/texmf-dist/tex/generic/gettitlestring/gettitlestring.sty
Package: gettitlestring 2019/12/15 v1.6 Cleanup title references (HO)
)
\c@section@level=\count313
)
LaTeX Info: Redefining \ref on input line 13.
LaTeX Info: Redefining \pageref on input line 13.
LaTeX Info: Redefining \nameref on input line 13.
(./pseudocodes.out) (./pseudocodes.out)
\@outlinefile=\write3
\openout3 = `pseudocodes.out'.
(./pseudocodes.toc)
\tf@toc=\write4
\openout4 = `pseudocodes.toc'.
LaTeX Font Info: Font shape `TU/SongtiSCLight(0)/m/sl' in size <10.95> not available
(Font) Font shape `TU/SongtiSCLight(0)/m/it' tried instead on input line 16.
[1
]
Package hyperref Info: bookmark level for unknown algorithm defaults to 0 on input line 21.
[2
]
LaTeX Font Info: Trying to load font information for U+msa on input line 31.
(/usr/local/texlive/2021/texmf-dist/tex/latex/amsfonts/umsa.fd
File: umsa.fd 2013/01/14 v3.01 AMS symbols A
)
LaTeX Font Info: Trying to load font information for U+msb on input line 20.
LaTeX Font Info: Trying to load font information for U+msb on input line 31.
(/usr/local/texlive/2021/texmf-dist/tex/latex/amsfonts/umsb.fd
File: umsb.fd 2013/01/14 v3.01 AMS symbols B
)
Overfull \hbox (38.0069pt too wide) in paragraph at lines 32--33
[] []\TU/SongtiSCLight(0)/m/n/10.95 计 算 实 际 的 $\OML/cmm/m/it/10.95 Q$ \TU/SongtiSCLight(0)/m/n/10.95 值,| 即 $\OML/cmm/m/it/10.95 y[] \OT1/cmr/m/n/10.95 = []$
) [3
] [4
] [5
]
Underfull \hbox (badness 10000) in paragraph at lines 111--112
[] []\TU/SongtiSCLight(0)/m/n/10.95 计 算 实 际 的 $\OML/cmm/m/it/10.95 Q$ \TU/SongtiSCLight(0)/m/n/10.95 值,| 即 $\OML/cmm/m/it/10.95 y[] \OT1/cmr/m/n/10.95 =
[]
[1
[6
] [2
] [7
] (./pseudocodes.aux) )
]
Overfull \hbox (32.54117pt too wide) in paragraph at lines 183--183
[][]$[]\OML/cmm/m/it/9 J[]\OT1/cmr/m/n/9 (\OML/cmm/m/it/9 ^^R\OT1/cmr/m/n/9 ) = \OMS/cmsy/m/n/9 r[]\OML/cmm/m/it/9 Q[] [] []$|
[]
Overfull \hbox (15.41673pt too wide) in paragraph at lines 184--184
[][]$[]\OML/cmm/m/it/9 J[]\OT1/cmr/m/n/9 (\OML/cmm/m/it/9 ^^^\OT1/cmr/m/n/9 ) = \OMS/cmsy/m/n/9 r[]\OML/cmm/m/it/9 ^^K [] [] \OT1/cmr/m/n/9 + [] \OMS/cmsy/m/n/9 r[]\OML/cmm/m/it/9 f[] []$\TU/lmr/m/n/9 ,$[][] \OT1/cmr/m/n/9 =
[]
[8
] (./pseudocodes.aux)
Package rerunfilecheck Info: File `pseudocodes.out' has not changed.
(rerunfilecheck) Checksum: 4575BA7458AA23D6E696EFFE39D05727;640.
)
Here is how much of TeX's memory you used:
7847 strings out of 476919
208964 string characters out of 5821840
529246 words of memory out of 5000000
27739 multiletter control sequences out of 15000+600000
410995 words of font info for 73 fonts, out of 8000000 for 9000
14813 strings out of 476919
312635 string characters out of 5821840
653471 words of memory out of 5000000
34563 multiletter control sequences out of 15000+600000
413601 words of font info for 90 fonts, out of 8000000 for 9000
1348 hyphenation exceptions out of 8191
101i,11n,104p,414b,663s stack positions out of 5000i,500n,10000p,200000b,80000s
101i,13n,104p,676b,736s stack positions out of 5000i,500n,10000p,200000b,80000s
Output written on pseudocodes.pdf (2 pages).
Output written on pseudocodes.pdf (8 pages).

View File

@@ -0,0 +1,7 @@
\BOOKMARK [1][-]{section.1}{\376\377\152\041\162\110\131\007\165\050}{}% 1
\BOOKMARK [1][-]{section.2}{\376\377\000Q\000\040\000l\000e\000a\000r\000n\000i\000n\000g\173\227\154\325}{}% 2
\BOOKMARK [1][-]{section.3}{\376\377\000S\000a\000r\000s\000a\173\227\154\325}{}% 3
\BOOKMARK [1][-]{section.4}{\376\377\000P\000o\000l\000i\000c\000y\000\040\000G\000r\000a\000d\000i\000e\000n\000t\173\227\154\325}{}% 4
\BOOKMARK [1][-]{section.5}{\376\377\000D\000Q\000N\173\227\154\325}{}% 5
\BOOKMARK [1][-]{section.6}{\376\377\000S\000o\000f\000t\000Q\173\227\154\325}{}% 6
\BOOKMARK [1][-]{section.7}{\376\377\000S\000A\000C\173\227\154\325}{}% 7

View File

@@ -4,17 +4,96 @@
\usepackage{algorithmic}
\usepackage{amssymb}
\usepackage{amsmath}
\usepackage{hyperref}
% \usepackage[hidelinks]{hyperref} 去除超链接的红色框
\usepackage{setspace}
\usepackage{titlesec}
\usepackage{float} % 调用该包能够使用[H]
% \pagestyle{plain} % 去除页眉但是保留页脚编号都去掉plain换empty
\begin{document}
\begin{algorithm}
\tableofcontents % 目录注意要运行两下或者vscode保存两下才能显示
% \singlespacing
\clearpage
\section{模版备用}
\begin{algorithm}[H] % [H]固定位置
\floatname{algorithm}{{算法}}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\begin{algorithmic}[1] % [1]显示步数
\STATE 测试
\end{algorithmic}
\end{algorithm}
\clearpage
\section{Q learning算法}
\begin{algorithm}[H] % [H]固定位置
\floatname{algorithm}{{Q-learning算法}\footnotemark[1]}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\begin{algorithmic}[1] % [1]显示步数
\STATE 初始化Q表$Q(s,a)$为任意值,但其中$Q(s_{terminal},)=0$即终止状态对应的Q值为0
\FOR {回合数 = $1,M$}
\STATE 重置环境,获得初始状态$s_1$
\FOR {时步 = $1,t$}
\STATE 根据$\varepsilon-greedy$策略采样动作$a_t$
\STATE 环境根据$a_t$反馈奖励$r_t$和下一个状态$s_{t+1}$
\STATE {\bfseries 更新策略:}
\STATE $Q(s_t,a_t) \leftarrow Q(s_t,a_t)+\alpha[r_t+\gamma\max _{a}Q(s_{t+1},a)-Q(s_t,a_t)]$
\STATE 更新状态$s_{t+1} \leftarrow s_t$
\ENDFOR
\ENDFOR
\end{algorithmic}
\end{algorithm}
\footnotetext[1]{Reinforcement Learning: An Introduction}
\clearpage
\section{Sarsa算法}
\begin{algorithm}[H] % [H]固定位置
\floatname{algorithm}{{Sarsa算法}\footnotemark[1]}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\begin{algorithmic}[1] % [1]显示步数
\STATE 初始化Q表$Q(s,a)$为任意值,但其中$Q(s_{terminal},)=0$即终止状态对应的Q值为0
\FOR {回合数 = $1,M$}
\STATE 重置环境,获得初始状态$s_1$
\STATE 根据$\varepsilon-greedy$策略采样初始动作$a_1$
\FOR {时步 = $1,t$}
\STATE 环境根据$a_t$反馈奖励$r_t$和下一个状态$s_{t+1}$
\STATE 根据$\varepsilon-greedy$策略$s_{t+1}$和采样动作$a_{t+1}$
\STATE {\bfseries 更新策略:}
\STATE $Q(s_t,a_t) \leftarrow Q(s_t,a_t)+\alpha[r_t+\gamma Q(s_{t+1},a_{t+1})-Q(s_t,a_t)]$
\STATE 更新状态$s_{t+1} \leftarrow s_t$
\STATE 更新动作$a_{t+1} \leftarrow a_t$
\ENDFOR
\ENDFOR
\end{algorithmic}
\end{algorithm}
\footnotetext[1]{Reinforcement Learning: An Introduction}
\clearpage
\section{Policy Gradient算法}
\begin{algorithm}[H] % [H]固定位置
\floatname{algorithm}{{REINFORCE算法Monte-Carlo Policy Gradient}\footnotemark[1]}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\begin{algorithmic}[1] % [1]显示步数
\STATE 初始化策略参数$\boldsymbol{\theta} \in \mathbb{R}^{d^{\prime}}($ e.g., to $\mathbf{0})$
\FOR {回合数 = $1,M$}
\STATE 根据策略$\pi(\cdot \mid \cdot, \boldsymbol{\theta})$采样一个(或几个)回合的transition
\FOR {时步 = $1,t$}
\STATE 计算回报$G \leftarrow \sum_{k=t+1}^{T} \gamma^{k-t-1} R_{k}$
\STATE 更新策略$\boldsymbol{\theta} \leftarrow {\boldsymbol{\theta}+\alpha \gamma^{t}} G \nabla \ln \pi\left(A_{t} \mid S_{t}, \boldsymbol{\theta}\right)$
\ENDFOR
\ENDFOR
\end{algorithmic}
\end{algorithm}
\footnotetext[1]{Reinforcement Learning: An Introduction}
\clearpage
\section{DQN算法}
\begin{algorithm}[H] % [H]固定位置
\floatname{algorithm}{{DQN算法}}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\renewcommand{\algorithmicrequire}{\textbf{输入:}}
\renewcommand{\algorithmicensure}{\textbf{输出:}}
\begin{algorithmic}
\begin{algorithmic}[1]
% \REQUIRE $n \geq 0 \vee x \neq 0$ % 输入
% \ENSURE $y = x^n$ % 输出
\STATE 初始化策略网络参数$\theta$ % 初始化
@@ -24,40 +103,85 @@
\STATE 重置环境,获得初始状态$s_t$
\FOR {时步 = $1,t$}
\STATE 根据$\varepsilon-greedy$策略采样动作$a_t$
\STATE 环境根据$a_t$反馈奖励$s_t$和下一个状态$s_{t+1}$
\STATE 环境根据$a_t$反馈奖励$r_t$和下一个状态$s_{t+1}$
\STATE 存储transition即$(s_t,a_t,r_t,s_{t+1})$到经验回放$D$
\STATE 更新环境状态$s_{t+1} \leftarrow s_t$
\STATE {\bfseries 更新策略:}
\STATE$D$中采样一个batch的transition
\STATE 计算实际的$Q$值,即$y_{j}= \begin{cases}r_{j} & \text {对于终止状态} s_{j+1} \\ r_{j}+\gamma \max _{a^{\prime}} Q\left(s_{j+1}, a^{\prime} ; \theta\right) & \text {对于非终止状态} s_{j+1}\end{cases}$
\STATE 对损失 $\left(y_{j}-Q\left(s_{j}, a_{j} ; \theta\right)\right)^{2}$关于参数$\theta$做随机梯度下降
\STATE$C$步复制参数$\hat{Q} \leftarrow Q$
\ENDFOR
\STATE$C$个回合复制参数$\hat{Q}\leftarrow Q$(此处也可像原论文中放到小循环中改成每$C$步,但没有每$C$个回合稳定)
\ENDFOR
\end{algorithmic}
\end{algorithm}
\clearpage
\begin{algorithm}
\section{SoftQ算法}
\begin{algorithm}[H]
\floatname{algorithm}{{SoftQ算法}}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\begin{algorithmic}
\begin{algorithmic}[1]
\STATE 初始化参数$\theta$$\phi$% 初始化
\STATE 复制参数$\bar{\theta} \leftarrow \theta, \bar{\phi} \leftarrow \phi$
\STATE 初始化经验回放$D$
\FOR {回合数 = $1,M$}
\FOR {时步 = $1,t$}
\STATE 根据$a_{t} \leftarrow f^{\phi}\left(\xi ; \mathbf{s}_{t}\right)$采样动作,其中$\xi \sim \mathcal{N}(\mathbf{0}, \boldsymbol{I})$
\STATE 根据$\mathbf{a}_{t} \leftarrow f^{\phi}\left(\xi ; \mathbf{s}_{t}\right)$采样动作,其中$\xi \sim \mathcal{N}(\mathbf{0}, \boldsymbol{I})$
\STATE 环境根据$a_t$反馈奖励$s_t$和下一个状态$s_{t+1}$
\STATE 存储transition即$(s_t,a_t,r_t,s_{t+1})$到经验回放$D$
\STATE 更新环境状态$s_{t+1} \leftarrow s_t$
\STATE 待完善
\STATE {\bfseries 更新soft Q函数参数}
\STATE 对于每个$s^{(i)}_{t+1}$采样$\left\{\mathbf{a}^{(i, j)}\right\}_{j=0}^{M} \sim q_{\mathbf{a}^{\prime}}$
\STATE 计算empirical soft values $V_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}\right)$\footnotemark[1]
\STATE 计算empirical gradient $J_{Q}(\theta)$\footnotemark[2]
\STATE 根据$J_{Q}(\theta)$使用ADAM更新参数$\theta$
\STATE {\bfseries 更新策略:}
\STATE 对于每个$s^{(i)}_{t}$采样$\left\{\xi^{(i, j)}\right\}_{j=0}^{M} \sim \mathcal{N}(\mathbf{0}, \boldsymbol{I})$
\STATE 计算$\mathbf{a}_{t}^{(i, j)}=f^{\phi}\left(\xi^{(i, j)}, \mathbf{s}_{t}^{(i)}\right)$
\STATE 使用经验估计计算$\Delta f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)$\footnotemark[3]
\STATE 计算经验估计$\frac{\partial J_{\pi}\left(\phi ; \mathbf{s}_{t}\right)}{\partial \phi} \propto \mathbb{E}_{\xi}\left[\Delta f^{\phi}\left(\xi ; \mathbf{s}_{t}\right) \frac{\partial f^{\phi}\left(\xi ; \mathbf{s}_{t}\right)}{\partial \phi}\right]$,即$\hat{\nabla}_{\phi} J_{\pi}$
\STATE 根据$\hat{\nabla}_{\phi} J_{\pi}$使用ADAM更新参数$\phi$
\STATE
\ENDFOR
\ENDFOR
\STATE$C$个回合复制参数$\bar{\theta} \leftarrow \theta, \bar{\phi} \leftarrow \phi$
\ENDFOR
\end{algorithmic}
\end{algorithm}
\footnotetext[1]{$V_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}\right)=\alpha \log \mathbb{E}_{q_{\mathbf{a}^{\prime}}}\left[\frac{\exp \left(\frac{1}{\alpha} Q_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}, \mathbf{a}^{\prime}\right)\right)}{q_{\mathbf{a}^{\prime}}\left(\mathbf{a}^{\prime}\right)}\right]$}
\footnotetext[2]{$J_{Q}(\theta)=\mathbb{E}_{\mathbf{s}_{t} \sim q_{\mathbf{s}_{t}}, \mathbf{a}_{t} \sim q_{\mathbf{a}_{t}}}\left[\frac{1}{2}\left(\hat{Q}_{\mathrm{soft}}^{\bar{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-Q_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right)^{2}\right]$}
\footnotetext[3]{$\begin{aligned} \Delta f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)=& \mathbb{E}_{\mathbf{a}_{t} \sim \pi^{\phi}}\left[\left.\kappa\left(\mathbf{a}_{t}, f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)\right) \nabla_{\mathbf{a}^{\prime}} Q_{\mathrm{soft}}^{\theta}\left(\mathbf{s}_{t}, \mathbf{a}^{\prime}\right)\right|_{\mathbf{a}^{\prime}=\mathbf{a}_{t}}\right.\\ &\left.+\left.\alpha \nabla_{\mathbf{a}^{\prime}} \kappa\left(\mathbf{a}^{\prime}, f^{\phi}\left(\cdot ; \mathbf{s}_{t}\right)\right)\right|_{\mathbf{a}^{\prime}=\mathbf{a}_{t}}\right] \end{aligned}$}
\clearpage
\section{SAC算法}
\begin{algorithm}[H] % [H]固定位置
\floatname{algorithm}{{Soft Actor Critic算法}}
\renewcommand{\thealgorithm}{} % 去掉算法标号
\caption{}
\begin{algorithmic}[1]
\STATE 初始化两个Actor的网络参数$\theta_1,\theta_2$以及一个Critic网络参数$\phi$ % 初始化
\STATE 复制参数到目标网络$\bar{\theta_1} \leftarrow \theta_1,\bar{\theta_2} \leftarrow \theta_2,$
\STATE 初始化经验回放$D$
\FOR {回合数 = $1,M$}
\STATE 重置环境,获得初始状态$s_t$
\FOR {时步 = $1,t$}
\STATE 根据$\boldsymbol{a}_{t} \sim \pi_{\phi}\left(\boldsymbol{a}_{t} \mid \mathbf{s}_{t}\right)$采样动作$a_t$
\STATE 环境反馈奖励和下一个状态,$\mathbf{s}_{t+1} \sim p\left(\mathbf{s}_{t+1} \mid \mathbf{s}_{t}, \mathbf{a}_{t}\right)$
\STATE 存储transition到经验回放中$\mathcal{D} \leftarrow \mathcal{D} \cup\left\{\left(\mathbf{s}_{t}, \mathbf{a}_{t}, r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right), \mathbf{s}_{t+1}\right)\right\}$
\STATE 更新环境状态$s_{t+1} \leftarrow s_t$
\STATE {\bfseries 更新策略:}
\STATE 更新$Q$函数,$\theta_{i} \leftarrow \theta_{i}-\lambda_{Q} \hat{\nabla}_{\theta_{i}} J_{Q}\left(\theta_{i}\right)$ for $i \in\{1,2\}$\footnotemark[1]\footnotemark[2]
\STATE 更新策略权重,$\phi \leftarrow \phi-\lambda_{\pi} \hat{\nabla}_{\phi} J_{\pi}(\phi)$ \footnotemark[3]
\STATE 调整temperature$\alpha \leftarrow \alpha-\lambda \hat{\nabla}_{\alpha} J(\alpha)$ \footnotemark[4]
\STATE 更新目标网络权重,$\bar{\theta}_{i} \leftarrow \tau \theta_{i}+(1-\tau) \bar{\theta}_{i}$ for $i \in\{1,2\}$
\ENDFOR
\ENDFOR
\end{algorithmic}
\end{algorithm}
\footnotetext[1]{$J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\left(r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\gamma \mathbb{E}_{\mathbf{s}_{t+1} \sim p}\left[V_{\bar{\theta}}\left(\mathbf{s}_{t+1}\right)\right]\right)\right)^{2}\right]$}
\footnotetext[2]{$\hat{\nabla}_{\theta} J_{Q}(\theta)=\nabla_{\theta} Q_{\theta}\left(\mathbf{a}_{t}, \mathbf{s}_{t}\right)\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\left(r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\gamma\left(Q_{\bar{\theta}}\left(\mathbf{s}_{t+1}, \mathbf{a}_{t+1}\right)-\alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t+1} \mid \mathbf{s}_{t+1}\right)\right)\right)\right)\right.$}
\footnotetext[3]{$\hat{\nabla}_{\phi} J_{\pi}(\phi)=\nabla_{\phi} \alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right)+\left(\nabla_{\mathbf{a}_{t}} \alpha \log \left(\pi_{\phi}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right)-\nabla_{\mathbf{a}_{t}} Q\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right) \nabla_{\phi} f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)$,$\mathbf{a}_{t}=f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)$}
\footnotetext[4]{$J(\alpha)=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{t}}\left[-\alpha \log \pi_{t}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)-\alpha \overline{\mathcal{H}}\right]$}
\clearpage
\end{document}

View File

@@ -0,0 +1,7 @@
\contentsline {section}{\numberline {1}模版备用}{2}{section.1}%
\contentsline {section}{\numberline {2}Q learning算法}{3}{section.2}%
\contentsline {section}{\numberline {3}Sarsa算法}{4}{section.3}%
\contentsline {section}{\numberline {4}Policy Gradient算法}{5}{section.4}%
\contentsline {section}{\numberline {5}DQN算法}{6}{section.5}%
\contentsline {section}{\numberline {6}SoftQ算法}{7}{section.6}%
\contentsline {section}{\numberline {7}SAC算法}{8}{section.7}%

View File

@@ -1,218 +0,0 @@
# DQN
## 原理简介
DQN是Q-leanning算法的优化和延伸Q-leaning中使用有限的Q表存储值的信息而DQN中则用神经网络替代Q表存储信息这样更适用于高维的情况相关知识基础可参考[datawhale李宏毅笔记-Q学习](https://datawhalechina.github.io/easy-rl/#/chapter6/chapter6)。
论文方面主要可以参考两篇一篇就是2013年谷歌DeepMind团队的[Playing Atari with Deep Reinforcement Learning](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)一篇是也是他们团队后来在Nature杂志上发表的[Human-level control through deep reinforcement learning](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf)。后者在算法层面增加target q-net也可以叫做Nature DQN。
Nature DQN使用了两个Q网络一个当前Q网络𝑄用来选择动作更新模型参数另一个目标Q网络𝑄用于计算目标Q值。目标Q网络的网络参数不需要迭代更新而是每隔一段时间从当前Q网络𝑄复制过来即延时更新这样可以减少目标Q值和当前的Q值相关性。
要注意的是两个Q网络的结构是一模一样的。这样才可以复制网络参数。Nature DQN和[Playing Atari with Deep Reinforcement Learning](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)相比除了用一个新的相同结构的目标Q网络来计算目标Q值以外其余部分基本是完全相同的。细节也可参考[强化学习Deep Q-Learning进阶之Nature DQN](https://www.cnblogs.com/pinard/p/9756075.html)。
https://blog.csdn.net/JohnJim0/article/details/109557173)
## 伪代码
<img src="assets/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0pvaG5KaW0w,size_16,color_FFFFFF,t_70.png" alt="img" style="zoom:50%;" />
## 代码实现
### RL接口
首先是强化学习训练的基本接口,即通用的训练模式:
```python
for i_episode in range(MAX_EPISODES):
state = env.reset() # reset环境状态
for i_step in range(MAX_STEPS):
action = agent.choose_action(state) # 根据当前环境state选择action
next_state, reward, done, _ = env.step(action) # 更新环境参数
agent.memory.push(state, action, reward, next_state, done) # 将state等这些transition存入memory
agent.update() # 每步更新网络
state = next_state # 跳转到下一个状态
if done:
break
```
每个episode加一个MAX_STEPS也可以使用while not done, 加这个max_steps有时是因为比如gym环境训练目标就是在200个step下达到200的reward或者是当完成一个episode的步数较多时也可以设置基本流程跟所有伪代码一致如下
1. agent选择动作
2. 环境根据agent的动作反馈出next_state和reward
3. agent进行更新如有memory就会将transition(包含staterewardaction等)存入memory中
4. 跳转到下一个状态
5. 如果done了就跳出循环进行下一个episode的训练。
想要实现完整的算法还需要创建QnetReplaybuffer等类
### 两个Q网络
上文讲了Nature DQN中有两个Q网络一个是policy_net一个是延时更新的target_net两个网络的结构是一模一样的如下(见```model.py```)注意DQN使用的Qnet就是全连接网络即FCH
```python
import torch.nn as nn
import torch.nn.functional as F
class FCN(nn.Module):
def __init__(self, n_states=4, n_actions=18):
""" 初始化q网络为全连接网络
n_states: 输入的feature即环境的state数目
n_actions: 输出的action总个数
"""
super(FCN, self).__init__()
self.fc1 = nn.Linear(n_states, 128) # 输入层
self.fc2 = nn.Linear(128, 128) # 隐藏层
self.fc3 = nn.Linear(128, n_actions) # 输出层
def forward(self, x):
# 各层对应的激活函数
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
```
输入为n_states输出为n_actions包含一个128维度的隐藏层这里根据需要可增加隐藏层维度和数量然后一般使用relu激活函数这里跟深度学习的网路设置是一样的。
### Replay Buffer
然后就是Replay Memory了其作用主要是是克服经验数据的相关性correlated data和非平稳分布non-stationary distribution问题实现如下(见```memory.py```)
```python
import random
import numpy as np
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*batch)
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
```
参数capacity表示buffer的容量主要包括push和sample两个步骤push是将transitions放到memory中sample是从memory随机抽取一些transition。
### Agent类
在```agent.py```中我们定义强化学习算法类,包括```choose_action```(选择动作使用e-greedy策略时会多一个```predict```函数,下面会将到)和```update```(更新)等函数。
在类中建立两个网络以及optimizer和memory
```python
self.policy_net = MLP(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
self.target_net = MLP(n_states, n_actions,hidden_dim=cfg.hidden_dim).to(self.device)
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # copy params from policy net
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
self.memory = ReplayBuffer(cfg.memory_capacity)
```
然后是选择action
```python
def choose_action(self, state):
'''选择动作
'''
self.frame_idx += 1
if random.random() > self.epsilon(self.frame_idx):
action = self.predict(state)
else:
action = random.randrange(self.n_actions)
return action
```
这里使用e-greedy策略即设置一个参数epsilon如果生成的随机数大于epsilon就根据网络预测的选择action否则还是随机选择action这个epsilon是会逐渐减小的可以使用线性或者指数减小的方式但不会减小到零这样在训练稳定时还能保持一定的探索这部分可以学习探索与利用(exploration and exploition)相关知识。
上面讲到的预测函数其实就是根据state选取q值最大的action如下
```python
def predict(self,state):
with torch.no_grad():
state = torch.tensor([state], device=self.device, dtype=torch.float32)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item()
```
然后是更新函数了:
```python
def update(self):
if len(self.memory) < self.batch_size:
return
# 从memory中随机采样transition
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
self.batch_size)
'''转为张量
例如tensor([[-4.5543e-02, -2.3910e-01, 1.8344e-02, 2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02, 2.3400e-01]])'''
state_batch = torch.tensor(
state_batch, device=self.device, dtype=torch.float)
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(
1) # 例如tensor([[1],...,[0]])
reward_batch = torch.tensor(
reward_batch, device=self.device, dtype=torch.float) # tensor([1., 1.,...,1])
next_state_batch = torch.tensor(
next_state_batch, device=self.device, dtype=torch.float)
done_batch = torch.tensor(np.float32(
done_batch), device=self.device)
'''计算当前(s_t,a)对应的Q(s_t, a)'''
'''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
q_values = self.policy_net(state_batch).gather(
dim=1, index=action_batch) # 等价于self.forward
# 计算所有next states的V(s_{t+1})即通过target_net中选取reward最大的对应states
next_q_values = self.target_net(next_state_batch).max(
1)[0].detach() # 比如tensor([ 0.0060, -0.0171,...,])
# 计算 expected_q_value
# 对于终止状态此时done_batch[0]=1, 对应的expected_q_value等于reward
expected_q_values = reward_batch + \
self.gamma * next_q_values * (1-done_batch)
# self.loss = F.smooth_l1_loss(q_values,expected_q_values.unsqueeze(1)) # 计算 Huber loss
loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) # 计算 均方误差loss
# 优化模型
self.optimizer.zero_grad() # zero_grad清除上一步所有旧的gradients from the last step
# loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分
loss.backward()
# for param in self.policy_net.parameters(): # clip防止梯度爆炸
# param.grad.data.clamp_(-1, 1)
self.optimizer.step() # 更新模型
```
更新遵循伪代码的以下部分:
<img src="assets/image-20210507162813393.png" alt="image-20210507162813393" style="zoom:50%;" />
首先从replay buffer中选取一个batch的数据计算loss然后进行minibatch SGD。
然后是保存与加载模型的部分,如下:
```python
def save(self, path):
torch.save(self.target_net.state_dict(), path+'dqn_checkpoint.pth')
def load(self, path):
self.target_net.load_state_dict(torch.load(path+'dqn_checkpoint.pth'))
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
param.data.copy_(target_param.data)
```
### 实验结果
训练结果如下:
<img src="assets/train_rewards_curve.png" alt="train_rewards_curve" style="zoom: 67%;" />
<img src="assets/eval_rewards_curve.png" alt="eval_rewards_curve" style="zoom:67%;" />
## 参考
[with torch.no_grad()](https://www.jianshu.com/p/1cea017f5d11)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2022-08-11 09:52:23
LastEditTime: 2022-08-18 14:27:18
@Discription:
@Environment: python 3.7.7
'''
@@ -23,10 +23,10 @@ class DQN:
def __init__(self,n_actions,model,memory,cfg):
self.n_actions = n_actions
self.device = torch.device(cfg.device) # cpu or cuda
self.gamma = cfg.gamma # 奖励的折扣因子
# e-greedy策略相关参数
self.sample_count = 0 # 用于epsilon的衰减计数
self.device = torch.device(cfg.device)
self.gamma = cfg.gamma
## e-greedy parameters
self.sample_count = 0 # sample count for epsilon decay
self.epsilon = cfg.epsilon_start
self.sample_count = 0
self.epsilon_start = cfg.epsilon_start
@@ -35,61 +35,78 @@ class DQN:
self.batch_size = cfg.batch_size
self.policy_net = model.to(self.device)
self.target_net = model.to(self.device)
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net
## copy parameters from policy net to target net
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()):
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器
self.memory = memory # 经验回放
# self.target_net.load_state_dict(self.policy_net.state_dict()) # or use this to copy parameters
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
self.memory = memory
self.update_flag = False
def sample(self, state):
''' 选择动作
def sample_action(self, state):
''' sample action with e-greedy policy
'''
self.sample_count += 1
# epsilon must decay(linear,exponential and etc.) for balancing exploration and exploitation
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay) # epsilon是会递减的这里选择指数递减
math.exp(-1. * self.sample_count / self.epsilon_decay)
if random.random() > self.epsilon:
with torch.no_grad():
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # 选择Q值最大的动作
action = q_values.max(1)[1].item() # choose action corresponding to the maximum q value
else:
action = random.randrange(self.n_actions)
return action
def predict(self,state):
def predict_action(self,state):
''' predict action
'''
with torch.no_grad():
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # 选择Q值最大的动作
action = q_values.max(1)[1].item() # choose action corresponding to the maximum q value
return action
def update(self):
if len(self.memory) < self.batch_size: # 当memory中不满足一个批量时不更新策略
if len(self.memory) < self.batch_size: # when transitions in memory donot meet a batch, not update
return
# 从经验回放中(replay memory)中随机采样一个批量的转移(transition)
else:
if not self.update_flag:
print("begin to update!")
self.update_flag = True
# sample a batch of transitions from replay buffer
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
self.batch_size)
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1)
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float)
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch), device=self.device)
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 计算当前状态(s_t,a)对应的Q(s_t, a)
next_q_values = self.target_net(next_state_batch).max(1)[0].detach() # 计算下一时刻的状态(s_t_,a)对应的Q值
# 计算期望的Q值对于终止状态此时done_batch[0]=1, 对应的expected_q_value等于reward
expected_q_values = reward_batch + self.gamma * next_q_values * (1-done_batch)
loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) # 计算均方根损失
# 优化更新模型
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) # shape(batchsize,1)
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize)
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
done_batch = torch.tensor(np.float32(done_batch), device=self.device).unsqueeze(1) # shape(batchsize,1)
# print(state_batch.shape,action_batch.shape,reward_batch.shape,next_state_batch.shape,done_batch.shape)
# compute current Q(s_t,a), it is 'y_j' in pseucodes
q_value_batch = self.policy_net(state_batch).gather(dim=1, index=action_batch) # shape(batchsize,1),requires_grad=True
# print(q_values.requires_grad)
# compute max(Q(s_t+1,A_t+1)) respects to actions A, next_max_q_value comes from another net and is just regarded as constant for q update formula below, thus should detach to requires_grad=False
next_max_q_value_batch = self.target_net(next_state_batch).max(1)[0].detach().unsqueeze(1)
# print(q_values.shape,next_q_values.shape)
# compute expected q value, for terminal state, done_batch[0]=1, and expected_q_value=rewardcorrespondingly
expected_q_value_batch = reward_batch + self.gamma * next_max_q_value_batch* (1-done_batch)
# print(expected_q_value_batch.shape,expected_q_value_batch.requires_grad)
loss = nn.MSELoss()(q_value_batch, expected_q_value_batch) # shape same to
# backpropagation
self.optimizer.zero_grad()
loss.backward()
for param in self.policy_net.parameters(): # clip防止梯度爆炸
# clip to avoid gradient explosion
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
def save(self, path):
def save_model(self, path):
from pathlib import Path
# create path
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.target_net.state_dict(), path+'checkpoint.pth')
torch.save(self.target_net.state_dict(), f"{path}/checkpoint.pt")
def load(self, path):
self.target_net.load_state_dict(torch.load(path+'checkpoint.pth'))
def load_model(self, path):
self.target_net.load_state_dict(torch.load(f"{path}/checkpoint.pt"))
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
param.data.copy_(target_param.data)

View File

@@ -1 +0,0 @@
{"algo_name": "DQN", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 500, "lr": 0.0001, "memory_capacity": 100000, "batch_size": 64, "target_update": 4, "hidden_dim": 256, "device": "cpu", "result_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220815-185119/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220815-185119/models/", "show_fig": false, "save_fig": true}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

View File

@@ -0,0 +1 @@
{"algo_name": "DQN", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "gamma": 0.95, "epsilon_start": 0.95, "epsilon_end": 0.01, "epsilon_decay": 500, "lr": 0.0001, "memory_capacity": 100000, "batch_size": 64, "target_update": 4, "hidden_dim": 256, "device": "cpu", "seed": 10, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220818-143132/results", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/DQN/outputs/CartPole-v0/20220818-143132/models", "show_fig": false, "save_fig": true}

View File

@@ -0,0 +1,21 @@
episodes,rewards
0,200.0
1,200.0
2,200.0
3,200.0
4,200.0
5,200.0
6,200.0
7,200.0
8,200.0
9,200.0
10,200.0
11,200.0
12,200.0
13,200.0
14,200.0
15,200.0
16,200.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 200.0
5 3 200.0
6 4 200.0
7 5 200.0
8 6 200.0
9 7 200.0
10 8 200.0
11 9 200.0
12 10 200.0
13 11 200.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 200.0
18 16 200.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -0,0 +1,201 @@
episodes,rewards
0,38.0
1,16.0
2,37.0
3,15.0
4,22.0
5,34.0
6,20.0
7,12.0
8,16.0
9,14.0
10,13.0
11,21.0
12,14.0
13,12.0
14,17.0
15,12.0
16,10.0
17,14.0
18,10.0
19,10.0
20,16.0
21,9.0
22,14.0
23,13.0
24,10.0
25,9.0
26,12.0
27,12.0
28,14.0
29,11.0
30,9.0
31,8.0
32,9.0
33,11.0
34,12.0
35,10.0
36,11.0
37,10.0
38,10.0
39,18.0
40,13.0
41,15.0
42,10.0
43,9.0
44,14.0
45,14.0
46,23.0
47,17.0
48,15.0
49,15.0
50,20.0
51,28.0
52,36.0
53,36.0
54,23.0
55,27.0
56,53.0
57,19.0
58,35.0
59,62.0
60,57.0
61,38.0
62,61.0
63,65.0
64,58.0
65,43.0
66,67.0
67,56.0
68,91.0
69,128.0
70,71.0
71,126.0
72,100.0
73,200.0
74,200.0
75,200.0
76,200.0
77,200.0
78,200.0
79,200.0
80,200.0
81,200.0
82,200.0
83,200.0
84,200.0
85,200.0
86,200.0
87,200.0
88,200.0
89,200.0
90,200.0
91,200.0
92,200.0
93,200.0
94,200.0
95,200.0
96,200.0
97,200.0
98,200.0
99,200.0
100,200.0
101,200.0
102,200.0
103,200.0
104,200.0
105,200.0
106,200.0
107,200.0
108,200.0
109,200.0
110,200.0
111,200.0
112,200.0
113,200.0
114,200.0
115,200.0
116,200.0
117,200.0
118,200.0
119,200.0
120,200.0
121,200.0
122,200.0
123,200.0
124,200.0
125,200.0
126,200.0
127,200.0
128,200.0
129,200.0
130,200.0
131,200.0
132,200.0
133,200.0
134,200.0
135,200.0
136,200.0
137,200.0
138,200.0
139,200.0
140,200.0
141,200.0
142,200.0
143,200.0
144,200.0
145,200.0
146,200.0
147,200.0
148,200.0
149,200.0
150,200.0
151,200.0
152,200.0
153,200.0
154,200.0
155,200.0
156,200.0
157,200.0
158,200.0
159,200.0
160,200.0
161,200.0
162,200.0
163,200.0
164,200.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,200.0
197,200.0
198,200.0
199,200.0
1 episodes rewards
2 0 38.0
3 1 16.0
4 2 37.0
5 3 15.0
6 4 22.0
7 5 34.0
8 6 20.0
9 7 12.0
10 8 16.0
11 9 14.0
12 10 13.0
13 11 21.0
14 12 14.0
15 13 12.0
16 14 17.0
17 15 12.0
18 16 10.0
19 17 14.0
20 18 10.0
21 19 10.0
22 20 16.0
23 21 9.0
24 22 14.0
25 23 13.0
26 24 10.0
27 25 9.0
28 26 12.0
29 27 12.0
30 28 14.0
31 29 11.0
32 30 9.0
33 31 8.0
34 32 9.0
35 33 11.0
36 34 12.0
37 35 10.0
38 36 11.0
39 37 10.0
40 38 10.0
41 39 18.0
42 40 13.0
43 41 15.0
44 42 10.0
45 43 9.0
46 44 14.0
47 45 14.0
48 46 23.0
49 47 17.0
50 48 15.0
51 49 15.0
52 50 20.0
53 51 28.0
54 52 36.0
55 53 36.0
56 54 23.0
57 55 27.0
58 56 53.0
59 57 19.0
60 58 35.0
61 59 62.0
62 60 57.0
63 61 38.0
64 62 61.0
65 63 65.0
66 64 58.0
67 65 43.0
68 66 67.0
69 67 56.0
70 68 91.0
71 69 128.0
72 70 71.0
73 71 126.0
74 72 100.0
75 73 200.0
76 74 200.0
77 75 200.0
78 76 200.0
79 77 200.0
80 78 200.0
81 79 200.0
82 80 200.0
83 81 200.0
84 82 200.0
85 83 200.0
86 84 200.0
87 85 200.0
88 86 200.0
89 87 200.0
90 88 200.0
91 89 200.0
92 90 200.0
93 91 200.0
94 92 200.0
95 93 200.0
96 94 200.0
97 95 200.0
98 96 200.0
99 97 200.0
100 98 200.0
101 99 200.0
102 100 200.0
103 101 200.0
104 102 200.0
105 103 200.0
106 104 200.0
107 105 200.0
108 106 200.0
109 107 200.0
110 108 200.0
111 109 200.0
112 110 200.0
113 111 200.0
114 112 200.0
115 113 200.0
116 114 200.0
117 115 200.0
118 116 200.0
119 117 200.0
120 118 200.0
121 119 200.0
122 120 200.0
123 121 200.0
124 122 200.0
125 123 200.0
126 124 200.0
127 125 200.0
128 126 200.0
129 127 200.0
130 128 200.0
131 129 200.0
132 130 200.0
133 131 200.0
134 132 200.0
135 133 200.0
136 134 200.0
137 135 200.0
138 136 200.0
139 137 200.0
140 138 200.0
141 139 200.0
142 140 200.0
143 141 200.0
144 142 200.0
145 143 200.0
146 144 200.0
147 145 200.0
148 146 200.0
149 147 200.0
150 148 200.0
151 149 200.0
152 150 200.0
153 151 200.0
154 152 200.0
155 153 200.0
156 154 200.0
157 155 200.0
158 156 200.0
159 157 200.0
160 158 200.0
161 159 200.0
162 160 200.0
163 161 200.0
164 162 200.0
165 163 200.0
166 164 200.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 200.0
199 197 200.0
200 198 200.0
201 199 200.0

View File

@@ -1,23 +1,23 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
parent_path = os.path.dirname(curr_path) # parent path
sys.path.append(parent_path) # add path to system path
import gym
import torch
import datetime
import numpy as np
import argparse
from common.utils import save_results
from common.utils import save_results,all_seed
from common.utils import plot_rewards,save_args
from common.models import MLP
from common.memories import ReplayBuffer
from dqn import DQN
def get_args():
""" 超参数
""" hyperparameters
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # 获取当前时间
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='DQN',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
@@ -33,102 +33,101 @@ def get_args():
parser.add_argument('--target_update',default=4,type=int)
parser.add_argument('--hidden_dim',default=256,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=10,type=int,help="seed")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/' )
'/' + curr_time + '/results' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/' )
'/' + curr_time + '/models' )
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
def env_agent_config(cfg,seed=1):
''' 创建环境和智能体
def env_agent_config(cfg):
''' create env and agent
'''
env = gym.make(cfg.env_name) # 创建环境
n_states = env.observation_space.shape[0] # 状态维度
n_actions = env.action_space.n # 动作维度
print(f"状态数:{n_states},动作数:{n_actions}")
env = gym.make(cfg.env_name) # create env
if cfg.seed !=0: # set random seed
all_seed(env,seed=cfg.seed)
n_states = env.observation_space.shape[0] # state dimension
n_actions = env.action_space.n # action dimension
print(f"state dim: {n_states}, action dim: {n_actions}")
model = MLP(n_states,n_actions,hidden_dim=cfg.hidden_dim)
memory = ReplayBuffer(cfg.memory_capacity) # 经验回放
agent = DQN(n_actions,model,memory,cfg) # 创建智能体
if seed !=0: # 设置随机种子
torch.manual_seed(seed)
env.seed(seed)
np.random.seed(seed)
memory = ReplayBuffer(cfg.memory_capacity) # replay buffer
agent = DQN(n_actions,model,memory,cfg) # create agent
return env, agent
def train(cfg, env, agent):
''' 训练
'''
print("开始训练!")
print(f"回合:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}")
rewards = [] # 记录所有回合的奖励
print("start training!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.train_eps):
ep_reward = 0 # 记录一回合内的奖励
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # 重置环境,返回初始状态
state = env.reset() # reset and obtain initial state
while True:
ep_step += 1
action = agent.sample(state) # 选择动作
next_state, reward, done, _ = env.step(action) # 更新环境,返回transition
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action) # update env and return transitions
agent.memory.push(state, action, reward,
next_state, done) # 保存transition
state = next_state # 更新下一个状态
agent.update() # 更新智能体
ep_reward += reward # 累加奖励
next_state, done) # save transitions
state = next_state # update next state for env
agent.update() # update agent
ep_reward += reward #
if done:
break
if (i_ep + 1) % cfg.target_update == 0: # 智能体目标网络更新
if (i_ep + 1) % cfg.target_update == 0: # target net update, target_update means "C" in pseucodes
agent.target_net.load_state_dict(agent.policy_net.state_dict())
steps.append(ep_step)
rewards.append(ep_reward)
if (i_ep + 1) % 10 == 0:
print(f'回合:{i_ep+1}/{cfg.train_eps},奖励:{ep_reward:.2f}Epislon{agent.epsilon:.3f}')
print("完成训练!")
print(f'Episode: {i_ep+1}/{cfg.train_eps}, Reward: {ep_reward:.2f}: Epislon: {agent.epsilon:.3f}')
print("finish training!")
env.close()
res_dic = {'rewards':rewards}
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg, env, agent):
print("开始测试!")
print(f"回合:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}")
rewards = [] # 记录所有回合的奖励
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = []
for i_ep in range(cfg.test_eps):
ep_reward = 0 # 记录一回合内的奖励
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # 重置环境,返回初始状态
state = env.reset() # reset and obtain initial state
while True:
ep_step+=1
action = agent.predict(state) # 选择动作
next_state, reward, done, _ = env.step(action) # 更新环境返回transition
state = next_state # 更新下一个状态
ep_reward += reward # 累加奖励
action = agent.predict_action(state) # predict action
next_state, reward, done, _ = env.step(action)
state = next_state
ep_reward += reward
if done:
break
steps.append(ep_step)
rewards.append(ep_reward)
print(f'回合:{i_ep+1}/{cfg.test_eps}奖励:{ep_reward:.2f}')
print("完成测试")
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
print("finish testing!")
env.close()
return {'rewards':rewards}
return {'episodes':range(len(rewards)),'rewards':rewards}
if __name__ == "__main__":
cfg = get_args()
# 训练
# training
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
save_args(cfg,path = cfg.result_path) # 保存参数到模型路径上
agent.save(path = cfg.model_path) # 保存模型
save_results(res_dic, tag = 'train', path = cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train")
# 测试
env, agent = env_agent_config(cfg) # 也可以不加,加这一行的是为了避免训练之后环境可能会出现问题,因此新建一个环境用于测试
agent.load(path = cfg.model_path) # 导入模型
save_args(cfg,path = cfg.result_path) # save parameters
agent.save_model(path = cfg.model_path) # save models
save_results(res_dic, tag = 'train', path = cfg.result_path) # save results
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train") # plot results
# testing
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
agent.load_model(path = cfg.model_path) # load model
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path = cfg.result_path) # 保存结果
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test") # 画出结果
path = cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

View File

@@ -0,0 +1,16 @@
{
"algo_name": "PolicyGradient",
"env_name": "CartPole-v0",
"train_eps": 200,
"test_eps": 20,
"gamma": 0.99,
"lr": 0.005,
"update_fre": 8,
"hidden_dim": 36,
"device": "cpu",
"seed": 1,
"result_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220822-174059/results/",
"model_path": "/Users/jj/Desktop/rl-tutorials/codes/PolicyGradient/outputs/CartPole-v0/20220822-174059/models/",
"save_fig": true,
"show_fig": false
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

View File

@@ -0,0 +1,21 @@
episodes,rewards
0,200.0
1,200.0
2,165.0
3,200.0
4,200.0
5,200.0
6,200.0
7,200.0
8,200.0
9,200.0
10,200.0
11,168.0
12,200.0
13,200.0
14,200.0
15,115.0
16,198.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 165.0
5 3 200.0
6 4 200.0
7 5 200.0
8 6 200.0
9 7 200.0
10 8 200.0
11 9 200.0
12 10 200.0
13 11 168.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 115.0
18 16 198.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@@ -0,0 +1,201 @@
episodes,rewards
0,26.0
1,53.0
2,10.0
3,37.0
4,22.0
5,21.0
6,12.0
7,34.0
8,38.0
9,40.0
10,23.0
11,14.0
12,16.0
13,25.0
14,15.0
15,23.0
16,11.0
17,28.0
18,21.0
19,62.0
20,33.0
21,27.0
22,15.0
23,17.0
24,26.0
25,35.0
26,26.0
27,14.0
28,42.0
29,45.0
30,34.0
31,39.0
32,31.0
33,17.0
34,42.0
35,41.0
36,31.0
37,39.0
38,28.0
39,12.0
40,36.0
41,33.0
42,47.0
43,40.0
44,63.0
45,36.0
46,64.0
47,79.0
48,49.0
49,40.0
50,65.0
51,47.0
52,51.0
53,30.0
54,26.0
55,41.0
56,86.0
57,61.0
58,38.0
59,200.0
60,49.0
61,70.0
62,61.0
63,101.0
64,200.0
65,152.0
66,108.0
67,46.0
68,72.0
69,87.0
70,27.0
71,126.0
72,46.0
73,25.0
74,14.0
75,42.0
76,38.0
77,55.0
78,42.0
79,51.0
80,67.0
81,83.0
82,178.0
83,115.0
84,140.0
85,97.0
86,85.0
87,61.0
88,153.0
89,200.0
90,200.0
91,200.0
92,200.0
93,64.0
94,200.0
95,200.0
96,157.0
97,128.0
98,160.0
99,35.0
100,140.0
101,113.0
102,200.0
103,154.0
104,200.0
105,200.0
106,200.0
107,198.0
108,137.0
109,200.0
110,200.0
111,102.0
112,200.0
113,200.0
114,200.0
115,200.0
116,148.0
117,200.0
118,200.0
119,200.0
120,200.0
121,200.0
122,194.0
123,200.0
124,200.0
125,200.0
126,183.0
127,200.0
128,200.0
129,200.0
130,200.0
131,200.0
132,200.0
133,200.0
134,200.0
135,200.0
136,93.0
137,96.0
138,84.0
139,103.0
140,79.0
141,104.0
142,82.0
143,105.0
144,200.0
145,200.0
146,171.0
147,200.0
148,200.0
149,200.0
150,200.0
151,197.0
152,133.0
153,142.0
154,147.0
155,156.0
156,131.0
157,181.0
158,163.0
159,146.0
160,200.0
161,176.0
162,200.0
163,173.0
164,177.0
165,200.0
166,200.0
167,200.0
168,200.0
169,200.0
170,200.0
171,200.0
172,200.0
173,200.0
174,200.0
175,200.0
176,200.0
177,200.0
178,200.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,200.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,190.0
197,200.0
198,189.0
199,200.0
1 episodes rewards
2 0 26.0
3 1 53.0
4 2 10.0
5 3 37.0
6 4 22.0
7 5 21.0
8 6 12.0
9 7 34.0
10 8 38.0
11 9 40.0
12 10 23.0
13 11 14.0
14 12 16.0
15 13 25.0
16 14 15.0
17 15 23.0
18 16 11.0
19 17 28.0
20 18 21.0
21 19 62.0
22 20 33.0
23 21 27.0
24 22 15.0
25 23 17.0
26 24 26.0
27 25 35.0
28 26 26.0
29 27 14.0
30 28 42.0
31 29 45.0
32 30 34.0
33 31 39.0
34 32 31.0
35 33 17.0
36 34 42.0
37 35 41.0
38 36 31.0
39 37 39.0
40 38 28.0
41 39 12.0
42 40 36.0
43 41 33.0
44 42 47.0
45 43 40.0
46 44 63.0
47 45 36.0
48 46 64.0
49 47 79.0
50 48 49.0
51 49 40.0
52 50 65.0
53 51 47.0
54 52 51.0
55 53 30.0
56 54 26.0
57 55 41.0
58 56 86.0
59 57 61.0
60 58 38.0
61 59 200.0
62 60 49.0
63 61 70.0
64 62 61.0
65 63 101.0
66 64 200.0
67 65 152.0
68 66 108.0
69 67 46.0
70 68 72.0
71 69 87.0
72 70 27.0
73 71 126.0
74 72 46.0
75 73 25.0
76 74 14.0
77 75 42.0
78 76 38.0
79 77 55.0
80 78 42.0
81 79 51.0
82 80 67.0
83 81 83.0
84 82 178.0
85 83 115.0
86 84 140.0
87 85 97.0
88 86 85.0
89 87 61.0
90 88 153.0
91 89 200.0
92 90 200.0
93 91 200.0
94 92 200.0
95 93 64.0
96 94 200.0
97 95 200.0
98 96 157.0
99 97 128.0
100 98 160.0
101 99 35.0
102 100 140.0
103 101 113.0
104 102 200.0
105 103 154.0
106 104 200.0
107 105 200.0
108 106 200.0
109 107 198.0
110 108 137.0
111 109 200.0
112 110 200.0
113 111 102.0
114 112 200.0
115 113 200.0
116 114 200.0
117 115 200.0
118 116 148.0
119 117 200.0
120 118 200.0
121 119 200.0
122 120 200.0
123 121 200.0
124 122 194.0
125 123 200.0
126 124 200.0
127 125 200.0
128 126 183.0
129 127 200.0
130 128 200.0
131 129 200.0
132 130 200.0
133 131 200.0
134 132 200.0
135 133 200.0
136 134 200.0
137 135 200.0
138 136 93.0
139 137 96.0
140 138 84.0
141 139 103.0
142 140 79.0
143 141 104.0
144 142 82.0
145 143 105.0
146 144 200.0
147 145 200.0
148 146 171.0
149 147 200.0
150 148 200.0
151 149 200.0
152 150 200.0
153 151 197.0
154 152 133.0
155 153 142.0
156 154 147.0
157 155 156.0
158 156 131.0
159 157 181.0
160 158 163.0
161 159 146.0
162 160 200.0
163 161 176.0
164 162 200.0
165 163 173.0
166 164 177.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 200.0
171 169 200.0
172 170 200.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 200.0
177 175 200.0
178 176 200.0
179 177 200.0
180 178 200.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 200.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 190.0
199 197 200.0
200 198 189.0
201 199 200.0

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2020-11-22 23:27:44
LastEditor: John
LastEditTime: 2022-02-10 01:25:27
LastEditTime: 2022-08-22 17:35:34
Discription:
Environment:
'''
@@ -16,35 +16,27 @@ from torch.distributions import Bernoulli
from torch.autograd import Variable
import numpy as np
class MLP(nn.Module):
''' 多层感知机
输入state维度
输出:概率
'''
def __init__(self,input_dim,hidden_dim = 36):
super(MLP, self).__init__()
# 24和36为hidden layer的层数可根据input_dim, n_actions的情况来改变
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
return x
class PolicyGradient:
def __init__(self, n_states,cfg):
def __init__(self, n_states,model,memory,cfg):
self.gamma = cfg.gamma
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
self.device = torch.device(cfg.device)
self.memory = memory
self.policy_net = model.to(self.device)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
self.batch_size = cfg.batch_size
def choose_action(self,state):
def sample_action(self,state):
state = torch.from_numpy(state).float()
state = Variable(state)
probs = self.policy_net(state)
m = Bernoulli(probs) # 伯努利分布
action = m.sample()
action = action.data.numpy().astype(int)[0] # 转为标量
return action
def predict_action(self,state):
state = torch.from_numpy(state).float()
state = Variable(state)
probs = self.policy_net(state)
@@ -53,7 +45,9 @@ class PolicyGradient:
action = action.data.numpy().astype(int)[0] # 转为标量
return action
def update(self,reward_pool,state_pool,action_pool):
def update(self):
state_pool,action_pool,reward_pool= self.memory.sample()
state_pool,action_pool,reward_pool = list(state_pool),list(action_pool),list(reward_pool)
# Discount reward
running_add = 0
for i in reversed(range(len(reward_pool))):
@@ -83,7 +77,11 @@ class PolicyGradient:
# print(loss)
loss.backward()
self.optimizer.step()
def save(self,path):
torch.save(self.policy_net.state_dict(), path+'pg_checkpoint.pt')
def load(self,path):
self.policy_net.load_state_dict(torch.load(path+'pg_checkpoint.pt'))
self.memory.clear()
def save_model(self,path):
from pathlib import Path
# create path
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.policy_net.state_dict(), path+'checkpoint.pt')
def load_model(self,path):
self.policy_net.load_state_dict(torch.load(path+'checkpoint.pt'))

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2020-11-22 23:21:53
LastEditor: John
LastEditTime: 2022-07-21 21:44:00
LastEditTime: 2022-08-22 17:40:07
Discription:
Environment:
'''
@@ -19,10 +19,11 @@ import torch
import datetime
import argparse
from itertools import count
import torch.nn.functional as F
from pg import PolicyGradient
from common.utils import save_results, make_dir
from common.utils import plot_rewards
from common.utils import save_results, make_dir,all_seed,save_args,plot_rewards
from common.models import MLP
from common.memories import PGReplay
def get_args():
@@ -32,112 +33,107 @@ def get_args():
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='PolicyGradient',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=300,type=int,help="episodes of training")
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
parser.add_argument('--lr',default=0.01,type=float,help="learning rate")
parser.add_argument('--batch_size',default=8,type=int)
parser.add_argument('--lr',default=0.005,type=float,help="learning rate")
parser.add_argument('--update_fre',default=8,type=int)
parser.add_argument('--hidden_dim',default=36,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=1,type=int,help="seed")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/' ) # path to save models
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
args = parser.parse_args([])
return args
class PGNet(MLP):
''' instead of outputing action, PG Net outputs propabilities of actions, we can use class inheritance from MLP here
'''
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
return x
def env_agent_config(cfg,seed=1):
def env_agent_config(cfg):
env = gym.make(cfg.env_name)
env.seed(seed)
if cfg.seed !=0: # set random seed
all_seed(env,seed=cfg.seed)
n_states = env.observation_space.shape[0]
agent = PolicyGradient(n_states,cfg)
n_actions = env.action_space.n # action dimension
print(f"state dim: {n_states}, action dim: {n_actions}")
model = PGNet(n_states,1,hidden_dim=cfg.hidden_dim)
memory = PGReplay()
agent = PolicyGradient(n_states,model,memory,cfg)
return env,agent
def train(cfg,env,agent):
print('Start training!')
print(f'Env:{cfg.env_name}, Algorithm:{cfg.algo_name}, Device:{cfg.device}')
state_pool = [] # temp states pool per several episodes
action_pool = []
reward_pool = []
print(f'Env:{cfg.env_name}, Algo:{cfg.algo_name}, Device:{cfg.device}')
rewards = []
ma_rewards = []
for i_ep in range(cfg.train_eps):
state = env.reset()
ep_reward = 0
for _ in count():
action = agent.choose_action(state) # 根据当前环境state选择action
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action)
ep_reward += reward
if done:
reward = 0
state_pool.append(state)
action_pool.append(float(action))
reward_pool.append(reward)
agent.memory.push((state,float(action),reward))
state = next_state
if done:
print(f'Episode{i_ep+1}/{cfg.train_eps}, Reward:{ep_reward:.2f}')
break
if i_ep > 0 and i_ep % cfg.batch_size == 0:
agent.update(reward_pool,state_pool,action_pool)
state_pool = []
action_pool = []
reward_pool = []
if (i_ep+1) % cfg.update_fre == 0:
agent.update()
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print('Finish training!')
env.close() # close environment
return rewards, ma_rewards
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg,env,agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = []
ma_rewards = []
for i_ep in range(cfg.test_eps):
state = env.reset()
ep_reward = 0
for _ in count():
action = agent.choose_action(state) # 根据当前环境state选择action
action = agent.predict_action(state)
next_state, reward, done, _ = env.step(action)
ep_reward += reward
if done:
reward = 0
state = next_state
if done:
print('回合:{}/{}, 奖励:{}'.format(i_ep + 1, cfg.train_eps, ep_reward))
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
break
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(
0.9*ma_rewards[-1]+0.1*ep_reward)
else:
ma_rewards.append(ep_reward)
print('完成测试!')
print("finish testing!")
env.close()
return rewards, ma_rewards
return {'episodes':range(len(rewards)),'rewards':rewards}
if __name__ == "__main__":
cfg = Config()
# 训练
cfg = get_args()
env, agent = env_agent_config(cfg)
rewards, ma_rewards = train(cfg, env, agent)
make_dir(cfg.result_path, cfg.model_path) # 创建保存结果和模型路径的文件夹
agent.save(path=cfg.model_path) # 保存模型
save_results(rewards, ma_rewards, tag='train',
path=cfg.result_path) # 保存结果
plot_rewards(rewards, ma_rewards, cfg, tag="train") # 画出结果
# 测试
env, agent = env_agent_config(cfg)
agent.load(path=cfg.model_path) # 导入模型
rewards, ma_rewards = test(cfg, env, agent)
save_results(rewards, ma_rewards, tag='test',
path=cfg.result_path) # 保存结果
plot_rewards(rewards, ma_rewards, cfg, tag="test") # 画出结果
res_dic = train(cfg, env, agent)
save_args(cfg,path = cfg.result_path) # save parameters
agent.save_model(path = cfg.model_path) # save models
save_results(res_dic, tag = 'train', path = cfg.result_path) # save results
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train") # plot results
# testing
env, agent = env_agent_config(cfg) # create new env for testing, sometimes can ignore this step
agent.load_model(path = cfg.model_path) # load model
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path = cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test")

View File

@@ -0,0 +1,4 @@
class SAC:
def __init__(self,n_actions,model,memory,cfg):
pass

View File

@@ -0,0 +1 @@
{"algo_name": "SoftQ", "env_name": "CartPole-v0", "train_eps": 200, "test_eps": 20, "max_steps": 200, "gamma": 0.99, "alpha": 4, "lr": 0.0001, "memory_capacity": 50000, "batch_size": 128, "target_update": 2, "device": "cpu", "seed": 10, "result_path": "/Users/jj/Desktop/rl-tutorials/codes/SoftQ/outputs/CartPole-v0/20220818-154333/results/", "model_path": "/Users/jj/Desktop/rl-tutorials/codes/SoftQ/outputs/CartPole-v0/20220818-154333/models/", "show_fig": false, "save_fig": true}

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View File

@@ -0,0 +1,21 @@
episodes,rewards
0,200.0
1,200.0
2,200.0
3,200.0
4,200.0
5,200.0
6,200.0
7,200.0
8,199.0
9,200.0
10,200.0
11,200.0
12,200.0
13,200.0
14,200.0
15,200.0
16,200.0
17,200.0
18,200.0
19,200.0
1 episodes rewards
2 0 200.0
3 1 200.0
4 2 200.0
5 3 200.0
6 4 200.0
7 5 200.0
8 6 200.0
9 7 200.0
10 8 199.0
11 9 200.0
12 10 200.0
13 11 200.0
14 12 200.0
15 13 200.0
16 14 200.0
17 15 200.0
18 16 200.0
19 17 200.0
20 18 200.0
21 19 200.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -0,0 +1,201 @@
episodes,rewards
0,21.0
1,23.0
2,24.0
3,27.0
4,33.0
5,18.0
6,47.0
7,18.0
8,18.0
9,21.0
10,26.0
11,31.0
12,11.0
13,17.0
14,22.0
15,16.0
16,17.0
17,34.0
18,20.0
19,11.0
20,50.0
21,15.0
22,11.0
23,39.0
24,11.0
25,28.0
26,37.0
27,26.0
28,63.0
29,18.0
30,17.0
31,13.0
32,9.0
33,15.0
34,13.0
35,21.0
36,17.0
37,22.0
38,20.0
39,31.0
40,9.0
41,10.0
42,11.0
43,15.0
44,18.0
45,10.0
46,30.0
47,14.0
48,36.0
49,26.0
50,21.0
51,15.0
52,9.0
53,14.0
54,10.0
55,27.0
56,14.0
57,15.0
58,22.0
59,12.0
60,20.0
61,10.0
62,12.0
63,29.0
64,11.0
65,13.0
66,27.0
67,50.0
68,29.0
69,40.0
70,29.0
71,18.0
72,27.0
73,11.0
74,15.0
75,10.0
76,13.0
77,11.0
78,17.0
79,13.0
80,18.0
81,24.0
82,15.0
83,34.0
84,11.0
85,35.0
86,26.0
87,9.0
88,19.0
89,19.0
90,16.0
91,25.0
92,18.0
93,37.0
94,46.0
95,88.0
96,26.0
97,55.0
98,43.0
99,141.0
100,89.0
101,151.0
102,47.0
103,56.0
104,64.0
105,56.0
106,49.0
107,87.0
108,58.0
109,55.0
110,57.0
111,165.0
112,31.0
113,200.0
114,57.0
115,107.0
116,46.0
117,45.0
118,64.0
119,69.0
120,67.0
121,65.0
122,47.0
123,63.0
124,134.0
125,60.0
126,89.0
127,99.0
128,51.0
129,109.0
130,131.0
131,156.0
132,118.0
133,185.0
134,86.0
135,149.0
136,138.0
137,143.0
138,114.0
139,130.0
140,139.0
141,106.0
142,135.0
143,164.0
144,156.0
145,155.0
146,200.0
147,186.0
148,64.0
149,200.0
150,135.0
151,135.0
152,168.0
153,200.0
154,200.0
155,200.0
156,167.0
157,198.0
158,188.0
159,200.0
160,200.0
161,200.0
162,200.0
163,200.0
164,200.0
165,200.0
166,200.0
167,200.0
168,189.0
169,200.0
170,146.0
171,200.0
172,200.0
173,200.0
174,115.0
175,170.0
176,200.0
177,200.0
178,178.0
179,200.0
180,200.0
181,200.0
182,200.0
183,200.0
184,200.0
185,200.0
186,120.0
187,200.0
188,200.0
189,200.0
190,200.0
191,200.0
192,200.0
193,200.0
194,200.0
195,200.0
196,200.0
197,200.0
198,200.0
199,200.0
1 episodes rewards
2 0 21.0
3 1 23.0
4 2 24.0
5 3 27.0
6 4 33.0
7 5 18.0
8 6 47.0
9 7 18.0
10 8 18.0
11 9 21.0
12 10 26.0
13 11 31.0
14 12 11.0
15 13 17.0
16 14 22.0
17 15 16.0
18 16 17.0
19 17 34.0
20 18 20.0
21 19 11.0
22 20 50.0
23 21 15.0
24 22 11.0
25 23 39.0
26 24 11.0
27 25 28.0
28 26 37.0
29 27 26.0
30 28 63.0
31 29 18.0
32 30 17.0
33 31 13.0
34 32 9.0
35 33 15.0
36 34 13.0
37 35 21.0
38 36 17.0
39 37 22.0
40 38 20.0
41 39 31.0
42 40 9.0
43 41 10.0
44 42 11.0
45 43 15.0
46 44 18.0
47 45 10.0
48 46 30.0
49 47 14.0
50 48 36.0
51 49 26.0
52 50 21.0
53 51 15.0
54 52 9.0
55 53 14.0
56 54 10.0
57 55 27.0
58 56 14.0
59 57 15.0
60 58 22.0
61 59 12.0
62 60 20.0
63 61 10.0
64 62 12.0
65 63 29.0
66 64 11.0
67 65 13.0
68 66 27.0
69 67 50.0
70 68 29.0
71 69 40.0
72 70 29.0
73 71 18.0
74 72 27.0
75 73 11.0
76 74 15.0
77 75 10.0
78 76 13.0
79 77 11.0
80 78 17.0
81 79 13.0
82 80 18.0
83 81 24.0
84 82 15.0
85 83 34.0
86 84 11.0
87 85 35.0
88 86 26.0
89 87 9.0
90 88 19.0
91 89 19.0
92 90 16.0
93 91 25.0
94 92 18.0
95 93 37.0
96 94 46.0
97 95 88.0
98 96 26.0
99 97 55.0
100 98 43.0
101 99 141.0
102 100 89.0
103 101 151.0
104 102 47.0
105 103 56.0
106 104 64.0
107 105 56.0
108 106 49.0
109 107 87.0
110 108 58.0
111 109 55.0
112 110 57.0
113 111 165.0
114 112 31.0
115 113 200.0
116 114 57.0
117 115 107.0
118 116 46.0
119 117 45.0
120 118 64.0
121 119 69.0
122 120 67.0
123 121 65.0
124 122 47.0
125 123 63.0
126 124 134.0
127 125 60.0
128 126 89.0
129 127 99.0
130 128 51.0
131 129 109.0
132 130 131.0
133 131 156.0
134 132 118.0
135 133 185.0
136 134 86.0
137 135 149.0
138 136 138.0
139 137 143.0
140 138 114.0
141 139 130.0
142 140 139.0
143 141 106.0
144 142 135.0
145 143 164.0
146 144 156.0
147 145 155.0
148 146 200.0
149 147 186.0
150 148 64.0
151 149 200.0
152 150 135.0
153 151 135.0
154 152 168.0
155 153 200.0
156 154 200.0
157 155 200.0
158 156 167.0
159 157 198.0
160 158 188.0
161 159 200.0
162 160 200.0
163 161 200.0
164 162 200.0
165 163 200.0
166 164 200.0
167 165 200.0
168 166 200.0
169 167 200.0
170 168 189.0
171 169 200.0
172 170 146.0
173 171 200.0
174 172 200.0
175 173 200.0
176 174 115.0
177 175 170.0
178 176 200.0
179 177 200.0
180 178 178.0
181 179 200.0
182 180 200.0
183 181 200.0
184 182 200.0
185 183 200.0
186 184 200.0
187 185 200.0
188 186 120.0
189 187 200.0
190 188 200.0
191 189 200.0
192 190 200.0
193 191 200.0
194 192 200.0
195 193 200.0
196 194 200.0
197 195 200.0
198 196 200.0
199 197 200.0
200 198 200.0
201 199 200.0

View File

@@ -0,0 +1,71 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import random
from torch.distributions import Categorical
import gym
import numpy as np
class SoftQ:
def __init__(self,n_actions,model,memory,cfg):
self.memory = memory
self.alpha = cfg.alpha
self.gamma = cfg.gamma # discount factor
self.batch_size = cfg.batch_size
self.device = torch.device(cfg.device)
self.policy_net = model.to(self.device)
self.target_net = model.to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict()) # copy parameters
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
self.losses = [] # save losses
def sample_action(self,state):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
q = self.policy_net(state)
v = self.alpha * torch.log(torch.sum(torch.exp(q/self.alpha), dim=1, keepdim=True)).squeeze()
dist = torch.exp((q-v)/self.alpha)
dist = dist / torch.sum(dist)
c = Categorical(dist)
a = c.sample()
return a.item()
def predict_action(self,state):
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float).unsqueeze(0)
with torch.no_grad():
q = self.policy_net(state)
v = self.alpha * torch.log(torch.sum(torch.exp(q/self.alpha), dim=1, keepdim=True)).squeeze()
dist = torch.exp((q-v)/self.alpha)
dist = dist / torch.sum(dist)
c = Categorical(dist)
a = c.sample()
return a.item()
def update(self):
if len(self.memory) < self.batch_size: # when the memory capacity does not meet a batch, the network will not update
return
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.batch_size)
state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
action_batch = torch.tensor(np.array(action_batch), device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
reward_batch = torch.tensor(np.array(reward_batch), device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
next_state_batch = torch.tensor(np.array(next_state_batch), device=self.device, dtype=torch.float) # shape(batchsize,n_states)
done_batch = torch.tensor(np.array(done_batch), device=self.device, dtype=torch.float).unsqueeze(1) # shape(batchsize,1)
# print(state_batch.shape,action_batch.shape,reward_batch.shape,next_state_batch.shape,done_batch.shape)
with torch.no_grad():
next_q = self.target_net(next_state_batch)
next_v = self.alpha * torch.log(torch.sum(torch.exp(next_q/self.alpha), dim=1, keepdim=True))
y = reward_batch + (1 - done_batch ) * self.gamma * next_v
loss = F.mse_loss(self.policy_net(state_batch).gather(1, action_batch.long()), y)
self.losses.append(loss)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def save_model(self, path):
from pathlib import Path
# create path
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.target_net.state_dict(), path+'checkpoint.pth')
def load_model(self, path):
self.target_net.load_state_dict(torch.load(path+'checkpoint.pth'))
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
param.data.copy_(target_param.data)

View File

@@ -0,0 +1,142 @@
import sys,os
curr_path = os.path.dirname(os.path.abspath(__file__)) # current path
parent_path = os.path.dirname(curr_path) # parent path
sys.path.append(parent_path) # add path to system path
import argparse
import datetime
import gym
import torch
import random
import numpy as np
import torch.nn as nn
from common.memories import ReplayBufferQue
from common.models import MLP
from common.utils import save_results,all_seed,plot_rewards,save_args
from softq import SoftQ
def get_args():
""" hyperparameters
"""
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # obtain current time
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--algo_name',default='SoftQ',type=str,help="name of algorithm")
parser.add_argument('--env_name',default='CartPole-v0',type=str,help="name of environment")
parser.add_argument('--train_eps',default=200,type=int,help="episodes of training")
parser.add_argument('--test_eps',default=20,type=int,help="episodes of testing")
parser.add_argument('--max_steps',default=200,type=int,help="maximum steps per episode")
parser.add_argument('--gamma',default=0.99,type=float,help="discounted factor")
parser.add_argument('--alpha',default=4,type=float,help="alpha")
parser.add_argument('--lr',default=0.0001,type=float,help="learning rate")
parser.add_argument('--memory_capacity',default=50000,type=int,help="memory capacity")
parser.add_argument('--batch_size',default=128,type=int)
parser.add_argument('--target_update',default=2,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--seed',default=10,type=int,help="seed")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/' )
parser.add_argument('--show_fig',default=False,type=bool,help="if show figure or not")
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
class SoftQNetwork(nn.Module):
'''Actually almost same to common.models.MLP
'''
def __init__(self,input_dim,output_dim):
super(SoftQNetwork,self).__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 256)
self.fc3 = nn.Linear(256, output_dim)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def env_agent_config(cfg):
''' create env and agent
'''
env = gym.make(cfg.env_name) # create env
if cfg.seed !=0: # set random seed
all_seed(env,seed=cfg.seed)
n_states = env.observation_space.shape[0] # state dimension
n_actions = env.action_space.n # action dimension
print(f"state dim: {n_states}, action dim: {n_actions}")
# model = MLP(n_states,n_actions)
model = SoftQNetwork(n_states,n_actions)
memory = ReplayBufferQue(cfg.memory_capacity) # replay buffer
agent = SoftQ(n_actions,model,memory,cfg) # create agent
return env, agent
def train(cfg, env, agent):
''' training
'''
print("start training!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
steps = [] # record steps for all episodes, sometimes need
for i_ep in range(cfg.train_eps):
ep_reward = 0 # reward per episode
ep_step = 0
state = env.reset() # reset and obtain initial state
while True:
# for _ in range(cfg.max_steps):
ep_step += 1
action = agent.sample_action(state) # sample action
next_state, reward, done, _ = env.step(action) # update env and return transitions
agent.memory.push((state, action, reward, next_state, done)) # save transitions
state = next_state # update next state for env
agent.update() # update agent
ep_reward += reward
if done:
break
if (i_ep + 1) % cfg.target_update == 0: # target net update, target_update means "C" in pseucodes
agent.target_net.load_state_dict(agent.policy_net.state_dict())
steps.append(ep_step)
rewards.append(ep_reward)
if (i_ep + 1) % 10 == 0:
print(f'Episode: {i_ep+1}/{cfg.train_eps}, Reward: {ep_reward:.2f}')
print("finish training!")
res_dic = {'episodes':range(len(rewards)),'rewards':rewards}
return res_dic
def test(cfg, env, agent):
print("start testing!")
print(f"Env: {cfg.env_name}, Algo: {cfg.algo_name}, Device: {cfg.device}")
rewards = [] # record rewards for all episodes
for i_ep in range(cfg.test_eps):
ep_reward = 0 # reward per episode
state = env.reset() # reset and obtain initial state
while True:
action = agent.predict_action(state) # predict action
next_state, reward, done, _ = env.step(action)
state = next_state
ep_reward += reward
if done:
break
rewards.append(ep_reward)
print(f'Episode: {i_ep+1}/{cfg.test_eps}Reward: {ep_reward:.2f}')
print("finish testing!")
env.close()
return {'episodes':range(len(rewards)),'rewards':rewards}
if __name__ == "__main__":
cfg = get_args()
# 训练
env, agent = env_agent_config(cfg)
res_dic = train(cfg, env, agent)
save_args(cfg,path = cfg.result_path) # 保存参数到模型路径上
agent.save_model(path = cfg.model_path) # 保存模型
save_results(res_dic, tag = 'train', path = cfg.result_path)
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "train")
# 测试
env, agent = env_agent_config(cfg) # 也可以不加,加这一行的是为了避免训练之后环境可能会出现问题,因此新建一个环境用于测试
agent.load_model(path = cfg.model_path) # 导入模型
res_dic = test(cfg, env, agent)
save_results(res_dic, tag='test',
path = cfg.result_path) # 保存结果
plot_rewards(res_dic['rewards'], cfg, path = cfg.result_path,tag = "test") # 画出结果

View File

@@ -5,11 +5,12 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-10 15:27:16
@LastEditor: John
LastEditTime: 2021-09-15 14:52:37
LastEditTime: 2022-08-22 17:23:21
@Discription:
@Environment: python 3.7.7
'''
import random
from collections import deque
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity # 经验回放的容量
@@ -34,3 +35,40 @@ class ReplayBuffer:
'''
return len(self.buffer)
class ReplayBufferQue:
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self.buffer = deque(maxlen=self.capacity)
def push(self,trainsitions):
'''_summary_
Args:
trainsitions (tuple): _description_
'''
self.buffer.append(trainsitions)
def sample(self, batch_size: int, sequential: bool = False):
if batch_size > len(self.buffer):
batch_size = len(self.buffer)
if sequential: # sequential sampling
rand = random.randint(0, len(self.buffer) - batch_size)
batch = [self.buffer[i] for i in range(rand, rand + batch_size)]
return zip(*batch)
else:
batch = random.sample(self.buffer, batch_size)
return zip(*batch)
def clear(self):
self.buffer.clear()
def __len__(self):
return len(self.buffer)
class PGReplay(ReplayBufferQue):
'''replay buffer for policy gradient based methods, each time these methods will sample all transitions
Args:
ReplayBufferQue (_type_): _description_
'''
def __init__(self):
self.buffer = deque()
def sample(self):
''' sample all the transitions
'''
batch = list(self.buffer)
return zip(*batch)

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:02:24
LastEditor: John
LastEditTime: 2022-08-15 18:11:27
LastEditTime: 2022-08-22 17:41:28
Discription:
Environment:
'''
@@ -15,6 +15,7 @@ from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import json
import pandas as pd
from matplotlib.font_manager import FontProperties # 导入字体模块
@@ -84,12 +85,12 @@ def plot_losses(losses, algo="DQN", save=True, path='./'):
plt.savefig(path+"losses_curve")
plt.show()
def save_results(dic, tag='train', path = None):
def save_results(res_dic, tag='train', path = None):
''' 保存奖励
'''
Path(path).mkdir(parents=True, exist_ok=True)
for key,value in dic.items():
np.save(path+'{}_{}.npy'.format(tag,key),value)
df = pd.DataFrame(res_dic)
df.to_csv(f"{path}/{tag}ing_results.csv",index=None)
print('Results saved')
@@ -115,4 +116,26 @@ def save_args(args,path=None):
Path(path).mkdir(parents=True, exist_ok=True)
with open(f"{path}/params.json", 'w') as fp:
json.dump(args_dict, fp)
print("参数已保存!")
print("Parameters saved!")
def all_seed(env,seed = 1):
''' omnipotent seed for RL, attention the position of seed function, you'd better put it just following the env create function
Args:
env (_type_):
seed (int, optional): _description_. Defaults to 1.
'''
import torch
import numpy as np
import random
print(f"seed = {seed}")
env.seed(seed) # env config
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed) # config for CPU
torch.cuda.manual_seed(seed) # config for GPU
os.environ['PYTHONHASHSEED'] = str(seed) # config for python scripts
# config for cudnn
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False

View File

@@ -1,10 +1,11 @@
gym==0.21.0
torch==1.9.0
torchvision==0.10.0
torchaudio==0.9.0
torch==1.10.0
torchvision==0.11.0
torchaudio==0.10.0
ipykernel==6.15.1
jupyter==1.0.0
matplotlib==3.5.2
seaborn==0.11.2
dill==0.3.5.1
argparse==1.4.0
argparse==1.4.0
pandas==1.3.5