From 8129640537d63deac485daaf0f2f1c09e247e928 Mon Sep 17 00:00:00 2001 From: "Haoze(Andrew) Wu" Date: Tue, 27 Feb 2024 07:28:36 -0800 Subject: [PATCH] Fix python API examples (#768) * fix index * run python examples in CI * test --- .github/workflows/ci.yml | 6 ++++++ maraboupy/examples/0_NNetExample.py | 4 ++-- maraboupy/examples/1_TensorflowExample.py | 8 ++++---- maraboupy/examples/2_ONNXExample.py | 2 +- maraboupy/examples/4_DncExample.py | 2 +- maraboupy/examples/5_DisjunctionConstraintExample.py | 2 +- 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a056de1e30..d359cdb094 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -149,6 +149,12 @@ jobs: export PYTHONPATH="$PYTHONPATH:$(dirname $(find $GITHUB_WORKSPACE -name "maraboupy" -type d))" python -c "import maraboupy" + - name: Run Python API examples + run: | + export PYTHONPATH="$PYTHONPATH:../../" + for file in *py; do python -u $file; done + working-directory: maraboupy/examples + - name: Generate Python Code Coverage if: ${{ ( matrix.compiler == 'g++' ) && ( matrix.build_type == 'Debug' ) }} run: python -m pytest --cov=maraboupy --cov-report=xml maraboupy/test diff --git a/maraboupy/examples/0_NNetExample.py b/maraboupy/examples/0_NNetExample.py index ac084fdcdb..18419f6a40 100644 --- a/maraboupy/examples/0_NNetExample.py +++ b/maraboupy/examples/0_NNetExample.py @@ -30,7 +30,7 @@ # %% # Load the network from NNet file, and set a lower bound on first output variable net1 = Marabou.read_nnet(nnetFile) -net1.setLowerBound(net1.outputVars[0][0], .5) +net1.setLowerBound(net1.outputVars[0][0][0], .5) # %% # Solve Marabou query @@ -59,4 +59,4 @@ net2 = Marabou.read_nnet(nnetFile) outputsMarabou = net2.evaluateWithMarabou([inputs]) -assert max(abs(outputsMarabou.flatten() - outputsExpected)) < 1e-8 +assert max(abs(outputsMarabou[0].flatten() - outputsExpected)) < 1e-8 diff --git a/maraboupy/examples/1_TensorflowExample.py b/maraboupy/examples/1_TensorflowExample.py index 02b4cc819d..6b2b1a97f5 100644 --- a/maraboupy/examples/1_TensorflowExample.py +++ b/maraboupy/examples/1_TensorflowExample.py @@ -26,8 +26,8 @@ # Or, you can specify the operation names of the input and output operations. # The default chooses the placeholder operations as input and the last operation as output inputNames = ['Placeholder'] -outputName = 'y_out' -network = Marabou.read_tf(filename = filename, inputNames = inputNames, outputName = outputName) +outputNames = ['y_out'] +network = Marabou.read_tf(filename = filename, inputNames = inputNames, outputNames = outputNames) # %% # Get the input and output variable numbers; [0] since first dimension is batch size @@ -43,8 +43,8 @@ # %% # Set output bounds on the second output variable -network.setLowerBound(outputVars[1], 194.0) -network.setUpperBound(outputVars[1], 210.0) +network.setLowerBound(outputVars[0][1], 194.0) +network.setUpperBound(outputVars[0][1], 210.0) # %% # Call to C++ Marabou solver diff --git a/maraboupy/examples/2_ONNXExample.py b/maraboupy/examples/2_ONNXExample.py index 4680f43ab9..2dbf6f6281 100644 --- a/maraboupy/examples/2_ONNXExample.py +++ b/maraboupy/examples/2_ONNXExample.py @@ -129,4 +129,4 @@ print(onnxEval) print("\nDifference:") print(onnxEval - marabouEval) -assert max(abs(onnxEval - marabouEval)) < 1e-6 +assert max(abs(onnxEval - marabouEval).flatten()) < 1e-6 diff --git a/maraboupy/examples/4_DncExample.py b/maraboupy/examples/4_DncExample.py index 0edc69df32..87e0a2a694 100644 --- a/maraboupy/examples/4_DncExample.py +++ b/maraboupy/examples/4_DncExample.py @@ -25,7 +25,7 @@ # Load an example network and place an output constraint nnet_file_name = "../../src/input_parsers/acas_example/ACASXU_run2a_1_1_tiny_2.nnet" net = Marabou.read_nnet(nnet_file_name) -net.setLowerBound(net.outputVars[0][0], .5) +net.setLowerBound(net.outputVars[0][0][0], .5) # %% # Solve the query with DNC mode turned on, which should return satisfying variable values diff --git a/maraboupy/examples/5_DisjunctionConstraintExample.py b/maraboupy/examples/5_DisjunctionConstraintExample.py index 953c7ca10a..1fd61b123f 100644 --- a/maraboupy/examples/5_DisjunctionConstraintExample.py +++ b/maraboupy/examples/5_DisjunctionConstraintExample.py @@ -26,7 +26,7 @@ # %% # Path to NNet file -nnetFile = "./resources/nnet/mnist/mnist10x10.nnet" +nnetFile = "../../resources/nnet/mnist/mnist10x10.nnet" # %% # Load the network from NNet file, and set a lower bound on first output variable