import torch import torch.nn as nn class NoisyLinear(nn.Module): def __init__(self, input_dim, output_dim, std_init=0.4): super(NoisyLinear, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.std_init = std_init self.weight_mu = nn.Parameter(torch.FloatTensor(output_dim, input_dim)) self.weight_sigma = nn.Parameter(torch.FloatTensor(output_dim, input_dim)) self.register_buffer('weight_epsilon', torch.FloatTensor(output_dim, input_dim)) self.bias_mu = nn.Parameter(torch.FloatTensor(output_dim)) self.bias_sigma = nn.Parameter(torch.FloatTensor(output_dim)) self.register_buffer('bias_epsilon', torch.FloatTensor(output_dim)) self.reset_parameters() self.reset_noise() def forward(self, x): if self.training: weight = self.weight_mu + self.weight_sigma.mul( (self.weight_epsilon)) bias = self.bias_mu + self.bias_sigma.mul(Variable(self.bias_epsilon)) else: weight = self.weight_mu bias = self.bias_mu return F.linear(x, weight, bias) def reset_parameters(self): mu_range = 1 / math.sqrt(self.weight_mu.size(1)) self.weight_mu.data.uniform_(-mu_range, mu_range) self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1))) self.bias_mu.data.uniform_(-mu_range, mu_range) self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0))) def reset_noise(self): epsilon_in = self._scale_noise(self.input_dim) epsilon_out = self._scale_noise(self.output_dim) self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) self.bias_epsilon.copy_(self._scale_noise(self.output_dim)) def _scale_noise(self, size): x = torch.randn(size) x = x.sign().mul(x.abs().sqrt()) return x