167 lines
4.6 KiB
Python
167 lines
4.6 KiB
Python
import copy
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from TD3.memory import ReplayBuffer
|
|
|
|
|
|
|
|
# Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3)
|
|
# Paper: https://arxiv.org/abs/1802.09477
|
|
|
|
|
|
class Actor(nn.Module):
|
|
def __init__(self, state_dim, action_dim, max_action):
|
|
super(Actor, self).__init__()
|
|
|
|
self.l1 = nn.Linear(state_dim, 256)
|
|
self.l2 = nn.Linear(256, 256)
|
|
self.l3 = nn.Linear(256, action_dim)
|
|
|
|
self.max_action = max_action
|
|
|
|
|
|
def forward(self, state):
|
|
a = F.relu(self.l1(state))
|
|
a = F.relu(self.l2(a))
|
|
return self.max_action * torch.tanh(self.l3(a))
|
|
|
|
|
|
class Critic(nn.Module):
|
|
def __init__(self, state_dim, action_dim):
|
|
super(Critic, self).__init__()
|
|
|
|
# Q1 architecture
|
|
self.l1 = nn.Linear(state_dim + action_dim, 256)
|
|
self.l2 = nn.Linear(256, 256)
|
|
self.l3 = nn.Linear(256, 1)
|
|
|
|
# Q2 architecture
|
|
self.l4 = nn.Linear(state_dim + action_dim, 256)
|
|
self.l5 = nn.Linear(256, 256)
|
|
self.l6 = nn.Linear(256, 1)
|
|
|
|
|
|
def forward(self, state, action):
|
|
sa = torch.cat([state, action], 1)
|
|
|
|
q1 = F.relu(self.l1(sa))
|
|
q1 = F.relu(self.l2(q1))
|
|
q1 = self.l3(q1)
|
|
|
|
q2 = F.relu(self.l4(sa))
|
|
q2 = F.relu(self.l5(q2))
|
|
q2 = self.l6(q2)
|
|
return q1, q2
|
|
|
|
|
|
def Q1(self, state, action):
|
|
sa = torch.cat([state, action], 1)
|
|
|
|
q1 = F.relu(self.l1(sa))
|
|
q1 = F.relu(self.l2(q1))
|
|
q1 = self.l3(q1)
|
|
return q1
|
|
|
|
|
|
class TD3(object):
|
|
def __init__(
|
|
self,
|
|
state_dim,
|
|
action_dim,
|
|
max_action,
|
|
cfg,
|
|
):
|
|
self.max_action = max_action
|
|
self.gamma = cfg.gamma
|
|
self.lr = cfg.lr
|
|
self.policy_noise = cfg.policy_noise
|
|
self.noise_clip = cfg.noise_clip
|
|
self.policy_freq = cfg.policy_freq
|
|
self.batch_size = cfg.batch_size
|
|
self.device = cfg.device
|
|
self.total_it = 0
|
|
|
|
self.actor = Actor(state_dim, action_dim, max_action).to(self.device)
|
|
self.actor_target = copy.deepcopy(self.actor)
|
|
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
|
|
|
|
self.critic = Critic(state_dim, action_dim).to(self.device)
|
|
self.critic_target = copy.deepcopy(self.critic)
|
|
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
|
|
self.memory = ReplayBuffer(state_dim, action_dim)
|
|
|
|
def choose_action(self, state):
|
|
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
|
|
return self.actor(state).cpu().data.numpy().flatten()
|
|
|
|
def update(self):
|
|
self.total_it += 1
|
|
|
|
# Sample replay buffer
|
|
state, action, next_state, reward, not_done = self.memory.sample(self.batch_size)
|
|
|
|
with torch.no_grad():
|
|
# Select action according to policy and add clipped noise
|
|
noise = (
|
|
torch.randn_like(action) * self.policy_noise
|
|
).clamp(-self.noise_clip, self.noise_clip)
|
|
|
|
next_action = (
|
|
self.actor_target(next_state) + noise
|
|
).clamp(-self.max_action, self.max_action)
|
|
|
|
# Compute the target Q value
|
|
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
|
|
target_Q = torch.min(target_Q1, target_Q2)
|
|
target_Q = reward + not_done * self.gamma * target_Q
|
|
|
|
# Get current Q estimates
|
|
current_Q1, current_Q2 = self.critic(state, action)
|
|
|
|
# Compute critic loss
|
|
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
|
|
|
|
# Optimize the critic
|
|
self.critic_optimizer.zero_grad()
|
|
critic_loss.backward()
|
|
self.critic_optimizer.step()
|
|
|
|
# Delayed policy updates
|
|
if self.total_it % self.policy_freq == 0:
|
|
|
|
# Compute actor losse
|
|
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
|
|
|
|
# Optimize the actor
|
|
self.actor_optimizer.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optimizer.step()
|
|
|
|
# Update the frozen target models
|
|
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
|
target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)
|
|
|
|
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
|
target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)
|
|
|
|
|
|
def save(self, path):
|
|
torch.save(self.critic.state_dict(), path + "td3_critic")
|
|
torch.save(self.critic_optimizer.state_dict(), path + "td3_critic_optimizer")
|
|
|
|
torch.save(self.actor.state_dict(), path + "td3_actor")
|
|
torch.save(self.actor_optimizer.state_dict(), path + "td3_actor_optimizer")
|
|
|
|
|
|
def load(self, path):
|
|
self.critic.load_state_dict(torch.load(path + "td3_critic"))
|
|
self.critic_optimizer.load_state_dict(torch.load(path + "td3_critic_optimizer"))
|
|
self.critic_target = copy.deepcopy(self.critic)
|
|
|
|
self.actor.load_state_dict(torch.load(path + "td3_actor"))
|
|
self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer"))
|
|
self.actor_target = copy.deepcopy(self.actor)
|
|
|