-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_json.py
140 lines (110 loc) · 5.12 KB
/
test_json.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
import cv2
import os
import json
import time
import argparse
import numpy as np
import tensorflow as tf
from keras.layers import Input
from keras.models import Model
from net.vgg16 import VGG16_UNet
from collections import OrderedDict
from utils.file_util import list_files, saveResult
from utils.inference_util import getDetBoxes, adjustResultCoordinates
from utils.img_util import load_image, img_resize, img_normalize, to_heat_map
class DateEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
parser.add_argument('--trained_model', default='weights/Syn_cus_1_5_25k.h5', type=str, help='pretrained model')
parser.add_argument('--gpu_list', type=str, default='0', help='list of gpu to use')
parser.add_argument('--text_threshold', default=0.1, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.3, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.2, type=float, help='link confidence threshold')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1., type=float, help='image magnification ratio')
parser.add_argument('--show_time', default=True, action='store_true', help='show processing time')
parser.add_argument('--test_folder', default=r'data/CTW/images-test',
type=str, help='folder path to input images')
FLAGS = parser.parse_args()
result_folder = 'results/Cus'
if not os.path.isdir(result_folder):
os.mkdir(result_folder)
def predict(model, image, text_threshold, link_threshold, low_text):
t0 = time.time()
# resize
h, w = image.shape[:2]
mag_ratio = 600 / max(h, w)
# img_resized, target_ratio = img_resize(image, FLAGS.mag_ratio, FLAGS.canvas_size, interpolation=cv2.INTER_LINEAR)
img_resized, target_ratio = img_resize(image, mag_ratio, FLAGS.canvas_size, interpolation=cv2.INTER_LINEAR)
ratio_h = ratio_w = 1 / target_ratio
# preprocessing
x = img_normalize(img_resized)
# make score and link map
score_text, score_link = model.predict(np.array([x]))
score_text = score_text[0]
score_link = score_link[0]
t0 = time.time() - t0
t1 = time.time()
# Post-processing
boxes = getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text)
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
t1 = time.time() - t1
# render results (optional)
render_img = score_text.copy()
white_img = np.ones((render_img.shape[0], 10, 3), dtype=np.uint8) * 255
ret_score_text = np.hstack((to_heat_map(render_img), white_img, to_heat_map(score_link)))
# ret_score_text = to_heat_map(render_img)
if FLAGS.show_time:
print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
return boxes, ret_score_text
def saveJson(image_path, bboxes, dirname):
print(os.path.join(dirname,os.path.basename(image_path)+'.json'))
os.mknod(os.path.join(dirname,os.path.basename(image_path)+'.json'))
temp = dict()
temp["bboxes"] = bboxes
with open(os.path.join(dirname,os.path.basename(image_path)+'.json'),'w',encoding='utf-8') as f:
json.dump(temp, f, ensure_ascii=False, cls=DateEncoder)
def test():
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
""" Load model """
input_image = Input(shape=(None, None, 3), name='image', dtype=tf.float32)
region, affinity = VGG16_UNet(input_tensor=input_image, weights=None)
model = Model(inputs=[input_image], outputs=[region, affinity])
model.load_weights(FLAGS.trained_model)
""" For test images in a folder """
image_list, _, _ = list_files(FLAGS.test_folder)
t = time.time()
""" Test images """
for k, image_path in enumerate(image_list):
print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\r')
image = load_image(image_path)
start_time = time.time()
bboxes, score_text = predict(model, image, FLAGS.text_threshold, FLAGS.link_threshold, FLAGS.low_text)
print(time.time() * 1000 - start_time * 1000)
# save score text
filename, file_ext = os.path.splitext(os.path.basename(image_path))
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
cv2.imwrite(mask_file, score_text)
# save text
saveResult(image_path, image[:, :, ::-1], bboxes, dirname=result_folder)
# save json
saveJson(image_path, bboxes, dirname=result_folder)
print("elapsed time : {}s".format(time.time() - t))
if __name__ == '__main__':
test()