-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlp.py
99 lines (87 loc) · 4.22 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from typing import List, Tuple
from torch import nn
class Linear(nn.Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
Args:
in_features: size of each input sample
out_features: size of each output sample
Shape:
- Input: :math:`(*, H_{in})` where :math:`*` means any number of
dimensions including none and :math:`H_{in} = \text{in\_features}`.
- Output: :math:`(*, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
def __init__(self, in_features: int, out_features: int) -> None:
super(Linear, self).__init__()
self.weight = nn.parameter.Parameter(torch.empty((out_features, in_features)))
self.bias = nn.parameter.Parameter(torch.empty(out_features))
def forward(self, input):
"""
:param input: [bsz, in_features]
:return result [bsz, out_features]
"""
output = torch.matmul(input, torch.transpose(self.weight, 0, 1)) + self.bias
return output
class MLP(torch.nn.Module):
def __init__(self, input_size: int, hidden_sizes: List[int], num_classes: int, activation: str = "relu"):
super(MLP, self).__init__()
self.input_size = input_size
self.hidden_sizes = hidden_sizes
assert len(hidden_sizes) > 1, "You should at least have one hidden layer"
self.num_classes = num_classes
self.activation = activation
assert activation in ['tanh', 'relu', 'sigmoid'], "Invalid choice of activation"
self.hidden_layers, self.output_layer = self._build_layers(input_size, hidden_sizes, num_classes)
# Initializaton
self._initialize_linear_layer(self.output_layer)
for layer in self.hidden_layers:
self._initialize_linear_layer(layer)
def _build_layers(self, input_size: int,
hidden_sizes: List[int],
num_classes: int) -> Tuple[nn.ModuleList, nn.Module]:
"""
Build the layers for MLP. Be ware of handlling corner cases.
:param input_size: An int
:param hidden_sizes: A list of ints. E.g., for [32, 32] means two hidden layers with 32 each.
:param num_classes: An int
:Return:
hidden_layers: nn.ModuleList. Within the list, each item has type nn.Module
output_layer: nn.Module
"""
hidden_layers = nn.ModuleList([Linear(input_size, hidden_sizes[0])])
for i in range(len(hidden_sizes) - 1):
hidden_layers.append(Linear(hidden_sizes[i], hidden_sizes[i + 1]))
output_layer = Linear(hidden_sizes[-1], num_classes)
return hidden_layers, output_layer
def activation_fn(self, activation, inputs: torch.Tensor) -> torch.Tensor:
""" process the inputs through different non-linearity function according to activation name """
if activation == "tanh":
return torch.nn.functional.tanh(inputs)
elif activation == "relu":
return torch.nn.functional.relu(inputs)
elif activation == "sigmoid":
return torch.nn.functional.sigmoid(inputs)
def _initialize_linear_layer(self, module: Linear) -> None:
""" For bias set to zeros. For weights set to glorot normal """
nn.init.xavier_normal_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, images: torch.Tensor) -> torch.Tensor:
""" Forward images and compute logits.
1. The images are first fattened to vectors.
2. Forward the result to each layer in the self.hidden_layer with activation_fn
3. Finally forward the result to the output_layer.
:param images: [batch, channels, width, height]
:return logits: [batch, num_classes]
"""
m = nn.Flatten(start_dim=1)
x = m(images)
for i, l in enumerate(self.hidden_layers):
x = self.activation_fn(self.activation, l(x))
x = self.output_layer(x)
return x