update projects
This commit is contained in:
52
projects/codes/NoisyDQN/noisy_dqn.py
Normal file
52
projects/codes/NoisyDQN/noisy_dqn.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class NoisyLinear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, std_init=0.4):
|
||||
super(NoisyLinear, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.std_init = std_init
|
||||
|
||||
self.weight_mu = nn.Parameter(torch.FloatTensor(output_dim, input_dim))
|
||||
self.weight_sigma = nn.Parameter(torch.FloatTensor(output_dim, input_dim))
|
||||
self.register_buffer('weight_epsilon', torch.FloatTensor(output_dim, input_dim))
|
||||
|
||||
self.bias_mu = nn.Parameter(torch.FloatTensor(output_dim))
|
||||
self.bias_sigma = nn.Parameter(torch.FloatTensor(output_dim))
|
||||
self.register_buffer('bias_epsilon', torch.FloatTensor(output_dim))
|
||||
|
||||
self.reset_parameters()
|
||||
self.reset_noise()
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
weight = self.weight_mu + self.weight_sigma.mul( (self.weight_epsilon))
|
||||
bias = self.bias_mu + self.bias_sigma.mul(Variable(self.bias_epsilon))
|
||||
else:
|
||||
weight = self.weight_mu
|
||||
bias = self.bias_mu
|
||||
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
def reset_parameters(self):
|
||||
mu_range = 1 / math.sqrt(self.weight_mu.size(1))
|
||||
|
||||
self.weight_mu.data.uniform_(-mu_range, mu_range)
|
||||
self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1)))
|
||||
|
||||
self.bias_mu.data.uniform_(-mu_range, mu_range)
|
||||
self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0)))
|
||||
|
||||
def reset_noise(self):
|
||||
epsilon_in = self._scale_noise(self.input_dim)
|
||||
epsilon_out = self._scale_noise(self.output_dim)
|
||||
|
||||
self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
|
||||
self.bias_epsilon.copy_(self._scale_noise(self.output_dim))
|
||||
|
||||
def _scale_noise(self, size):
|
||||
x = torch.randn(size)
|
||||
x = x.sign().mul(x.abs().sqrt())
|
||||
return x
|
||||
Reference in New Issue
Block a user