forked from visipedia/tfrecords
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_caltech_tfrecords.py
222 lines (175 loc) · 7.96 KB
/
create_caltech_tfrecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
###################################################################
## INRIA - STARS 2018
## Author: Hung NGUYEN
###################################################################
"""
This script contains code to convert the Caltech - like dataset to the TFRecords with multiple variations could be
set and be handled.
"""
#####import numpy as np
import os
from time import sleep
import cv2
import json
from create_tfrecords import create
## path of folder contains the images of dataset
images_path = ""
## path of folder contains the annotations which should match wi
annotations_path = ""
## image extension
img_extension = ".jpg"
def get_size_img(full_path_image):
exist = os.path.exists(full_path_image)
if not exist:
print("%r doesn't exist" % full_path_image)
img = cv2.imread(full_path_image)
if img is not None:
return img.shape
return None, None, None
DEBUG = False
class TFRecordCreator:
def __init__(self, annotations_path, images_path):
assert os.path.isdir(
images_path), "%r is not a valid folder path, please check if this folder exists" % images_path
assert os.path.isdir(annotations_path), \
"%r is not a valid folder path, please check if this folder exists" % annotations_path
self._annotations_path = annotations_path
if not self._annotations_path.endswith("/"):
self._annotations_path = self._annotations_path + "/"
self._images_path = images_path
if not self._images_path.endswith("/"):
self._images_path = self._images_path + "/"
if DEBUG: print("Init sucessful!!")
self._id = 0
def create_record_for_single_image(self, filename):
with open(self._annotations_path + "/" + filename) as file:
content = file.readlines()
annotations = []
for line in content[1:]:
if line.startswith("person"):
single_annotation = line.split()
annotations.append(single_annotation[1:5])
if len(annotations) > 0:
image_record = {}
## get rid of .txt
if not filename[:-4].endswith(img_extension):
filename = filename[:-4] + img_extension
else:
filename = filename[:-4]
image_record["filename"] = self._images_path + filename
height, width, channels = get_size_img(image_record["filename"])
if height is None:
return None, 0
image_record["height"] = height
image_record["width"] = width
image_record["channels"] = channels
image_record["format"] = img_extension
class_dict = {
"label" : 1,
"text" : "person",
"conf" : 1.
}
image_record["class"] = class_dict
temp_obj_dict = {}
temp_obj_dict["count"] = len(annotations)
temp_obj_dict["area"] = []
temp_obj_dict["id"] = []
temp_x_min = []
temp_x_max = []
temp_y_min = []
temp_y_max = []
for anno in annotations:
x,y,w,h = [int(float(_)) for _ in anno]
normal_x_min = x/width
normal_y_min = y/height
normal_x_max = (x + w)/width
nornal_y_max = (y + h)/height
temp_x_min.append(normal_x_min)
temp_x_max.append(normal_x_max)
temp_y_min.append(normal_y_min)
temp_y_max.append(nornal_y_max)
temp_obj_dict["area"].append((normal_x_max - normal_x_min) * (nornal_y_max - normal_y_min) )
temp_obj_dict["bbox"] = {}
temp_obj_dict["bbox"]["xmin"] = temp_x_min
temp_obj_dict["bbox"]["xmax"] = temp_x_max
temp_obj_dict["bbox"]["ymin"] = temp_y_min
temp_obj_dict["bbox"]["ymax"] = temp_y_max
nb_object = int(len(annotations))
temp_obj_dict["bbox"]["score"] = [1.0] * nb_object
temp_obj_dict["bbox"]["label"] = [1] * nb_object
temp_obj_dict["bbox"]["conf"] = [1.0] * nb_object
temp_obj_dict["bbox"]["text"] = ["person"] * nb_object
temp_obj_dict["id"] = [str(z) for z in range(nb_object)]
image_record["object"] = temp_obj_dict
image_record["id"] = str(self._id)
self._id+=1
return image_record, nb_object
## No object annotation in current image.
return None, 0
def create_records_caltech_format(self,dataset_name = "train" , output_path = ".", num_shards = 10, \
num_threads=5, store_images = True, save_dict = False):
dataset = []
directory = os.fsencode(self._annotations_path)
nb_images = 0
files = os.listdir(directory)
nb_files = len(files)
for file in files:
filename = os.fsdecode(file)
if filename.endswith(".txt"):
record, nb_object = self.create_record_for_single_image(filename)
if nb_object > 0: dataset.append(record)
nb_images += 1
#if nb_images > 9: break
sleep(0.1)
# Update Progress Bar
TFRecordCreator.printProgressBar(nb_images + 1, nb_files, prefix='Progress:', suffix='Complete', length=50)
if DEBUG: print("Done, ", nb_images)
self._id = 0
if save_dict:
with open( dataset_name + '.json', 'w') as fp:
json.dump(dataset, fp)
TFRecordCreator.generate_records_file(dataset, dataset_name, output_path, \
num_shards, num_shards, store_images)
@staticmethod
def create_record_from_saved_dict(json_data_file_name, dataset_name = "train" , output_path = ".", \
num_shards = 10, num_threads=5, store_images = True):
with open(json_data_file_name, 'r') as fp:
saved_dict = json.load(fp)
TFRecordCreator.generate_records_file(saved_dict, dataset_name, output_path, \
num_shards, num_shards, store_images)
@staticmethod
def generate_records_file(saved_dict, dataset_name = "train" , output_path = ".", num_shards = 10, \
num_threads=5, store_images = True):
failed_images = create(
dataset=saved_dict,
dataset_name=dataset_name,
output_directory=output_path,
num_shards=num_shards,
num_threads=num_threads,
store_images=store_images
)
## in case of failed:
print("%d images failed." % (len(failed_images),))
for image_data in failed_images:
print("Image %s: %s" % (image_data['id'], image_data['error_msg']))
@staticmethod
# Print iterations progress
def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█'):
"""
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : positive number of decimals in percent complete (Int)
length - Optional : character length of bar (Int)
fill - Optional : bar fill character (Str)
"""
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bar = fill * filledLength + '-' * (length - filledLength)
print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='\r')
# Print New Line on Complete
if iteration == total:
print()