From 4cc12bf97f3a797b029432e42c11fac8c2e2e642 Mon Sep 17 00:00:00 2001 From: JohnJim0816 Date: Mon, 23 Nov 2020 12:13:34 +0800 Subject: [PATCH] update PolicyGradient --- codes/PolicyGradient/agent.py | 68 ++++++++++++++++++++++++++++++++++ codes/PolicyGradient/env.py | 19 ++++++++++ codes/PolicyGradient/main.py | 52 ++++++++++++++++++++++++++ codes/PolicyGradient/model.py | 27 ++++++++++++++ codes/PolicyGradient/params.py | 19 ++++++++++ 5 files changed, 185 insertions(+) create mode 100644 codes/PolicyGradient/agent.py create mode 100644 codes/PolicyGradient/env.py create mode 100644 codes/PolicyGradient/main.py create mode 100644 codes/PolicyGradient/model.py create mode 100644 codes/PolicyGradient/params.py diff --git a/codes/PolicyGradient/agent.py b/codes/PolicyGradient/agent.py new file mode 100644 index 0000000..e2725c3 --- /dev/null +++ b/codes/PolicyGradient/agent.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-11-22 23:27:44 +LastEditor: John +LastEditTime: 2020-11-23 12:05:03 +Discription: +Environment: +''' +import torch +from torch.distributions import Bernoulli +from torch.autograd import Variable +import numpy as np + +from model import FCN + +class PolicyGradient: + + def __init__(self, n_states,device='cpu',gamma = 0.99,lr = 0.01,batch_size=5): + self.gamma = gamma + self.policy_net = FCN(n_states) + self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(), lr=lr) + self.batch_size = batch_size + + def choose_action(self,state): + + state = torch.from_numpy(state).float() + state = Variable(state) + probs = self.policy_net(state) + m = Bernoulli(probs) + action = m.sample() + + action = action.data.numpy().astype(int)[0] # 转为标量 + return action + + def update(self,reward_pool,state_pool,action_pool): + # Discount reward + running_add = 0 + for i in reversed(range(len(reward_pool))): + if reward_pool[i] == 0: + running_add = 0 + else: + running_add = running_add * self.gamma + reward_pool[i] + reward_pool[i] = running_add + + # Normalize reward + reward_mean = np.mean(reward_pool) + reward_std = np.std(reward_pool) + for i in range(len(reward_pool)): + reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std + + # Gradient Desent + self.optimizer.zero_grad() + + for i in range(len(reward_pool)): + state = state_pool[i] + action = Variable(torch.FloatTensor([action_pool[i]])) + reward = reward_pool[i] + + state = Variable(torch.from_numpy(state).float()) + probs = self.policy_net(state) + m = Bernoulli(probs) + loss = -m.log_prob(action) * reward # Negtive score function x reward + # print(loss) + loss.backward() + self.optimizer.step() \ No newline at end of file diff --git a/codes/PolicyGradient/env.py b/codes/PolicyGradient/env.py new file mode 100644 index 0000000..bf67b81 --- /dev/null +++ b/codes/PolicyGradient/env.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-11-22 23:23:10 +LastEditor: John +LastEditTime: 2020-11-23 11:55:24 +Discription: +Environment: +''' +import gym + +def env_init(): + env = gym.make('CartPole-v0') # 可google为什么unwrapped gym,此处一般不需要 + env.seed(1) # 设置env随机种子 + n_states = env.observation_space.shape[0] + n_actions = env.action_space.n + return env,n_states,n_actions \ No newline at end of file diff --git a/codes/PolicyGradient/main.py b/codes/PolicyGradient/main.py new file mode 100644 index 0000000..0e19513 --- /dev/null +++ b/codes/PolicyGradient/main.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-11-22 23:21:53 +LastEditor: John +LastEditTime: 2020-11-23 12:06:15 +Discription: +Environment: +''' +from itertools import count +import torch +from env import env_init +from params import get_args +from agent import PolicyGradient + +def train(cfg): + env,n_states,n_actions = env_init() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测gpu + agent = PolicyGradient(n_states,device = device,lr = cfg.policy_lr) + '''下面带pool都是存放的transition序列用于gradient''' + state_pool = [] # 存放每batch_size个episode的state序列 + action_pool = [] + reward_pool = [] + for i_episode in range(cfg.train_eps): + state = env.reset() + ep_reward = 0 + for t in count(): + action = agent.choose_action(state) # 根据当前环境state选择action + next_state, reward, done, _ = env.step(action) + ep_reward += reward + if done: + reward = 0 + state_pool.append(state) + action_pool.append(float(action)) + reward_pool.append(reward) + state = next_state + if done: + print('Episode:', i_episode, ' Reward:', ep_reward) + break + # if i_episode % cfg.batch_size == 0: + if i_episode > 0 and i_episode % 5 == 0: + agent.update(reward_pool,state_pool,action_pool) + state_pool = [] # 每个episode的state + action_pool = [] + reward_pool = [] + + +if __name__ == "__main__": + cfg = get_args() + train(cfg) \ No newline at end of file diff --git a/codes/PolicyGradient/model.py b/codes/PolicyGradient/model.py new file mode 100644 index 0000000..9ca6738 --- /dev/null +++ b/codes/PolicyGradient/model.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-11-22 23:18:46 +LastEditor: John +LastEditTime: 2020-11-23 01:58:22 +Discription: +Environment: +''' +import torch.nn as nn +import torch.nn.functional as F +class FCN(nn.Module): + ''' 全连接网络''' + def __init__(self,n_states): + super(FCN, self).__init__() + # 24和36为hidden layer的层数,可根据n_states, n_actions的情况来改变 + self.fc1 = nn.Linear(n_states, 24) + self.fc2 = nn.Linear(24, 36) + self.fc3 = nn.Linear(36, 1) # Prob of Left + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = F.sigmoid(self.fc3(x)) + return x \ No newline at end of file diff --git a/codes/PolicyGradient/params.py b/codes/PolicyGradient/params.py new file mode 100644 index 0000000..2a3390f --- /dev/null +++ b/codes/PolicyGradient/params.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Author: John +Email: johnjim0816@gmail.com +Date: 2020-11-22 23:25:37 +LastEditor: John +LastEditTime: 2020-11-22 23:32:44 +Discription: 存储参数 +Environment: +''' +import argparse +def get_args(): + '''训练参数''' + parser = argparse.ArgumentParser() + parser.add_argument("--train_eps", default=1200, type=int) # 训练的最大episode数目 + parser.add_argument("--policy_lr", default=0.01, type=float) # 学习率 + config = parser.parse_args() + return config \ No newline at end of file