diff --git a/codes/common/model.py b/codes/common/model.py new file mode 100644 index 0000000..bb71722 --- /dev/null +++ b/codes/common/model.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2021-03-12 21:14:12 +LastEditor: John +LastEditTime: 2021-03-12 21:28:46 +Discription: +Environment: +''' +import torch.nn as nn +import torch.nn.functional as F +class MLP(nn.Module): + ''' 全连接网络''' + def __init__(self,state_dim): + super(MLP, self).__init__() + # 24和36为hidden layer的层数,可根据state_dim, n_actions的情况来改变 + self.fc1 = nn.Linear(state_dim, 36) + self.fc2 = nn.Linear(36, 36) + self.fc3 = nn.Linear(36, 1) # Prob of Left + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = F.sigmoid(self.fc3(x)) + return x \ No newline at end of file diff --git a/codes/common/plot.py b/codes/common/plot.py new file mode 100644 index 0000000..409004a --- /dev/null +++ b/codes/common/plot.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-10-07 20:57:11 +LastEditor: John +LastEditTime: 2021-03-13 11:31:49 +Discription: +Environment: +''' +import matplotlib.pyplot as plt +import seaborn as sns +def plot_rewards(rewards,ma_rewards,tag="train",algo = "On-Policy First-Visit MC Control",path='./'): + sns.set() + plt.title("average learning curve of {}".format(algo)) + plt.xlabel('epsiodes') + plt.plot(rewards,label='rewards') + plt.plot(ma_rewards,label='moving average rewards') + plt.legend() + plt.savefig(path+"rewards_curve_{}".format(tag)) + plt.show() + diff --git a/codes/common/utils.py b/codes/common/utils.py new file mode 100644 index 0000000..2a44ec5 --- /dev/null +++ b/codes/common/utils.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2021-03-12 16:02:24 +LastEditor: John +LastEditTime: 2021-03-12 16:10:28 +Discription: +Environment: +''' +import os +import numpy as np + + +def save_results(rewards,ma_rewards,tag='train',path='./results'): + '''保存reward等结果 + ''' + np.save(path+'rewards_'+tag+'.npy', rewards) + np.save(path+'ma_rewards_'+tag+'.npy', ma_rewards) + print('results saved!') \ No newline at end of file