-
Notifications
You must be signed in to change notification settings - Fork 0
/
conv_layer.py
38 lines (30 loc) · 1.01 KB
/
conv_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Convolutional layer
PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper
Dynamic Routing Between Capsules. NIPS 2017.
https://arxiv.org/abs/1710.09829
Author: Cedric Chee
"""
import torch
import torch.nn as nn
class ConvLayer(nn.Module):
"""
Conventional Conv2d layer
"""
def __init__(self, in_channel, out_channel, kernel_size, relu6=False):
super(ConvLayer, self).__init__()
self.conv0 = nn.Conv2d(in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
stride=1)
if relu6:
self.act = nn.ReLU6(inplace=True)
else:
self.act = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass"""
# x shape: [128, 1, 28, 28]
# out_conv0 shape: [128, 256, 20, 20]
out_conv0 = self.conv0(x)
# out_relu shape: [128, 256, 20, 20]
out_relu = self.act(out_conv0)
return out_relu