forked from Seanlinx/mtcnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
94 lines (78 loc) · 3.71 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import mxnet as mx
import argparse
import cv2
import time
from core.symbol import P_Net, R_Net, O_Net
from core.imdb import IMDB
from config import config
from core.loader import TestLoader
from core.detector import Detector
from core.fcn_detector import FcnDetector
from tools.load_model import load_param
from core.MtcnnDetector import MtcnnDetector
def test_net(prefix, epoch, batch_size, ctx,
thresh=[0.6, 0.6, 0.7], min_face_size=24,
stride=2, slide_window=False):
detectors = [None, None, None]
# load pnet model
args, auxs = load_param(prefix[0], epoch[0], convert=False, ctx=ctx)
if slide_window:
PNet = Detector(P_Net("test"), 12, batch_size[0], ctx, args, auxs)
else:
PNet = FcnDetector(P_Net("test"), ctx, args, auxs)
detectors[0] = PNet
# load rnet model
args, auxs = load_param(prefix[1], epoch[0], convert=False, ctx=ctx)
RNet = Detector(R_Net("test"), 24, batch_size[1], ctx, args, auxs)
detectors[1] = RNet
# load onet model
args, auxs = load_param(prefix[2], epoch[2], convert=False, ctx=ctx)
ONet = Detector(O_Net("test"), 48, batch_size[2], ctx, args, auxs)
detectors[2] = ONet
mtcnn_detector = MtcnnDetector(detectors=detectors, ctx=ctx, min_face_size=min_face_size,
stride=stride, threshold=thresh, slide_window=slide_window)
img = cv2.imread('test01.jpg')
t1 = time.time()
boxes, boxes_c = mtcnn_detector.detect_pnet(img)
boxes, boxes_c = mtcnn_detector.detect_rnet(img, boxes_c)
boxes, boxes_c = mtcnn_detector.detect_onet(img, boxes_c)
print 'time: ',time.time() - t1
if boxes_c is not None:
draw = img.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
for b in boxes_c:
cv2.rectangle(draw, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0, 255, 255), 1)
cv2.putText(draw, '%.3f'%b[4], (int(b[0]), int(b[1])), font, 0.4, (255, 255, 255), 1)
cv2.imshow("detection result", draw)
cv2.waitKey(0)
def parse_args():
parser = argparse.ArgumentParser(description='Test mtcnn',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--prefix', dest='prefix', help='prefix of model name', nargs="+",
default=['model/pnet', 'model/rnet', 'model/onet'], type=str)
parser.add_argument('--epoch', dest='epoch', help='epoch number of model to load', nargs="+",
default=[16, 16, 16], type=int)
parser.add_argument('--batch_size', dest='batch_size', help='list of batch size used in prediction', nargs="+",
default=[2048, 256, 16], type=int)
parser.add_argument('--thresh', dest='thresh', help='list of thresh for pnet, rnet, onet', nargs="+",
default=[0.5, 0.5, 0.7], type=float)
parser.add_argument('--min_face', dest='min_face', help='minimum face size for detection',
default=40, type=int)
parser.add_argument('--stride', dest='stride', help='stride of sliding window',
default=2, type=int)
parser.add_argument('--sw', dest='slide_window', help='use sliding window in pnet', action='store_true')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device to train with',
default=0, type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print 'Called with argument:'
print args
ctx = mx.gpu(args.gpu_id)
if args.gpu_id == -1:
ctx = mx.cpu(0)
test_net(args.prefix, args.epoch, args.batch_size,
ctx, args.thresh, args.min_face,
args.stride, args.slide_window)