-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_generator.py
65 lines (48 loc) · 1.91 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import os
import cv2
import json
import h5py
from utils.categories_conversion_utils import *
from utils.directory_utils import *
class DataGenerator:
def __init__(self, images_folder='data/color', batch_size=24):
self.images_folder = images_folder
self.batch_size = batch_size
return
def preprocess_input(self, x):
x /= 255.
x -= 0.5
x *= 2.
return x
def save_folder_names_json(self, path_to_json_file="data/category_dictionary.json"):
category_names = [name for name in os.listdir(self.images_folder)]
category_dictionary = {}
category_names = sorted(category_names)
for i in range(len(category_names)):
category_dictionary[category_names[i]] = i
with open(path_to_json_file, 'w') as fp:
json.dump(category_dictionary, fp)
def prepare_batch_data(self, data, resize_shape):
batch_data = np.zeros((len(data), resize_shape[0], resize_shape[1], 3))
batch_labels = []
for i in range(len(data)):
image_path = data[i][0]
batch_labels.append(category_to_one_hot(data[i][1]))
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = cv2.resize(image, resize_shape)
image = self.preprocess_input(image.astype(float))
batch_data[i] = image
return batch_data, np.array(batch_labels)
def generate_data(self, data, resize_shape, testing=False):
data = np.array(data)
while True:
np.random.shuffle(data)
num_samples = len(data)
for i in range(0, num_samples, self.batch_size):
batch_data = data[i: i + self.batch_size]
images, labels = self.prepare_batch_data(batch_data, resize_shape)
yield images, labels
if testing:
break