-
Notifications
You must be signed in to change notification settings - Fork 21
/
test_iris.py
77 lines (64 loc) · 2.09 KB
/
test_iris.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
from irislandmarks import IrisLandmarks
import matplotlib.pyplot as plt
import cv2
import numpy as np
def centerCropSquare(img, center, side=None, scaleWRTHeight=None):
a = side is None
b = scaleWRTHeight is None
assert (not a and b) or (a and not b) # Python doesn't have "xor"... C'mon Python!
half = 0
if side is None:
half = int(img.shape[0]*scaleWRTHeight/2)
else:
half = int(side/2)
return img[(center[0] - half):(center[0] + half), (center[1] - half):(center[1] + half), :]
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = IrisLandmarks().to(gpu)
net.load_weights("irislandmarks.pth")
img = cv2.imread("test.jpg")
centerRight = [485, 332]
centerLeft = [479, 638]
img = centerCropSquare(img, centerRight, side=400) # 400 is 1200 (image size) * 64/192, as the detector takes a 64x64 box inside the 192 image
plt.imshow(img)
plt.show()
# tl = [467, 284]
# br = [504, 397]
# w = br[1] - tl[1]
# h = br[0] - tl[0]
# w = int(w*2.3)
# h = int(h*2.3)
# tl[0] -= int(h/2)
# tl[1] -= int(w/2)
# br[0] = tl[0] + h
# br[1] = tl[1] + w
# img = img[tl[0]:br[0], tl[1]:br[1]]
# plt.imshow(img)
# plt.show()
# img = np.fliplr(img) # the detector is trained on the left eye only, hence the flip
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = cv2.resize(img, (64, 64))
# test = net.predict_on_image(img)
eye_gpu, iris_gpu = net.predict_on_image(img)
eye = eye_gpu.cpu().numpy()
iris = iris_gpu.cpu().numpy()
plt.imshow(img, zorder=1)
x, y = eye[:, 0], eye[:, 1]
plt.scatter(x, y, zorder=2, s=1.0)
x, y = iris[:, 0], iris[:, 1]
plt.scatter(x, y, zorder=2, s=1.0, c='r')
plt.show()
# torch.onnx.export(
# net,
# (torch.randn(1,3,64,64, device='cuda'), ),
# "irislandmarks.onnx",
# input_names=("image", ),
# output_names=("preds", "conf"),
# opset_version=9
# )