Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split and verify an ONNX file into multiple subnets #697

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

tagomaru
Copy link
Contributor

This PR is going to support a new feature to split an onnx file and verify and sequentially verify each subnet.

@@ -74,6 +76,20 @@ def read_onnx(filename, inputNames=None, outputNames=None, reindexOutputVars=Tru
"""
return MarabouNetworkONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars)

def read_onnx_with_threshould(filename, inputNames=None, outputNames=None, reindexOutputVars=True, threshold=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in first threshold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MatthewDaggitt
Thanks for your proactive review!
However, this PR is WIP.
So don't worry about this so much at this point. I'm working on polishing :)
Anyway, thanks again your support, I will update as you commented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -185,6 +190,8 @@ def processGraph(self):
# Recursively create remaining shapes and equations as needed
for outputName in self.outputNames:
self.makeGraphEquations(outputName, True)
# if self.thresholdReached:
# return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obsolete code to be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

filename (str): Path to the ONNX file
inputNames (list of str, optional): List of node names corresponding to inputs
outputNames (list of str, optional): List of node names corresponding to outputs
threshold (int, optional): Threshold for the number of linear equations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more readable name for this would be max_number_of_lin_equations or something similar. The current name is really not very meaningful, and the comment doesn't clear anything up. What happens when you cross the threshold?

It would also be good to provide in the docs why you would want to limit the number of linear equations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from maraboupy import MarabouNetworkONNX

class MarabouNetworkComposition(MarabouNetwork.MarabouNetwork):
"""Constructs a MarabouNetworkComposition object from an ONNX file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment could be more useful. We know that the object constructs a MarabouNetworkComposition object. What do objects of this class actually represent? What do they do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


network = MarabouNetworkONNX.MarabouNetworkONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)

network.saveQuery('q1.ipq')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots and lots of hard coded names here. At the very least they need to be heavily documented, but it would be better to make them configurable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reindexOutputVars was deleted since it is required

@@ -217,8 +224,15 @@ def makeGraphEquations(self, nodeName, makeEquations):
raise RuntimeError(err_msg)

# Compute node's shape and create Marabou equations as needed
if self.thresholdReached:
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't doing the thresholding stuff on the C++ side with the final set of equations after the parsing of the network be a much better way of doing it? Then you would get this feature for all parsers and both Python and C++, rather than simply the ONNX network for the Python backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is out of scope of this PR. Maybe it will be supported in the future PR.

@tagomaru tagomaru requested a review from wu-haoze January 2, 2024 07:40
:class:`~maraboupy.MarabouNetworkComposition.MarabouNetworkComposition`
"""
return MarabouNetworkComposition(filename, inputNames, outputNames,
maxNumberOfLinearEquations=maxNumberOfLinearEquations)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation is off

self.makeMarabouEquations(nodeName, makeEquations)

if self.maxNumberOfLinearEquations is not None:
if not self.thresholdReached and len(self.equList) > self.maxNumberOfLinearEquations:
if self.splitNetworkAtNode(nodeName, networkNamePostSplit='post_split.onnx'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one should probably use a tempfile name, otherwise two threads running in parallel would be interfering

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants