Files
easy-rl/codes/common/plot.py
johnjim0816 129c0c65fa update codes
2021-11-18 15:41:27 +08:00

39 lines
1.1 KiB
Python

#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2020-10-07 20:57:11
LastEditor: John
LastEditTime: 2021-09-23 12:23:01
Discription:
Environment:
'''
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties # 导入字体模块
def plot_rewards(rewards,ma_rewards,plot_cfg,tag='train'):
sns.set()
plt.figure() # 创建一个图形实例,方便同时多画几个图
plt.title("learning curve on {} of {} for {}".format(plot_cfg.device, plot_cfg.algo, plot_cfg.env_name))
plt.xlabel('epsiodes')
plt.plot(rewards,label='rewards')
plt.plot(ma_rewards,label='ma rewards')
plt.legend()
if plot_cfg.save:
plt.savefig(plot_cfg.result_path+"{}_rewards_curve".format(tag))
plt.show()
def plot_losses(losses,algo = "DQN",save=True,path='./'):
sns.set()
plt.figure()
plt.title("loss curve of {}".format(algo))
plt.xlabel('epsiodes')
plt.plot(losses,label='rewards')
plt.legend()
if save:
plt.savefig(path+"losses_curve")
plt.show()