52 lines
2.0 KiB
Python
52 lines
2.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class NoisyLinear(nn.Module):
|
|
def __init__(self, input_dim, output_dim, std_init=0.4):
|
|
super(NoisyLinear, self).__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
self.std_init = std_init
|
|
|
|
self.weight_mu = nn.Parameter(torch.FloatTensor(output_dim, input_dim))
|
|
self.weight_sigma = nn.Parameter(torch.FloatTensor(output_dim, input_dim))
|
|
self.register_buffer('weight_epsilon', torch.FloatTensor(output_dim, input_dim))
|
|
|
|
self.bias_mu = nn.Parameter(torch.FloatTensor(output_dim))
|
|
self.bias_sigma = nn.Parameter(torch.FloatTensor(output_dim))
|
|
self.register_buffer('bias_epsilon', torch.FloatTensor(output_dim))
|
|
|
|
self.reset_parameters()
|
|
self.reset_noise()
|
|
|
|
def forward(self, x):
|
|
if self.training:
|
|
weight = self.weight_mu + self.weight_sigma.mul( (self.weight_epsilon))
|
|
bias = self.bias_mu + self.bias_sigma.mul(Variable(self.bias_epsilon))
|
|
else:
|
|
weight = self.weight_mu
|
|
bias = self.bias_mu
|
|
|
|
return F.linear(x, weight, bias)
|
|
|
|
def reset_parameters(self):
|
|
mu_range = 1 / math.sqrt(self.weight_mu.size(1))
|
|
|
|
self.weight_mu.data.uniform_(-mu_range, mu_range)
|
|
self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1)))
|
|
|
|
self.bias_mu.data.uniform_(-mu_range, mu_range)
|
|
self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0)))
|
|
|
|
def reset_noise(self):
|
|
epsilon_in = self._scale_noise(self.input_dim)
|
|
epsilon_out = self._scale_noise(self.output_dim)
|
|
|
|
self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
|
|
self.bias_epsilon.copy_(self._scale_noise(self.output_dim))
|
|
|
|
def _scale_noise(self, size):
|
|
x = torch.randn(size)
|
|
x = x.sign().mul(x.abs().sqrt())
|
|
return x |