forked from OAID/Tengine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
classification.py
71 lines (62 loc) · 2.53 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from tengine import tg
import numpy as np
import cv2
import argparse
import os
import time
import numpy as np
DEFAULT_LABEL_FILE = "./synset_words.txt"
DEFAULT_IMG_H = 224
DEFAULT_IMG_W = 224
DEFAULT_SCALE = 0.017
DEFAULT_MEAN1 = 104.007
DEFAULT_MEAN2 = 116.669
DEFAULT_MEAN3 = 122.679
parser = argparse.ArgumentParser(description='classification')
parser.add_argument('-m', '--model', default='./mobilenet.tmfile', type=str)
parser.add_argument('-i', '--image', default='./cat.jpg', type=str)
parser.add_argument('-g', '--image-size', default=f'{DEFAULT_IMG_H},{DEFAULT_IMG_W}', type=str, help='image size: height, width')
parser.add_argument('-w', '--mean-value', default=f'{DEFAULT_MEAN1},{DEFAULT_MEAN2},{DEFAULT_MEAN3}', type=str, help='mean value: mean1, mean2, mean3')
parser.add_argument('-s', '--scale', default=f'{DEFAULT_SCALE}', type=str)
parser.add_argument('-l', '--label', default=f'{DEFAULT_LABEL_FILE}', type=str, help='the default path of labels.txt (e.g. synset_words.txt)')
def get_current_time():
return time.time() * 1000
def read_labels_file(fname):
outs = []
with open(fname, 'r') as fin:
for line in fin.readlines():
outs.append(line.strip())
return outs
def main(args):
image_file = args.image
tm_file = args.model
assert os.path.exists(args.label), f'Label File: {args.label} not found'
assert os.path.exists(image_file), f'Image: {image_file} not found'
assert os.path.exists(tm_file), f'Model: {tm_file} not found'
labels = read_labels_file(args.label)
img_h, img_w = map(int, args.image_size.split(','))
scale = float(args.scale)
mean_value = list(map(float, args.mean_value.split(',')))
assert len(mean_value) == 3, 'The number of mean_value should be 3, e.g. 104.007,116.669,122.679'
img_mean = np.array(mean_value).reshape((1, 1, 3))
data = cv2.imread(image_file)
data = cv2.resize(data, (img_w, img_h))
data = ((data - img_mean) * scale).astype(np.float32)
data = np.ascontiguousarray(data.transpose((2, 0, 1)))
assert data.dtype == np.float32
graph = tg.Graph(None, 'tengine', tm_file)
input_tensor = graph.getInputTensor(0, 0)
dims = [1, 3, img_h, img_w]
input_tensor.shape = dims
graph.preRun()
input_tensor.buf = data
graph.run(1) # 1 is blocking
output_tensor = graph.getOutputTensor(0, 0)
output = np.array(output_tensor.buf)
k = 5
idx = output.argsort()[-1:-k-1:-1]
for i in idx:
print(labels[i], output[i])
if __name__ == '__main__':
args = parser.parse_args()
main(args)