-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_cnn.py
102 lines (73 loc) · 2.82 KB
/
run_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import mlproof as mlp
import mlproof.nets as nets
import numpy as np
import os
import cPickle as pickle
import sys
sys.setrecursionlimit(1000000000)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", type=str, help="the cnn type", default='MergeNet')
parser.add_argument("-p", "--patchpath", type=str, help="the patch folder in the datapath", default='cylinder_small1')
parser.add_argument("-b", "--border", type=str, help="the border to use", default='larger_border_overlap')
parser.add_argument("-d", "--desc", type=str, help="the description", default='test')
args = parser.parse_args()
CNN_NAME = args.name
PATCH_PATH = args.patchpath
BORDER = args.border
DESC = args.desc
OUTPUT_FOLDER = os.path.expanduser('~/nets/'+CNN_NAME+'_'+PATCH_PATH+'_'+BORDER+'_'+DESC)
if not os.path.exists(OUTPUT_FOLDER):
os.makedirs(OUTPUT_FOLDER)
# create CNN
cnn = eval("nets."+CNN_NAME)()
# load data
if CNN_NAME.startswith('RGB'):
if BORDER == 'larger_border_overlap':
border_prefix = 'larger_border'
else:
border_prefix = 'border'
X_train, y_train, X_test, y_test = mlp.Patch.load_rgba(PATCH_PATH, border_prefix=border_prefix)
else:
X_train, y_train, X_test, y_test = mlp.Patch.load(PATCH_PATH)
#
# train and test inputs
#
if CNN_NAME.startswith('MergeNetThreeLeg'):
# 3 leg version
X_train_input = {'image_input': X_train['image'],
'prob_input': X_train['prob'],
'binary_input': X_train['merged_array']}
X_test_input = {'image_input': X_test['image'],
'prob_input': X_test['prob'],
'binary_input': X_test['merged_array']}
elif CNN_NAME.startswith('MergeNet'):
# 4 leg version
X_train_input = {'image_input': X_train['image'],
'prob_input': X_train['prob'],
'binary_input': X_train['merged_array'],
'border_input': X_train[BORDER]}
X_test_input = {'image_input': X_test['image'],
'prob_input': X_test['prob'],
'binary_input': X_test['merged_array'],
'border_input': X_test[BORDER]}
elif CNN_NAME.startswith('RGB'):
if CNN_NAME.startswith('RGBA'):
X_train_input = X_train
X_test_input = X_test
else:
# this is only RGB
X_train_input = X_train[:,:-1,:,:]
X_test_input = X_test[:,:-1,:,:]
# train
cnn = cnn.fit(X_train_input, y_train)
# test
test_accuracy = cnn.score(X_test_input, y_test)
print test_accuracy
with open(os.path.join(OUTPUT_FOLDER, 'test_'+str(test_accuracy)+'.txt'), 'w') as f:
f.write(str(test_accuracy))
# store CNN
with open(os.path.join(OUTPUT_FOLDER, 'net.p'), 'wb') as f:
pickle.dump(cnn, f, -1)
print 'All done.'