diff --git a/maraboupy/MarabouNetworkComposition.py b/maraboupy/MarabouNetworkComposition.py index 598f359fd..4e924ab65 100644 --- a/maraboupy/MarabouNetworkComposition.py +++ b/maraboupy/MarabouNetworkComposition.py @@ -8,7 +8,7 @@ All rights reserved. See the file COPYING in the top-level source directory for licensing information. -MarabouNetworkComposition represents neural networks with piecewise linear constraints derived from the ONNX format +MarabouNetworkComposition represents split subnets of a neural network with piecewise linear constraints derived from the ONNX format ''' import numpy as np @@ -94,11 +94,12 @@ def solve(self, filename="", verbose=True, options=None): Returns: (tuple): tuple containing: - exitCode (str): A string representing the exit code (unsat/TIMEOUT/ERROR/UNKNOWN/QUIT_REQUESTED). - - vals (Dict[int, float]): Empty dictionary if UNSAT, otherwise a dictionary of SATisfying values for variables + - vals (Dict[int, float]): Empty dictionary. This is for compatibility with MarabouNetwork. - stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed (Only for the last subnet) """ if options == None: options = MarabouCore.Options() + for i, ipqFile in enumerate(self.ipqs): # load input query ipq = Marabou.loadQuery(ipqFile) @@ -127,12 +128,12 @@ def solve(self, filename="", verbose=True, options=None): _, bounds, _ = MarabouCore.calculateBounds(ipq, options) def encodeCalculateInputBounds(self, ipq, i, bounds): - """Function to encode input variables and calculate bounds for the next subnet - + """Function to encode input variables and set bounds for the current subnet + Args: ipq (:class:`~maraboupy.MarabouCore.InputQuery`): InputQuery object to encode input variables - i (int): Index of the subnet - bounds (dict): Dictionary containing bounds for variables + i (int): Index of the previous subnet + bounds (dict): Dictionary containing bounds for variables of the previous subnet Returns: None @@ -173,7 +174,7 @@ def encodeOutput(self, ipq, i): """Function to encode output variables Args: ipq: (:class:`~maraboupy.MarabouCore.InputQuery`): InputQuery object to encode output variables - i: (int): Index of the subnet + i: (int): Index of the previous subnet Returns: None @@ -297,4 +298,4 @@ def setUpperBound(self, x, v): if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars): self.upperBounds[x] = v else: - raise RuntimeError("Can set bounds only on either input or output variables") \ No newline at end of file + raise RuntimeError("Can set bounds only on either input or output variables") diff --git a/maraboupy/test/test_network_composition.py b/maraboupy/test/test_network_composition.py index f4faf945c..f7d49d3e3 100644 --- a/maraboupy/test/test_network_composition.py +++ b/maraboupy/test/test_network_composition.py @@ -1,8 +1,5 @@ -# Tests MarabouNetwork features not tested by it's subclasses from maraboupy import Marabou -from maraboupy import MarabouCore import os -import numpy as np # Global settings OPT = Marabou.createOptions(verbosity = 0) # Turn off printing @@ -118,4 +115,4 @@ def test_one_split_unsat(): exitCode2, _, _ = network.calculateBounds(options=OPT) # exitCode2 should be also unsat - assert exitCode == exitCode2 \ No newline at end of file + assert exitCode == exitCode2