update projects
This commit is contained in:
52
projects/codes/NoisyDQN/noisy_dqn.py
Normal file
52
projects/codes/NoisyDQN/noisy_dqn.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class NoisyLinear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, std_init=0.4):
|
||||
super(NoisyLinear, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.std_init = std_init
|
||||
|
||||
self.weight_mu = nn.Parameter(torch.FloatTensor(output_dim, input_dim))
|
||||
self.weight_sigma = nn.Parameter(torch.FloatTensor(output_dim, input_dim))
|
||||
self.register_buffer('weight_epsilon', torch.FloatTensor(output_dim, input_dim))
|
||||
|
||||
self.bias_mu = nn.Parameter(torch.FloatTensor(output_dim))
|
||||
self.bias_sigma = nn.Parameter(torch.FloatTensor(output_dim))
|
||||
self.register_buffer('bias_epsilon', torch.FloatTensor(output_dim))
|
||||
|
||||
self.reset_parameters()
|
||||
self.reset_noise()
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
weight = self.weight_mu + self.weight_sigma.mul( (self.weight_epsilon))
|
||||
bias = self.bias_mu + self.bias_sigma.mul(Variable(self.bias_epsilon))
|
||||
else:
|
||||
weight = self.weight_mu
|
||||
bias = self.bias_mu
|
||||
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
def reset_parameters(self):
|
||||
mu_range = 1 / math.sqrt(self.weight_mu.size(1))
|
||||
|
||||
self.weight_mu.data.uniform_(-mu_range, mu_range)
|
||||
self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1)))
|
||||
|
||||
self.bias_mu.data.uniform_(-mu_range, mu_range)
|
||||
self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0)))
|
||||
|
||||
def reset_noise(self):
|
||||
epsilon_in = self._scale_noise(self.input_dim)
|
||||
epsilon_out = self._scale_noise(self.output_dim)
|
||||
|
||||
self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
|
||||
self.bias_epsilon.copy_(self._scale_noise(self.output_dim))
|
||||
|
||||
def _scale_noise(self, size):
|
||||
x = torch.randn(size)
|
||||
x = x.sign().mul(x.abs().sqrt())
|
||||
return x
|
||||
25
projects/codes/NoisyDQN/task0_train.ipynb
Normal file
25
projects/codes/NoisyDQN/task0_train.ipynb
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"curr_path = str(Path().absolute()) # 当前路径\n",
|
||||
"parent_path = str(Path().absolute().parent) # 父路径\n",
|
||||
"sys.path.append(parent_path) # 添加路径到系统路径"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Reference in New Issue
Block a user