import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from TD3.memory import ReplayBuffer # Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) # Paper: https://arxiv.org/abs/1802.09477 class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.l1 = nn.Linear(state_dim, 256) self.l2 = nn.Linear(256, 256) self.l3 = nn.Linear(256, action_dim) self.max_action = max_action def forward(self, state): a = F.relu(self.l1(state)) a = F.relu(self.l2(a)) return self.max_action * torch.tanh(self.l3(a)) class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() # Q1 architecture self.l1 = nn.Linear(state_dim + action_dim, 256) self.l2 = nn.Linear(256, 256) self.l3 = nn.Linear(256, 1) # Q2 architecture self.l4 = nn.Linear(state_dim + action_dim, 256) self.l5 = nn.Linear(256, 256) self.l6 = nn.Linear(256, 1) def forward(self, state, action): sa = torch.cat([state, action], 1) q1 = F.relu(self.l1(sa)) q1 = F.relu(self.l2(q1)) q1 = self.l3(q1) q2 = F.relu(self.l4(sa)) q2 = F.relu(self.l5(q2)) q2 = self.l6(q2) return q1, q2 def Q1(self, state, action): sa = torch.cat([state, action], 1) q1 = F.relu(self.l1(sa)) q1 = F.relu(self.l2(q1)) q1 = self.l3(q1) return q1 class TD3(object): def __init__( self, state_dim, action_dim, max_action, cfg, ): self.max_action = max_action self.gamma = cfg.gamma self.lr = cfg.lr self.policy_noise = cfg.policy_noise self.noise_clip = cfg.noise_clip self.policy_freq = cfg.policy_freq self.batch_size = cfg.batch_size self.device = cfg.device self.total_it = 0 self.actor = Actor(state_dim, action_dim, max_action).to(self.device) self.actor_target = copy.deepcopy(self.actor) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) self.critic = Critic(state_dim, action_dim).to(self.device) self.critic_target = copy.deepcopy(self.critic) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) self.memory = ReplayBuffer(state_dim, action_dim) def choose_action(self, state): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) return self.actor(state).cpu().data.numpy().flatten() def update(self): self.total_it += 1 # Sample replay buffer state, action, next_state, reward, not_done = self.memory.sample(self.batch_size) with torch.no_grad(): # Select action according to policy and add clipped noise noise = ( torch.randn_like(action) * self.policy_noise ).clamp(-self.noise_clip, self.noise_clip) next_action = ( self.actor_target(next_state) + noise ).clamp(-self.max_action, self.max_action) # Compute the target Q value target_Q1, target_Q2 = self.critic_target(next_state, next_action) target_Q = torch.min(target_Q1, target_Q2) target_Q = reward + not_done * self.gamma * target_Q # Get current Q estimates current_Q1, current_Q2 = self.critic(state, action) # Compute critic loss critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Delayed policy updates if self.total_it % self.policy_freq == 0: # Compute actor losse actor_loss = -self.critic.Q1(state, self.actor(state)).mean() # Optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # Update the frozen target models for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data) def save(self, path): torch.save(self.critic.state_dict(), path + "td3_critic") torch.save(self.critic_optimizer.state_dict(), path + "td3_critic_optimizer") torch.save(self.actor.state_dict(), path + "td3_actor") torch.save(self.actor_optimizer.state_dict(), path + "td3_actor_optimizer") def load(self, path): self.critic.load_state_dict(torch.load(path + "td3_critic")) self.critic_optimizer.load_state_dict(torch.load(path + "td3_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(path + "td3_actor")) self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)