hot update

This commit is contained in:
johnjim0816
2022-08-22 17:50:11 +08:00
parent 0a54840828
commit ad65dd17cd
54 changed files with 1639 additions and 503 deletions

View File

@@ -5,7 +5,7 @@ Author: John
Email: johnjim0816@gmail.com
Date: 2020-11-22 23:27:44
LastEditor: John
LastEditTime: 2022-02-10 01:25:27
LastEditTime: 2022-08-22 17:35:34
Discription:
Environment:
'''
@@ -16,35 +16,27 @@ from torch.distributions import Bernoulli
from torch.autograd import Variable
import numpy as np
class MLP(nn.Module):
''' 多层感知机
输入state维度
输出:概率
'''
def __init__(self,input_dim,hidden_dim = 36):
super(MLP, self).__init__()
# 24和36为hidden layer的层数可根据input_dim, n_actions的情况来改变
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1) # Prob of Left
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
return x
class PolicyGradient:
def __init__(self, n_states,cfg):
def __init__(self, n_states,model,memory,cfg):
self.gamma = cfg.gamma
self.policy_net = MLP(n_states,hidden_dim=cfg.hidden_dim)
self.device = torch.device(cfg.device)
self.memory = memory
self.policy_net = model.to(self.device)
self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=cfg.lr)
self.batch_size = cfg.batch_size
def choose_action(self,state):
def sample_action(self,state):
state = torch.from_numpy(state).float()
state = Variable(state)
probs = self.policy_net(state)
m = Bernoulli(probs) # 伯努利分布
action = m.sample()
action = action.data.numpy().astype(int)[0] # 转为标量
return action
def predict_action(self,state):
state = torch.from_numpy(state).float()
state = Variable(state)
probs = self.policy_net(state)
@@ -53,7 +45,9 @@ class PolicyGradient:
action = action.data.numpy().astype(int)[0] # 转为标量
return action
def update(self,reward_pool,state_pool,action_pool):
def update(self):
state_pool,action_pool,reward_pool= self.memory.sample()
state_pool,action_pool,reward_pool = list(state_pool),list(action_pool),list(reward_pool)
# Discount reward
running_add = 0
for i in reversed(range(len(reward_pool))):
@@ -83,7 +77,11 @@ class PolicyGradient:
# print(loss)
loss.backward()
self.optimizer.step()
def save(self,path):
torch.save(self.policy_net.state_dict(), path+'pg_checkpoint.pt')
def load(self,path):
self.policy_net.load_state_dict(torch.load(path+'pg_checkpoint.pt'))
self.memory.clear()
def save_model(self,path):
from pathlib import Path
# create path
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.policy_net.state_dict(), path+'checkpoint.pt')
def load_model(self,path):
self.policy_net.load_state_dict(torch.load(path+'checkpoint.pt'))