This commit is contained in:
johnjim0816
2022-07-21 22:12:19 +08:00
parent 6b3121fcff
commit e9b3e92141
21 changed files with 99 additions and 85 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:02:24
LastEditor: John
LastEditTime: 2022-07-20 23:53:34
LastEditTime: 2022-07-21 21:45:33
Discription:
Environment:
'''
@@ -69,19 +69,19 @@ def plot_losses(losses, algo="DQN", save=True, path='./'):
plt.savefig(path+"losses_curve")
plt.show()
def save_results_1(dic, tag='train', path='./results'):
def save_results(dic, tag='train', path='./results'):
''' 保存奖励
'''
for key,value in dic.items():
np.save(path+'{}_{}.npy'.format(tag,key),value)
print('Results saved')
def save_results(rewards, ma_rewards, tag='train', path='./results'):
''' 保存奖励
'''
np.save(path+'{}_rewards.npy'.format(tag), rewards)
np.save(path+'{}_ma_rewards.npy'.format(tag), ma_rewards)
print('Result saved!')
# def save_results(rewards, ma_rewards, tag='train', path='./results'):
# ''' 保存奖励
# '''
# np.save(path+'{}_rewards.npy'.format(tag), rewards)
# np.save(path+'{}_ma_rewards.npy'.format(tag), ma_rewards)
# print('Result saved!')
def make_dir(*paths):