-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_patches.py
106 lines (89 loc) · 4.51 KB
/
generate_patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import glob
from PIL import Image
import PIL
import random
from utils import *
import six
from six.moves import xrange
import numpy as np
# the pixel value range is '0-1' of training data
# the pixel value range is '0-255'(uint8 ) of training data -- added on 12-11-17
# macro
DATA_AUG_TIMES = 1 # transform a sample to a different sample for DATA_AUG_TIMES times, changed from 4 on 12-11-17
parser = argparse.ArgumentParser(description='')
parser.add_argument('--src_dir', dest='src_dir', default='./data/Train400', help='dir of data')
parser.add_argument('--save_dir', dest='save_dir', default='./data', help='dir of patches')
parser.add_argument('--patch_size', dest='pat_size', type=int, default=40, help='patch size')
parser.add_argument('--stride', dest='stride', type=int, default=10, help='stride')
parser.add_argument('--step', dest='step', type=int, default=0, help='step')
parser.add_argument('--batch_size', dest='bat_size', type=int, default=128, help='batch size')
# check output arguments
parser.add_argument('--from_file', dest='from_file', default="./data/img_clean_pats.npy", help='get pic from file')
parser.add_argument('--num_pic', dest='num_pic', type=int, default=10, help='number of pic to pick')
args = parser.parse_args()
def generate_patches(isDebug=False):
global DATA_AUG_TIMES
count = 0
filepaths = glob.glob(args.src_dir + '/*.png')
if isDebug:
filepaths = filepaths[:10]
# print "number of training data %d" % len(filepaths)
scales = [1, 0.9, 0.8, 0.7]
# calculate the number of patches
for i in xrange(len(filepaths)):
img = Image.open(filepaths[i]).convert('L') # convert RGB to gray
for s in xrange(len(scales)):
newsize = (int(img.size[0] * scales[s]), int(img.size[1] * scales[s]))
img_s = img.resize(newsize, resample=PIL.Image.BICUBIC) # do not change the original img
im_h, im_w = img_s.size
for x in range(0 + args.step, (im_h - args.pat_size + 2), args.stride):
for y in range(0 + args.step, (im_w - args.pat_size + 2), args.stride):
count += 1
origin_patch_num = count * DATA_AUG_TIMES
if origin_patch_num % args.bat_size != 0:
# if the final batch is not complete, make it complete
# totaly (origin_patch_num/args.bat_size + 1) patches
numPatches = (origin_patch_num / args.bat_size + 1) * args.bat_size
else:
numPatches = origin_patch_num
numPatches = int(numPatches)
print(numPatches)
# numPatches = math.ceil(count/args.bat_size) * args.bat_size * 8 -- commented on 12-11-17
# print "total patches = %d , batch size = %d, total batches = %d" % \
# (numPatches, args.bat_size, numPatches / args.bat_size)
# data matrix 4-D
inputs = np.zeros((numPatches, args.pat_size, args.pat_size, 1), dtype="uint8")
count = 0
# generate patches
for i in xrange(len(filepaths)):
img = Image.open(filepaths[i]).convert('L')
for s in xrange(len(scales)):
newsize = (int(img.size[0] * scales[s]), int(img.size[1] * scales[s]))
# print newsize
img_s = img.resize(newsize, resample=PIL.Image.BICUBIC)
img_s = np.reshape(np.array(img_s, dtype="uint8"),
(img_s.size[0], img_s.size[1], 1)) # extend one dimension
for j in xrange(DATA_AUG_TIMES):
im_h, im_w, _ = img_s.shape
for x in range(0 + args.step, im_h - args.pat_size + 1, args.stride):
for y in range(0 + args.step, im_w - args.pat_size + 1, args.stride):
inputs[count, :, :, :] = data_augmentation(img_s[x:x + args.pat_size, y:y + args.pat_size, :], \
random.randint(0, 7))
# cv2.namedWindow('test')
# cv2.imshow('test',inputs[count,...])
# cv2.waitKey(0)
count += 1
# pad the batch
if count < numPatches:
to_pad = numPatches - count
inputs[-to_pad:, :, :, :] = inputs[:to_pad, :, :, :]
# two lines commented on 12-11-17
# assert np.max(inputs) > 1
# inputs = inputs / 255.0 # normalize to [0, 1]
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
np.save(os.path.join(args.save_dir, "img_clean_pats"), inputs)
# print "size of inputs tensor = " + str(inputs.shape)
if __name__ == '__main__':
generate_patches()