diff --git a/codes/Q-learning/agent.py b/codes/Q-learning/agent.py index 5c16c43..6c21251 100644 --- a/codes/Q-learning/agent.py +++ b/codes/Q-learning/agent.py @@ -84,11 +84,11 @@ class QLearning(object): def save(self): '''把 Q表格 的数据保存到文件中 ''' - npy_file = './Q_table.npy' + npy_file = './result/Q_table.npy' np.save(npy_file, self.Q_table) print(npy_file + ' saved.') - def load(self, npy_file='./Q_table.npy'): + def load(self, npy_file='./result/Q_table.npy'): '''从文件中读取数据到 Q表格 ''' self.Q_table = np.load(npy_file) - print(npy_file + 'loaded.') \ No newline at end of file + print(npy_file + 'loaded.')