-
Notifications
You must be signed in to change notification settings - Fork 5
/
mini_batch_loader.py
95 lines (73 loc) · 3.11 KB
/
mini_batch_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import numpy as np
import cv2
class MiniBatchLoader(object):
def __init__(self, train_path, test_path, image_dir_path, crop_size):
# load data paths
self.training_path_infos = self.read_paths(train_path, image_dir_path)
self.testing_path_infos = self.read_paths(test_path, image_dir_path)
self.crop_size = crop_size
# test ok
@staticmethod
def path_label_generator(txt_path, src_path):
for line in open(txt_path):
line = line.strip()
src_full_path = os.path.join(src_path, line)
if os.path.isfile(src_full_path):
yield src_full_path
# test ok
@staticmethod
def count_paths(path):
c = 0
for _ in open(path):
c += 1
return c
# test ok
@staticmethod
def read_paths(txt_path, src_path):
cs = []
for pair in MiniBatchLoader.path_label_generator(txt_path, src_path):
cs.append(pair)
return cs
def load_training_data(self, indices):
return self.load_data(self.training_path_infos, indices, augment=True)
def load_testing_data(self, indices):
return self.load_data(self.testing_path_infos, indices)
# test ok
def load_data(self, path_infos, indices, augment=False):
mini_batch_size = len(indices)
in_channels = 3
if augment:
xs = np.zeros((mini_batch_size, in_channels, self.crop_size, self.crop_size)).astype(np.float32)
for i, index in enumerate(indices):
path = path_infos[index]
img = cv2.imread(path)
if img is None:
raise RuntimeError("invalid image: {i}".format(i=path))
h, w, _ = img.shape
if np.random.rand() > 0.5:
img = np.fliplr(img)
if np.random.rand() > 0.5:
angle = 10*np.random.rand()
if np.random.rand() > 0.5:
angle *= -1
M = cv2.getRotationMatrix2D((w/2,h/2),angle,1)
img = cv2.warpAffine(img,M,(w,h))
rand_range_h = h-self.crop_size
rand_range_w = w-self.crop_size
x_offset = np.random.randint(rand_range_w)
y_offset = np.random.randint(rand_range_h)
img = np.transpose(img[y_offset:y_offset+self.crop_size, x_offset:x_offset+self.crop_size],(2,0,1))
xs[i, :, :, :] = (img/255).astype(np.float32)
elif mini_batch_size == 1:
for i, index in enumerate(indices):
path = path_infos[index]
img = cv2.imread(path)
if img is None:
raise RuntimeError("invalid image: {i}".format(i=path))
h, w, _ = img.shape
xs = np.zeros((mini_batch_size, in_channels, h, w)).astype(np.float32)
xs[0, :, :, :] = np.transpose((img/255).astype(np.float32),(2,0,1))
else:
raise RuntimeError("mini batch size must be 1 when testing")
return xs