Skip to content

Commit

Permalink
fix failure in bnn
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-haoze committed Jan 6, 2024
1 parent 3a6d2b1 commit 83fbff7
Show file tree
Hide file tree
Showing 13 changed files with 43 additions and 37 deletions.
9 changes: 5 additions & 4 deletions regress/regress1/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,12 @@ endforeach()
# Proof production tests

# ReLU
marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_5_9" "3" unsat)
# TODO: move the commented out long running test to higher regression level.
#marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_5_9" "3" unsat)
marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_5_7" "3" unsat)
marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_3_7" "3" unsat)
marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_2_9" "4" unsat)
marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_2_9" "3" unsat)
#marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_3_7" "3" unsat)
#marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_2_9" "4" unsat)
#marabou_add_acasxu_proof_test(1 "ACASXU_experimental_v2a_2_9" "3" unsat)
marabou_add_coav_proof_test(1 "reluBenchmark1.30941200256s_UNSAT.nnet" unsat)
marabou_add_coav_proof_test(1 "reluBenchmark0.453322172165s_UNSAT.nnet" unsat)
marabou_add_coav_proof_test(1 "reluBenchmark0.30711388588s_UNSAT.nnet" unsat)
Expand Down
2 changes: 1 addition & 1 deletion regress/regress2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ marabou_add_acasxu_dnc_test(1 "ACASXU_experimental_v2a_3_9" "3" unsat) #
# marabou_add_acasxu_dnc_test(1 "ACASXU_experimental_v2a_3_5" "2" sat) # timeout
# marabou_add_acasxu_dnc_test(1 "ACASXU_experimental_v2a_3_9" "3" unsat) # 200sec

marabou_add_mnist_test(2 "mnist10x10.nnet" "image3_target8_epsilon0.1.txt" sat) #16sec
#marabou_add_mnist_test(2 "mnist10x10.nnet" "image3_target8_epsilon0.1.txt" sat) #16sec
marabou_add_mnist_test(2 "mnist10x10.nnet" "image3_target9_epsilon0.005.txt" unsat) #17sec

# Add all input query files in regress2/input_queries/ as tests.
Expand Down
5 changes: 3 additions & 2 deletions src/configuration/GlobalConfiguration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ const bool GlobalConfiguration::ONLY_AUX_INITIAL_BASIS = false;
const GlobalConfiguration::ExplicitBasisBoundTighteningType GlobalConfiguration::EXPLICIT_BASIS_BOUND_TIGHTENING_TYPE =
GlobalConfiguration::COMPUTE_INVERTED_BASIS_MATRIX;
const bool GlobalConfiguration::EXPLICIT_BOUND_TIGHTENING_UNTIL_SATURATION = false;
const double GlobalConfiguration::EXPLICIT_BASIS_BOUND_TIGHTENING_ROUNDING_CONSTANT = 1e-6;

const unsigned GlobalConfiguration::REFACTORIZATION_THRESHOLD = 100;
const GlobalConfiguration::BasisFactorizationType GlobalConfiguration::BASIS_FACTORIZATION_TYPE =
Expand All @@ -112,9 +113,9 @@ const bool GlobalConfiguration::GUROBI_LOGGING = false;

// Logging - note that it is enabled only in Debug mode
const bool GlobalConfiguration::DNC_MANAGER_LOGGING = true;
const bool GlobalConfiguration::ENGINE_LOGGING = false;
const bool GlobalConfiguration::ENGINE_LOGGING = true;
const bool GlobalConfiguration::TABLEAU_LOGGING = false;
const bool GlobalConfiguration::SMT_CORE_LOGGING = false;
const bool GlobalConfiguration::SMT_CORE_LOGGING = true;
const bool GlobalConfiguration::DANTZIGS_RULE_LOGGING = false;
const bool GlobalConfiguration::BASIS_FACTORIZATION_LOGGING = false;
const bool GlobalConfiguration::PREPROCESSOR_LOGGING = false;
Expand Down
1 change: 1 addition & 0 deletions src/configuration/GlobalConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class GlobalConfiguration

// When doing bound tightening using the explicit basis matrix, should the basis matrix be inverted?
static const ExplicitBasisBoundTighteningType EXPLICIT_BASIS_BOUND_TIGHTENING_TYPE;
static const double EXPLICIT_BASIS_BOUND_TIGHTENING_ROUNDING_CONSTANT;

// When doing explicit bound tightening, should we repeat until saturation?
static const bool EXPLICIT_BOUND_TIGHTENING_UNTIL_SATURATION;
Expand Down
3 changes: 3 additions & 0 deletions src/engine/Engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ void Engine::applySnCSplit( PiecewiseLinearCaseSplit sncSplit, String queryId )
preContextPushHook();
_smtCore.pushContext();
applySplit( sncSplit );
_boundManager.propagateTightenings();
}

bool Engine::inSnCMode() const
Expand Down Expand Up @@ -282,9 +283,11 @@ bool Engine::solve( unsigned timeoutInSeconds )
// If true, we just entered a new subproblem
if ( splitJustPerformed )
{
std::cout << "bt" << std::endl;
performBoundTighteningAfterCaseSplit();
informLPSolverOfBounds();
splitJustPerformed = false;
std::cout << "bt - done" << std::endl;
}

// Perform any SmtCore-initiated case splits
Expand Down
4 changes: 2 additions & 2 deletions src/engine/MILPEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ void MILPEncoder::encodeSignConstraint( GurobiWrapper &gurobi,
ASSERT( ( FloatUtils::gte( _tableau.getLowerBound( sign->getB() ), 0 ) &&
FloatUtils::areEqual( _tableau.getLowerBound( sign->getF() ), 1 ) )
||
( FloatUtils::lt( _tableau.getUpperBound( sign->getB() ), 0 ) &&
FloatUtils::areEqual( _tableau.getLowerBound( sign->getF() ), -1 ) ) );
( FloatUtils::lte( _tableau.getUpperBound( sign->getB() ), 0 ) &&
FloatUtils::areEqual( _tableau.getUpperBound( sign->getF() ), -1 ) ) );
return;
}

Expand Down
7 changes: 7 additions & 0 deletions src/engine/MarabouMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "DnCMarabou.h"
#include "ConfigurationError.h"
#include "Error.h"
#include "LPSolverType.h"
#include "Marabou.h"
#include "Options.h"

Expand Down Expand Up @@ -95,6 +96,12 @@ int marabouMain( int argc, char **argv )
printf( "Proof production is not yet supported with MILP solvers, turning --milp off.\n" );
}

if ( options->getBool( Options::PRODUCE_PROOFS ) && ( options->getLPSolverType() == LPSolverType::GUROBI ) )
{
options->setString( Options::LP_SOLVER, "native" );
printf( "Proof production is not yet supported with MILP solvers, using native simplex engine.\n" );
}

if ( options->getBool( Options::DNC_MODE ) &&
options->getBool( Options::PARALLEL_DEEPSOI ) )
{
Expand Down
17 changes: 3 additions & 14 deletions src/engine/PolarityBasedDivider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,11 @@ void PolarityBasedDivider::createSubQueries( unsigned numNewSubqueries, const
PiecewiseLinearConstraint *PolarityBasedDivider::getPLConstraintToSplit
( const PiecewiseLinearCaseSplit &split )
{
EngineState *engineStateBeforeSplit = new EngineState();
_engine->storeState( *engineStateBeforeSplit,
TableauStateStorageLevel::STORE_BOUNDS_ONLY );
_engine->applySplit( split );
_engine->applySnCSplit( split, "" );

PiecewiseLinearConstraint *constraintToSplit = NULL;
constraintToSplit = _engine->pickSplitPLConstraintSnC( SnCDivideStrategy::Polarity );
_engine->restoreState( *engineStateBeforeSplit );
delete engineStateBeforeSplit;
_engine->getContext().pop();
_engine->postContextPopHook();
return constraintToSplit;
}

//
// Local Variables:
// compile-command: "make -C ../.. "
// tags-file-name: "../../TAGS"
// c-basic-offset: 4
// End:
//
9 changes: 4 additions & 5 deletions src/engine/RowBoundTightener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,8 @@ unsigned RowBoundTightener::tightenOnSingleInvertedBasisRow( const TableauRow &r
}
}

result += registerTighterLowerBound( y, lowerBound, row );
result += registerTighterUpperBound( y, upperBound, row );

result += registerTighterLowerBound( y, lowerBound - GlobalConfiguration::EXPLICIT_BASIS_BOUND_TIGHTENING_ROUNDING_CONSTANT, row );
result += registerTighterUpperBound( y, upperBound + GlobalConfiguration::EXPLICIT_BASIS_BOUND_TIGHTENING_ROUNDING_CONSTANT, row );
if ( FloatUtils::gt( getLowerBound( y ), getUpperBound( y ) ) )
{
ASSERT( FloatUtils::gt( _boundManager.getLowerBound( y ), _boundManager.getUpperBound( y ) ) );
Expand Down Expand Up @@ -374,8 +373,8 @@ unsigned RowBoundTightener::tightenOnSingleInvertedBasisRow( const TableauRow &r

// If a tighter bound is found, store it
xi = row._row[i]._var;
result += registerTighterLowerBound( xi, lowerBound, row );
result += registerTighterUpperBound( xi, upperBound, row );
result += registerTighterLowerBound( xi, lowerBound - GlobalConfiguration::EXPLICIT_BASIS_BOUND_TIGHTENING_ROUNDING_CONSTANT, row );
result += registerTighterUpperBound( xi, upperBound + GlobalConfiguration::EXPLICIT_BASIS_BOUND_TIGHTENING_ROUNDING_CONSTANT, row );
if ( FloatUtils::gt( getLowerBound( xi ), getUpperBound( xi ) ) )
{
ASSERT( FloatUtils::gt( _boundManager.getLowerBound( xi ), _boundManager.getUpperBound( xi ) ) );
Expand Down
7 changes: 5 additions & 2 deletions src/engine/SignConstraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ void SignConstraint::notifyLowerBound( unsigned variable, double bound )
!FloatUtils::gt( bound, getLowerBound( variable ) ) )
return;

if ( variable == 1637 )
std::cout << "Updating variable LB to " << bound << std::endl;

// Otherwise - update bound
setLowerBound( variable, bound );

Expand Down Expand Up @@ -439,13 +442,13 @@ void SignConstraint::notifyUpperBound( unsigned variable, double bound )
!FloatUtils::lt( bound, getUpperBound( variable ) ) )
return;

if ( variable == 1637 )
std::cout << "Updating variable UB to " << bound << std::endl;
// Otherwise - update bound
setUpperBound( variable, bound );

if ( variable == _f && FloatUtils::lt( bound, 1 ) )
{


setPhaseStatus( PhaseStatus::SIGN_PHASE_NEGATIVE );
if ( _boundManager != nullptr )
{
Expand Down
8 changes: 4 additions & 4 deletions src/engine/SignConstraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
** directory for licensing information.\endverbatim
**
** SignConstraint implements the following constraint:
** f = Sign( b ) = ( b > 0 -> f = 1 )
** /\ ( b <=0 -> f = -1 )
** f = Sign( b ) = ( b >= 0 -> f = 1 )
** /\ ( b < 0 -> f = -1 )
**
** It distinguishes two relevant phases for search:
** SIGN_PHASE_POSITIVE: b > 0 and f = 1
** SIGN_PHASE_NEGATIVE: b <=0 and f = -1
** SIGN_PHASE_POSITIVE: b >= 0 and f = 1
** SIGN_PHASE_NEGATIVE: b < 0 and f = -1
**
** The constraint is implemented as PiecewiseLinearConstraint
** and operates in two modes:
Expand Down
2 changes: 1 addition & 1 deletion src/engine/Tableau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ bool Tableau::existsValue( unsigned variable ) const
{
if ( _lpSolverType == LPSolverType::GUROBI )
{
return _gurobi->existsAssignment( Stringf( "x%u", variable ) );
return _gurobi && _gurobi->existsAssignment( Stringf( "x%u", variable ) );
}
else
{
Expand Down
6 changes: 4 additions & 2 deletions src/engine/tests/MockEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,12 @@ class MockEngine : public IEngine
}

bool _snc;
CVC4::context::Context _context;

void applySnCSplit( PiecewiseLinearCaseSplit /*split*/, String /*queryId*/)
{
_snc = true;
_context.push();
}

bool inSnCMode() const {
Expand All @@ -211,8 +214,7 @@ class MockEngine : public IEngine

bool applyAllValidConstraintCaseSplits() { return false; };

CVC4::context::Context _dontCare;
CVC4::context::Context &getContext() { return _dontCare; }
CVC4::context::Context &getContext() { return _context; }

bool consistentBounds() const { return true; }

Expand Down

0 comments on commit 83fbff7

Please sign in to comment.