forked from DMoumita/MLVS
-
Notifications
You must be signed in to change notification settings - Fork 1
/
nnenum.py
167 lines (118 loc) · 4.41 KB
/
nnenum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
'''
nnenum vnnlib front end
usage: "python3 nnenum.py <onnx_file> <vnnlib_file> [timeout=None] [outfile=None]"
Stanley Bak
June 2021
'''
import sys
import numpy as np
from src.enumerate import enumerate_network
from src.settings import Settings
from src.result import Result
from src.onnx_network import load_onnx_network_optimized, load_onnx_network
from src.specification import Specification, DisjunctiveSpec
from src.nnenum_vnnlib import get_num_inputs_outputs, read_vnnlib_simple
def make_spec(vnnlib_filename, onnx_filename):
'''make Specification
returns a pair: (list of [box, Specification], inp_dtype)
'''
num_inputs, num_outputs, inp_dtype = get_num_inputs_outputs(onnx_filename)
vnnlib_spec = read_vnnlib_simple(vnnlib_filename, num_inputs, num_outputs)
rv = []
for box, spec_list in vnnlib_spec:
if len(spec_list) == 1:
mat, rhs = spec_list[0]
spec = Specification(mat, rhs)
else:
spec_obj_list = [Specification(mat, rhs) for mat, rhs in spec_list]
spec = DisjunctiveSpec(spec_obj_list)
rv.append((box, spec))
return rv, inp_dtype
def set_control_settings():
'set settings for smaller control benchmarks'
Settings.TIMING_STATS = False
Settings.PARALLEL_ROOT_LP = False
Settings.SPLIT_IF_IDLE = False
Settings.PRINT_OVERAPPROX_OUTPUT = False
Settings.TRY_QUICK_OVERAPPROX = True
Settings.CONTRACT_ZONOTOPE_LP = True
Settings.CONTRACT_LP_OPTIMIZED = True
Settings.CONTRACT_LP_TRACK_WITNESSES = True
Settings.OVERAPPROX_BOTH_BOUNDS = False
Settings.BRANCH_MODE = Settings.BRANCH_OVERAPPROX
Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER = 1.5
Settings.OVERAPPROX_LP_TIMEOUT = 0.02
Settings.OVERAPPROX_MIN_GEN_LIMIT = 70
def set_image_settings():
'set settings for larger image benchmarks'
Settings.COMPRESS_INIT_BOX = False
Settings.BRANCH_MODE = Settings.BRANCH_OVERAPPROX
Settings.TRY_QUICK_OVERAPPROX = False
Settings.OVERAPPROX_MIN_GEN_LIMIT = np.inf
Settings.SPLIT_IF_IDLE = False
Settings.OVERAPPROX_LP_TIMEOUT = np.inf
Settings.TIMING_STATS = True
# contraction doesn't help in high dimensions
#Settings.OVERAPPROX_CONTRACT_ZONO_LP = False
Settings.CONTRACT_ZONOTOPE = False
Settings.CONTRACT_ZONOTOPE_LP = False
def main():
'main entry point'
if len(sys.argv) < 3:
print('usage: "python3 nnenum.py <onnx_file> <vnnlib_file> [timeout=None] [outfile=None] [processes=<auto>]"')
sys.exit(1)
onnx_filename = sys.argv[1]
vnnlib_filename = sys.argv[2]
timeout = None
outfile = None
if len(sys.argv) >= 4:
timeout = float(sys.argv[3])
if len(sys.argv) >= 5:
outfile = sys.argv[4]
if len(sys.argv) >= 6:
processes = int(sys.argv[5])
Settings.NUM_PROCESSES = processes
onnxFileName1 = onnx_filename.split('/')[-1]
vnnFileName1 = vnnlib_filename.split('/')[-1]
print(f"\nNetwork model: {onnxFileName1}")
print(f"Property file: {vnnFileName1}")
#
spec_list, input_dtype = make_spec(vnnlib_filename, onnx_filename)
try:
network = load_onnx_network_optimized(onnx_filename)
except AssertionError:
# cannot do optimized load due to unsupported layers
network = load_onnx_network(onnx_filename)
result_str = 'none' # gets overridden
num_inputs = len(spec_list[0][0])
if num_inputs < 700:
set_control_settings()
else:
set_image_settings()
for init_box, spec in spec_list:
init_box = np.array(init_box, dtype=input_dtype)
if timeout is not None:
if timeout <= 0:
result_str = 'timeout'
break
Settings.TIMEOUT = timeout
res = enumerate_network(init_box, network, spec)
result_str = res.result_str
if timeout is not None:
# reduce timeout by the runtime
timeout -= res.total_secs
if result_str != "safe":
break
# rename for VNNCOMP21:
if result_str == "safe":
result_str = "Property holds"
elif "unsafe" in result_str:
result_str = "Property violated"
if outfile is not None:
with open(outfile, 'w') as f:
f.write(result_str)
#print(result_str)
if result_str == 'error':
sys.exit(Result.results.index('error'))
if __name__ == '__main__':
main()