Files
easy-rl/codes/TD3/agent.py
2022-05-31 01:20:58 +08:00

178 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2021-12-22 10:40:05
LastEditor: JiangJi
LastEditTime: 2021-12-22 10:43:55
Discription:
'''
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from TD3.memory import ReplayBuffer
class Actor(nn.Module):
def __init__(self, input_dim, output_dim, max_action):
'''[summary]
Args:
input_dim (int): 输入维度这里等于n_states
output_dim (int): 输出维度这里等于n_actions
max_action (int): action的最大值
'''
super(Actor, self).__init__()
self.l1 = nn.Linear(input_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, output_dim)
self.max_action = max_action
def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
class Critic(nn.Module):
def __init__(self, input_dim, output_dim):
super(Critic, self).__init__()
# Q1 architecture
self.l1 = nn.Linear(input_dim + output_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, 1)
# Q2 architecture
self.l4 = nn.Linear(input_dim + output_dim, 256)
self.l5 = nn.Linear(256, 256)
self.l6 = nn.Linear(256, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = F.relu(self.l4(sa))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
def Q1(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
return q1
class TD3(object):
def __init__(
self,
input_dim,
output_dim,
max_action,
cfg,
):
self.max_action = max_action
self.gamma = cfg.gamma
self.lr = cfg.lr
self.policy_noise = cfg.policy_noise
self.noise_clip = cfg.noise_clip
self.policy_freq = cfg.policy_freq
self.batch_size = cfg.batch_size
self.device = cfg.device
self.total_it = 0
self.actor = Actor(input_dim, output_dim, max_action).to(self.device)
self.actor_target = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
self.critic = Critic(input_dim, output_dim).to(self.device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
self.memory = ReplayBuffer(input_dim, output_dim)
def choose_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
return self.actor(state).cpu().data.numpy().flatten()
def update(self):
self.total_it += 1
# Sample replay buffer
state, action, next_state, reward, not_done = self.memory.sample(self.batch_size)
with torch.no_grad():
# Select action according to policy and add clipped noise
noise = (
torch.randn_like(action) * self.policy_noise
).clamp(-self.noise_clip, self.noise_clip)
next_action = (
self.actor_target(next_state) + noise
).clamp(-self.max_action, self.max_action)
# Compute the target Q value
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + not_done * self.gamma * target_Q
# Get current Q estimates
current_Q1, current_Q2 = self.critic(state, action)
# Compute critic loss
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Delayed policy updates
if self.total_it % self.policy_freq == 0:
# Compute actor losse
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update the frozen target models
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)
def save(self, path):
torch.save(self.critic.state_dict(), path + "td3_critic")
torch.save(self.critic_optimizer.state_dict(), path + "td3_critic_optimizer")
torch.save(self.actor.state_dict(), path + "td3_actor")
torch.save(self.actor_optimizer.state_dict(), path + "td3_actor_optimizer")
def load(self, path):
self.critic.load_state_dict(torch.load(path + "td3_critic"))
self.critic_optimizer.load_state_dict(torch.load(path + "td3_critic_optimizer"))
self.critic_target = copy.deepcopy(self.critic)
self.actor.load_state_dict(torch.load(path + "td3_actor"))
self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer"))
self.actor_target = copy.deepcopy(self.actor)