41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
@Author: John
|
|
@Email: johnjim0816@gmail.com
|
|
@Date: 2020-06-11 12:18:12
|
|
@LastEditor: John
|
|
@LastEditTime: 2020-06-11 17:23:45
|
|
@Discription:
|
|
@Environment: python 3.7.7
|
|
'''
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class CNN(nn.Module):
|
|
|
|
def __init__(self, h, w, n_outputs):
|
|
super(CNN, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
|
|
self.bn1 = nn.BatchNorm2d(16)
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
|
|
self.bn2 = nn.BatchNorm2d(32)
|
|
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
|
|
self.bn3 = nn.BatchNorm2d(32)
|
|
|
|
# Number of Linear input connections depends on output of conv2d layers
|
|
# and therefore the input image size, so compute it.
|
|
def conv2d_size_out(size, kernel_size = 5, stride = 2):
|
|
return (size - (kernel_size - 1) - 1) // stride + 1
|
|
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
|
|
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
|
|
linear_input_size = convw * convh * 32
|
|
self.head = nn.Linear(linear_input_size, n_outputs)
|
|
|
|
# Called with either one element to determine next action, or a batch
|
|
# during optimization. Returns tensor([[left0exp,right0exp]...]).
|
|
def forward(self, x):
|
|
x = F.relu(self.bn1(self.conv1(x)))
|
|
x = F.relu(self.bn2(self.conv2(x)))
|
|
x = F.relu(self.bn3(self.conv3(x)))
|
|
return self.head(x.view(x.size(0), -1)) |