Files
easy-rl/projects/codes/Sarsa/sarsa.py
2022-08-15 22:31:37 +08:00

58 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-12 16:58:16
LastEditor: John
LastEditTime: 2022-08-04 22:22:16
Discription:
Environment:
'''
import numpy as np
from collections import defaultdict
import torch
import math
class Sarsa(object):
def __init__(self,
n_actions,cfg):
self.n_actions = n_actions
self.lr = cfg.lr
self.gamma = cfg.gamma
self.sample_count = 0
self.epsilon_start = cfg.epsilon_start
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.Q = defaultdict(lambda: np.zeros(n_actions)) # Q table
def sample(self, state):
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay) # The probability to select a random action, is is log decayed
best_action = np.argmax(self.Q[state])
action_probs = np.ones(self.n_actions, dtype=float) * self.epsilon / self.n_actions
action_probs[best_action] += (1.0 - self.epsilon)
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
return action
def predict(self,state):
return np.argmax(self.Q[state])
def update(self, state, action, reward, next_state, next_action,done):
Q_predict = self.Q[state][action]
if done:
Q_target = reward # 终止状态
else:
Q_target = reward + self.gamma * self.Q[next_state][next_action] # 与Q learning不同Sarsa是拿下一步动作对应的Q值去更新
self.Q[state][action] += self.lr * (Q_target - Q_predict)
def save(self,path):
'''把 Q表格 的数据保存到文件中
'''
import dill
torch.save(
obj=self.Q,
f=path+"sarsa_model.pkl",
pickle_module=dill
)
def load(self, path):
'''从文件中读取数据到 Q表格
'''
import dill
self.Q =torch.load(f=path+'sarsa_model.pkl',pickle_module=dill)