-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split and verify an ONNX file into multiple subnets #697
base: master
Are you sure you want to change the base?
Split and verify an ONNX file into multiple subnets #697
Conversation
maraboupy/Marabou.py
Outdated
@@ -74,6 +76,20 @@ def read_onnx(filename, inputNames=None, outputNames=None, reindexOutputVars=Tru | |||
""" | |||
return MarabouNetworkONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars) | |||
|
|||
def read_onnx_with_threshould(filename, inputNames=None, outputNames=None, reindexOutputVars=True, threshold=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in first threshold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @MatthewDaggitt
Thanks for your proactive review!
However, this PR is WIP.
So don't worry about this so much at this point. I'm working on polishing :)
Anyway, thanks again your support, I will update as you commented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
maraboupy/MarabouNetworkONNX.py
Outdated
@@ -185,6 +190,8 @@ def processGraph(self): | |||
# Recursively create remaining shapes and equations as needed | |||
for outputName in self.outputNames: | |||
self.makeGraphEquations(outputName, True) | |||
# if self.thresholdReached: | |||
# return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obsolete code to be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
maraboupy/Marabou.py
Outdated
filename (str): Path to the ONNX file | ||
inputNames (list of str, optional): List of node names corresponding to inputs | ||
outputNames (list of str, optional): List of node names corresponding to outputs | ||
threshold (int, optional): Threshold for the number of linear equations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more readable name for this would be max_number_of_lin_equations
or something similar. The current name is really not very meaningful, and the comment doesn't clear anything up. What happens when you cross the threshold?
It would also be good to provide in the docs why you would want to limit the number of linear equations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
from maraboupy import MarabouNetworkONNX | ||
|
||
class MarabouNetworkComposition(MarabouNetwork.MarabouNetwork): | ||
"""Constructs a MarabouNetworkComposition object from an ONNX file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment could be more useful. We know that the object constructs a MarabouNetworkComposition
object. What do objects of this class actually represent? What do they do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
network = MarabouNetworkONNX.MarabouNetworkONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold) | ||
|
||
network.saveQuery('q1.ipq') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots and lots of hard coded names here. At the very least they need to be heavily documented, but it would be better to make them configurable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reindexOutputVars was deleted since it is required
maraboupy/MarabouNetworkONNX.py
Outdated
@@ -217,8 +224,15 @@ def makeGraphEquations(self, nodeName, makeEquations): | |||
raise RuntimeError(err_msg) | |||
|
|||
# Compute node's shape and create Marabou equations as needed | |||
if self.thresholdReached: | |||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't doing the thresholding stuff on the C++ side with the final set of equations after the parsing of the network be a much better way of doing it? Then you would get this feature for all parsers and both Python and C++, rather than simply the ONNX network for the Python backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is out of scope of this PR. Maybe it will be supported in the future PR.
…xOutputVars from MarabouNetworkComposition
maraboupy/Marabou.py
Outdated
:class:`~maraboupy.MarabouNetworkComposition.MarabouNetworkComposition` | ||
""" | ||
return MarabouNetworkComposition(filename, inputNames, outputNames, | ||
maxNumberOfLinearEquations=maxNumberOfLinearEquations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation is off
maraboupy/MarabouNetworkONNX.py
Outdated
self.makeMarabouEquations(nodeName, makeEquations) | ||
|
||
if self.maxNumberOfLinearEquations is not None: | ||
if not self.thresholdReached and len(self.equList) > self.maxNumberOfLinearEquations: | ||
if self.splitNetworkAtNode(nodeName, networkNamePostSplit='post_split.onnx'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one should probably use a tempfile name, otherwise two threads running in parallel would be interfering
This PR is going to support a new feature to split an onnx file and verify and sequentially verify each subnet.