-
Notifications
You must be signed in to change notification settings - Fork 21
/
irislandmarks.py
204 lines (159 loc) · 6.79 KB
/
irislandmarks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Print(nn.Module):
def __init__(self, description=None):
self.description = description
super(Print, self).__init__()
def forward(self, x):
if not self.description is None:
print(self.description)
print(x.shape)
return x
class IrisBlock(nn.Module):
"""This is the main building block for architecture"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1):
super(IrisBlock, self).__init__()
# My impl
self.stride = stride
self.channel_pad = out_channels - in_channels
padding = (kernel_size - 1) // 2
if stride == 2:
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
self.convAct = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=int(out_channels/2), kernel_size=stride, stride=stride, padding=0, bias=True),
nn.PReLU(int(out_channels/2))
)
self.dwConvConv = nn.Sequential(
nn.Conv2d(in_channels=int(out_channels/2), out_channels=int(out_channels/2),
kernel_size=kernel_size, stride=1, padding=padding, # Padding might be wrong here
groups=int(out_channels/2), bias=True),
nn.Conv2d(in_channels=int(out_channels/2), out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=True),
)
self.act = nn.PReLU(out_channels)
def forward(self, x):
h = self.convAct(x)
if self.stride == 2:
x = self.max_pool(x)
h = self.dwConvConv(h)
if self.channel_pad > 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0)
return self.act(h + x)
class IrisLandmarks(nn.Module):
"""The IrisLandmark face landmark model from MediaPipe.
Because we won't be training this model, it doesn't need to have
batchnorm layers. These have already been "folded" into the conv
weights by TFLite.
The conversion to PyTorch is fairly straightforward, but there are
some small differences between TFLite and PyTorch in how they handle
padding on conv layers with stride 2.
This version works on batches, while the MediaPipe version can only
handle a single image at a time.
"""
def __init__(self):
super(IrisLandmarks, self).__init__()
# self.num_coords = 228
# self.x_scale = 64.0
# self.y_scale = 64.0
self.min_score_thresh = 0.75
self._define_layers()
def _define_layers(self):
self.backbone = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=0, bias=True),
nn.PReLU(64),
IrisBlock(64, 64),
IrisBlock(64, 64),
IrisBlock(64, 64),
IrisBlock(64, 64),
IrisBlock(64, 128, stride=2),
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128, stride=2)
)
self.split_eye = nn.Sequential(
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128, stride=2),
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128, stride=2),
IrisBlock(128, 128),
IrisBlock(128, 128),
nn.Conv2d(in_channels=128, out_channels=213, kernel_size=2, stride=1, padding=0, bias=True)
)
self.split_iris = nn.Sequential(
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128, stride=2),
IrisBlock(128, 128),
IrisBlock(128, 128),
IrisBlock(128, 128, stride=2),
IrisBlock(128, 128),
IrisBlock(128, 128),
nn.Conv2d(in_channels=128, out_channels=15, kernel_size=2, stride=1, padding=0, bias=True)
)
def forward(self, x):
# TFLite uses slightly different padding on the first conv layer
# than PyTorch, so do it manually.
x = F.pad(x, [0, 1, 0, 1], "constant", 0)
b = x.shape[0] # batch size, needed for reshaping later
x = self.backbone(x) # (b, 128, 8, 8)
e = self.split_eye(x) # (b, 213, 1, 1)
e = e.view(b, -1) # (b, 213)
i = self.split_iris(x) # (b, 15, 1, 1)
i = i.reshape(b, -1) # (b, 15)
return [e, i]
def _device(self):
"""Which device (CPU or GPU) is being used by this model?"""
return self.backbone[0].weight.device
def load_weights(self, path):
self.load_state_dict(torch.load(path))
self.eval()
def _preprocess(self, x):
"""Converts the image pixels to the range [-1, 1]."""
# return x.float() / 127.5 - 1.0
return x.float() / 255.0 # NOTE: [0.0, 1.0] range seems to give better results
def predict_on_image(self, img):
"""Makes a prediction on a single image.
Arguments:
img: a NumPy array of shape (H, W, 3) or a PyTorch tensor of
shape (3, H, W). The image's height and width should be
64 pixels.
Returns:
A tensor with face detections.
"""
if isinstance(img, np.ndarray):
img = torch.from_numpy(img).permute((2, 0, 1))
return self.predict_on_batch(img.unsqueeze(0))
def predict_on_batch(self, x):
"""Makes a prediction on a batch of images.
Arguments:
x: a NumPy array of shape (b, H, W, 3) or a PyTorch tensor of
shape (b, 3, H, W). The height and width should be 64 pixels.
Returns:
A list containing a tensor of face detections for each image in
the batch. If no faces are found for an image, returns a tensor
of shape (0, 17).
Each face detection is a PyTorch tensor consisting of 17 numbers:
- ymin, xmin, ymax, xmax
- x,y-coordinates for the 6 keypoints
- confidence score
"""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).permute((0, 3, 1, 2))
# x = torch.from_numpy(x)
assert x.shape[1] == 3
assert x.shape[2] == 64
assert x.shape[3] == 64
# 1. Preprocess the images into tensors:
x = x.to(self._device())
x = self._preprocess(x)
# 2. Run the neural network:
with torch.no_grad():
out = self.__call__(x)
# 3. Postprocess the raw predictions:
eye, iris = out
return eye.view(-1, 71, 3), iris.view(-1, 5, 3)