-
Notifications
You must be signed in to change notification settings - Fork 3
/
store_patches_rgba.py
80 lines (56 loc) · 2.12 KB
/
store_patches_rgba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import cPickle as pickle
import mahotas as mh
import numpy as np
import os
import time
import sys
import mlproof as mlp
PATCH_PATH = os.path.expanduser('~/patches/cylinder1_rgba/')
OUTPUT_PATH = os.path.expanduser('~/patches/cylinder1_rgba')
def shuffle_in_unison_inplace(a, b):
assert len(a) == len(b)
p = np.random.permutation(len(a))
return a[p], b[p]
def run(outname, borderprefix):
if outname == 'test':
NO_PATCHES = 13408*2
elif outname == 'train':
NO_PATCHES = 266088*2
PATCH_BYTES = 75*75
P_SIZE = (NO_PATCHES, 4, 75,75) # rather than raveled right now
p_rgba = np.zeros(P_SIZE, dtype=np.float32)
p_target = np.zeros(NO_PATCHES)
i = 0
if outname == 'test':
groups = [(outname,250,300)]
elif outname == 'train':
groups = [(outname,0,50),(outname,50,100),(outname,100,150),(outname,150,200),(outname,200,250)]
for g in groups:
with open(PATCH_PATH+g[0]+'_'+str(g[1])+'_'+str(g[2])+'_error_patches.p', 'rb') as f:
patches = pickle.load(f)
for p in patches:
p_rgba[i][0] = p['image']
p_rgba[i][1] = p['prob']
p_rgba[i][2] = p['merged_array']
p_rgba[i][3] = p[borderprefix+'_overlap']
p_target[i] = 1 # <--- important
i += 1
with open(PATCH_PATH+g[0]+'_'+str(g[1])+'_'+str(g[2])+'_correct_patches.p', 'rb') as f:
patches = pickle.load(f)
for p in patches:
p_rgba[i][0] = p['image']
p_rgba[i][1] = p['prob']
p_rgba[i][2] = p['merged_array']
p_rgba[i][3] = p[borderprefix+'_overlap']
p_target[i] = 0 # <--- important
i += 1
print 'saving'
np.savez(PATCH_PATH+groups[0][0]+'_'+borderprefix+'_unshuffled.npz', rgba=p_rgba)
np.savez(PATCH_PATH+groups[0][0]+'_'+borderprefix+'_targets_unshuffled.npz', rgba=p_target)
print 'Done!'
shuffled = shuffle_in_unison_inplace(p_rgba, p_target)
print 'saving'
np.savez(PATCH_PATH+groups[0][0]+'_'+borderprefix+'.npz', rgba=shuffled[0])
np.savez(PATCH_PATH+groups[0][0]+'_'+borderprefix+'_targets.npz', rgba=shuffled[1])
print 'Done!'
run(sys.argv[1], sys.argv[2])