178 lines
4.9 KiB
Python
178 lines
4.9 KiB
Python
#!/usr/bin/env python
|
||
# coding=utf-8
|
||
'''
|
||
Author: JiangJi
|
||
Email: johnjim0816@gmail.com
|
||
Date: 2021-12-22 10:40:05
|
||
LastEditor: JiangJi
|
||
LastEditTime: 2021-12-22 10:43:55
|
||
Discription:
|
||
'''
|
||
import copy
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from TD3.memory import ReplayBuffer
|
||
|
||
class Actor(nn.Module):
|
||
|
||
def __init__(self, input_dim, output_dim, max_action):
|
||
'''[summary]
|
||
|
||
Args:
|
||
input_dim (int): 输入维度,这里等于n_states
|
||
output_dim (int): 输出维度,这里等于n_actions
|
||
max_action (int): action的最大值
|
||
'''
|
||
super(Actor, self).__init__()
|
||
|
||
self.l1 = nn.Linear(input_dim, 256)
|
||
self.l2 = nn.Linear(256, 256)
|
||
self.l3 = nn.Linear(256, output_dim)
|
||
self.max_action = max_action
|
||
|
||
def forward(self, state):
|
||
|
||
a = F.relu(self.l1(state))
|
||
a = F.relu(self.l2(a))
|
||
return self.max_action * torch.tanh(self.l3(a))
|
||
|
||
|
||
class Critic(nn.Module):
|
||
def __init__(self, input_dim, output_dim):
|
||
super(Critic, self).__init__()
|
||
|
||
# Q1 architecture
|
||
self.l1 = nn.Linear(input_dim + output_dim, 256)
|
||
self.l2 = nn.Linear(256, 256)
|
||
self.l3 = nn.Linear(256, 1)
|
||
|
||
# Q2 architecture
|
||
self.l4 = nn.Linear(input_dim + output_dim, 256)
|
||
self.l5 = nn.Linear(256, 256)
|
||
self.l6 = nn.Linear(256, 1)
|
||
|
||
|
||
def forward(self, state, action):
|
||
sa = torch.cat([state, action], 1)
|
||
|
||
q1 = F.relu(self.l1(sa))
|
||
q1 = F.relu(self.l2(q1))
|
||
q1 = self.l3(q1)
|
||
|
||
q2 = F.relu(self.l4(sa))
|
||
q2 = F.relu(self.l5(q2))
|
||
q2 = self.l6(q2)
|
||
return q1, q2
|
||
|
||
|
||
def Q1(self, state, action):
|
||
sa = torch.cat([state, action], 1)
|
||
|
||
q1 = F.relu(self.l1(sa))
|
||
q1 = F.relu(self.l2(q1))
|
||
q1 = self.l3(q1)
|
||
return q1
|
||
|
||
|
||
class TD3(object):
|
||
def __init__(
|
||
self,
|
||
input_dim,
|
||
output_dim,
|
||
max_action,
|
||
cfg,
|
||
):
|
||
self.max_action = max_action
|
||
self.gamma = cfg.gamma
|
||
self.lr = cfg.lr
|
||
self.policy_noise = cfg.policy_noise
|
||
self.noise_clip = cfg.noise_clip
|
||
self.policy_freq = cfg.policy_freq
|
||
self.batch_size = cfg.batch_size
|
||
self.device = cfg.device
|
||
self.total_it = 0
|
||
|
||
self.actor = Actor(input_dim, output_dim, max_action).to(self.device)
|
||
self.actor_target = copy.deepcopy(self.actor)
|
||
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
|
||
|
||
self.critic = Critic(input_dim, output_dim).to(self.device)
|
||
self.critic_target = copy.deepcopy(self.critic)
|
||
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
|
||
self.memory = ReplayBuffer(input_dim, output_dim)
|
||
|
||
def choose_action(self, state):
|
||
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
|
||
return self.actor(state).cpu().data.numpy().flatten()
|
||
|
||
def update(self):
|
||
self.total_it += 1
|
||
|
||
# Sample replay buffer
|
||
state, action, next_state, reward, not_done = self.memory.sample(self.batch_size)
|
||
|
||
with torch.no_grad():
|
||
# Select action according to policy and add clipped noise
|
||
noise = (
|
||
torch.randn_like(action) * self.policy_noise
|
||
).clamp(-self.noise_clip, self.noise_clip)
|
||
|
||
next_action = (
|
||
self.actor_target(next_state) + noise
|
||
).clamp(-self.max_action, self.max_action)
|
||
|
||
# Compute the target Q value
|
||
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
|
||
target_Q = torch.min(target_Q1, target_Q2)
|
||
target_Q = reward + not_done * self.gamma * target_Q
|
||
|
||
# Get current Q estimates
|
||
current_Q1, current_Q2 = self.critic(state, action)
|
||
|
||
# Compute critic loss
|
||
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
|
||
|
||
# Optimize the critic
|
||
self.critic_optimizer.zero_grad()
|
||
critic_loss.backward()
|
||
self.critic_optimizer.step()
|
||
|
||
# Delayed policy updates
|
||
if self.total_it % self.policy_freq == 0:
|
||
|
||
# Compute actor losse
|
||
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
|
||
|
||
# Optimize the actor
|
||
self.actor_optimizer.zero_grad()
|
||
actor_loss.backward()
|
||
self.actor_optimizer.step()
|
||
|
||
# Update the frozen target models
|
||
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
||
target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)
|
||
|
||
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
||
target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)
|
||
|
||
|
||
def save(self, path):
|
||
torch.save(self.critic.state_dict(), path + "td3_critic")
|
||
torch.save(self.critic_optimizer.state_dict(), path + "td3_critic_optimizer")
|
||
|
||
torch.save(self.actor.state_dict(), path + "td3_actor")
|
||
torch.save(self.actor_optimizer.state_dict(), path + "td3_actor_optimizer")
|
||
|
||
|
||
def load(self, path):
|
||
self.critic.load_state_dict(torch.load(path + "td3_critic"))
|
||
self.critic_optimizer.load_state_dict(torch.load(path + "td3_critic_optimizer"))
|
||
self.critic_target = copy.deepcopy(self.critic)
|
||
|
||
self.actor.load_state_dict(torch.load(path + "td3_actor"))
|
||
self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer"))
|
||
self.actor_target = copy.deepcopy(self.actor)
|
||
|