diff --git a/maraboupy/MarabouCore.cpp b/maraboupy/MarabouCore.cpp index 52c232c3b..8d35646b1 100644 --- a/maraboupy/MarabouCore.cpp +++ b/maraboupy/MarabouCore.cpp @@ -525,6 +525,11 @@ void saveQuery( InputQuery &inputQuery, std::string filename ) inputQuery.saveQuery( String( filename ) ); } +void saveQueryAsSmtLib( InputQuery &query, std::string filename ) +{ + query.saveQueryAsSmtLib( String( filename ) ); +} + void loadQuery( std::string filename, InputQuery &inputQuery ) { return QueryLoader::loadQuery( String( filename ), inputQuery ); @@ -610,6 +615,17 @@ PYBIND11_MODULE( MarabouCore, m ) R"pbdoc( Serializes the inputQuery in the given filename + Args: + inputQuery (:class:`~maraboupy.MarabouCore.InputQuery`): Marabou input query to be saved + filename (str): Name of file to save query + )pbdoc", + py::arg( "inputQuery" ), + py::arg( "filename" ) ); + m.def( "saveQueryAsSmtLib", + &saveQueryAsSmtLib, + R"pbdoc( + Serializes the inputQuery in the given filename as an SMTLIB file + Args: inputQuery (:class:`~maraboupy.MarabouCore.InputQuery`): Marabou input query to be saved filename (str): Name of file to save query diff --git a/src/engine/InputQuery.cpp b/src/engine/InputQuery.cpp index ba16ac3ea..d275646b0 100644 --- a/src/engine/InputQuery.cpp +++ b/src/engine/InputQuery.cpp @@ -339,6 +339,13 @@ void InputQuery::saveQuery( const String &fileName ) delete query; } +void InputQuery::saveQueryAsSmtLib( const String &fileName ) const +{ + Query *query = generateQuery(); + query->saveQueryAsSmtLib( fileName ); + delete query; +} + Query *InputQuery::generateQuery() const { Query *query = new Query(); diff --git a/src/engine/InputQuery.h b/src/engine/InputQuery.h index 9b5a1036d..083161ee0 100644 --- a/src/engine/InputQuery.h +++ b/src/engine/InputQuery.h @@ -135,6 +135,7 @@ class InputQuery : public IQuery Serializes the query to a file which can then be loaded using QueryLoader. */ void saveQuery( const String &fileName ); + void saveQueryAsSmtLib( const String &filename ) const; /* Generate a non-context-dependent version of the Query diff --git a/src/engine/Query.cpp b/src/engine/Query.cpp index 199d3f630..77b22c9c9 100644 --- a/src/engine/Query.cpp +++ b/src/engine/Query.cpp @@ -581,6 +581,37 @@ void Query::saveQuery( const String &fileName ) queryFile->close(); } +void Query::saveQueryAsSmtLib( const String &fileName ) const +{ + if ( !_nlConstraints.empty() ) + { + printf( "SMTLIB conversion does not support nonlinear constraints yet. Aborting " + "Conversion.\n" ); + return; + } + + List equations; + + Vector upperBounds( _numberOfVariables, 0 ); + Vector lowerBounds( _numberOfVariables, 0 ); + + for ( unsigned i = 0; i < _numberOfVariables; ++i ) + { + upperBounds[i] = _upperBounds.exists( i ) ? _upperBounds[i] : FloatUtils::infinity(); + lowerBounds[i] = + _lowerBounds.exists( i ) ? _lowerBounds[i] : FloatUtils::negativeInfinity(); + } + + SmtLibWriter::writeToSmtLibFile( fileName, + 0, + _numberOfVariables, + upperBounds, + lowerBounds, + NULL, + _equations, + _plConstraints ); +} + void Query::markInputVariable( unsigned variable, unsigned inputIndex ) { _variableToInputIndex[variable] = inputIndex; diff --git a/src/engine/Query.h b/src/engine/Query.h index ca7536e17..48fa1740d 100644 --- a/src/engine/Query.h +++ b/src/engine/Query.h @@ -125,6 +125,7 @@ class Query : public IQuery Serializes the query to a file which can then be loaded using QueryLoader. */ void saveQuery( const String &fileName ); + void saveQueryAsSmtLib( const String &fileName ) const; /* Print input and output bounds diff --git a/src/proofs/Checker.cpp b/src/proofs/Checker.cpp index 8d14e4113..c153de7fe 100644 --- a/src/proofs/Checker.cpp +++ b/src/proofs/Checker.cpp @@ -309,92 +309,16 @@ Checker::getCorrespondingConstraint( const List &split void Checker::writeToFile() { - List leafInstance; - - // Write with SmtLibWriter - unsigned b, f; - unsigned m = _proofSize; - unsigned n = _groundUpperBounds.size(); - - SmtLibWriter::addHeader( n, leafInstance ); - SmtLibWriter::addGroundUpperBounds( _groundUpperBounds, leafInstance ); - SmtLibWriter::addGroundLowerBounds( _groundLowerBounds, leafInstance ); - - auto tableauRow = SparseUnsortedList(); - - for ( unsigned i = 0; i < m; ++i ) - { - tableauRow = SparseUnsortedList(); - _initialTableau->getRow( i, &tableauRow ); - - // Fix correct size - if ( !tableauRow.getSize() && !tableauRow.empty() ) - for ( auto it = tableauRow.begin(); it != tableauRow.end(); ++it ) - tableauRow.incrementSize(); - - SmtLibWriter::addTableauRow( tableauRow, leafInstance ); - } - - for ( auto &constraint : _problemConstraints ) - { - auto vars = constraint->getParticipatingVariables(); - Vector conVars( vars.begin(), vars.end() ); - - if ( constraint->getType() == RELU ) - { - b = conVars[0]; - f = conVars[1]; - SmtLibWriter::addReLUConstraint( b, f, constraint->getPhaseStatus(), leafInstance ); - } - else if ( constraint->getType() == SIGN ) - { - b = conVars[0]; - f = conVars[1]; - SmtLibWriter::addSignConstraint( b, f, constraint->getPhaseStatus(), leafInstance ); - } - else if ( constraint->getType() == ABSOLUTE_VALUE ) - { - b = conVars[0]; - f = conVars[1]; - SmtLibWriter::addAbsConstraint( b, f, constraint->getPhaseStatus(), leafInstance ); - } - else if ( constraint->getType() == MAX ) - { - MaxConstraint *maxConstraint = (MaxConstraint *)constraint; - - Set elements = maxConstraint->getParticipatingElements(); - double info = 0; - - if ( constraint->getPhaseStatus() == MAX_PHASE_ELIMINATED ) - info = maxConstraint->getMaxValueOfEliminatedPhases(); - else if ( constraint->phaseFixed() ) - info = maxConstraint->phaseToVariable( constraint->getPhaseStatus() ); - else - for ( const auto &element : maxConstraint->getParticipatingVariables() ) - elements.erase( element ); - - SmtLibWriter::addMaxConstraint( constraint->getParticipatingVariables().back(), - elements, - constraint->getPhaseStatus(), - info, - leafInstance ); - } - else if ( constraint->getType() == DISJUNCTION ) - SmtLibWriter::addDisjunctionConstraint( - ( (DisjunctionConstraint *)constraint )->getFeasibleDisjuncts(), leafInstance ); - else if ( constraint->getType() == LEAKY_RELU ) - { - b = conVars[0]; - f = conVars[1]; - double slope = ( (LeakyReluConstraint *)constraint )->getSlope(); - SmtLibWriter::addLeakyReLUConstraint( - b, f, slope, constraint->getPhaseStatus(), leafInstance ); - } - } - - SmtLibWriter::addFooter( leafInstance ); - File file( "delegated" + std::to_string( _delegationCounter ) + ".smtlib" ); - SmtLibWriter::writeInstanceToFile( file, leafInstance ); + String filename = "delegated" + std::to_string( _delegationCounter ) + ".smtlib"; + + SmtLibWriter::writeToSmtLibFile( filename, + _proofSize, + _groundUpperBounds.size(), + _groundUpperBounds, + _groundLowerBounds, + _initialTableau, + List(), + _problemConstraints ); ++_delegationCounter; } diff --git a/src/proofs/SmtLibWriter.cpp b/src/proofs/SmtLibWriter.cpp index ed05f92d3..92628707d 100644 --- a/src/proofs/SmtLibWriter.cpp +++ b/src/proofs/SmtLibWriter.cpp @@ -14,9 +14,115 @@ #include "SmtLibWriter.h" +#include "DisjunctionConstraint.h" +#include "LeakyReluConstraint.h" +#include "MaxConstraint.h" +#include "ReluConstraint.h" +#include "SignConstraint.h" + const unsigned SmtLibWriter::SMTLIBWRITER_PRECISION = (unsigned)std::log10( 1 / GlobalConfiguration::DEFAULT_EPSILON_FOR_COMPARISONS ); + +void SmtLibWriter::writeToSmtLibFile( const String &fileName, + unsigned numOfTableauRows, + unsigned numOfVariables, + const Vector &upperBounds, + const Vector &lowerBounds, + const SparseMatrix *tableau, + const List &additionalEquations, + const List &problemConstraints ) +{ + List instance; + + // Write with SmtLibWriter + unsigned b, f; + + SmtLibWriter::addHeader( numOfVariables, instance ); + SmtLibWriter::addGroundUpperBounds( upperBounds, instance ); + SmtLibWriter::addGroundLowerBounds( lowerBounds, instance ); + + auto tableauRow = SparseUnsortedList(); + + for ( unsigned i = 0; i < numOfTableauRows; ++i ) + { + tableauRow = SparseUnsortedList(); + tableau->getRow( i, &tableauRow ); + + // Fix correct size + if ( !tableauRow.getSize() && !tableauRow.empty() ) + for ( auto it = tableauRow.begin(); it != tableauRow.end(); ++it ) + tableauRow.incrementSize(); + + SmtLibWriter::addTableauRow( tableauRow, instance ); + } + + for ( const auto &eq : additionalEquations ) + SmtLibWriter::addEquation( eq, instance, true ); + + for ( auto &constraint : problemConstraints ) + { + auto vars = constraint->getParticipatingVariables(); + Vector conVars( vars.begin(), vars.end() ); + + if ( constraint->getType() == RELU ) + { + b = conVars[0]; + f = conVars[1]; + SmtLibWriter::addReLUConstraint( b, f, constraint->getPhaseStatus(), instance ); + } + else if ( constraint->getType() == SIGN ) + { + b = conVars[0]; + f = conVars[1]; + SmtLibWriter::addSignConstraint( b, f, constraint->getPhaseStatus(), instance ); + } + else if ( constraint->getType() == ABSOLUTE_VALUE ) + { + b = conVars[0]; + f = conVars[1]; + SmtLibWriter::addAbsConstraint( b, f, constraint->getPhaseStatus(), instance ); + } + else if ( constraint->getType() == MAX ) + { + MaxConstraint *maxConstraint = (MaxConstraint *)constraint; + + Set elements = maxConstraint->getParticipatingElements(); + double info = 0; + + if ( constraint->getPhaseStatus() == MAX_PHASE_ELIMINATED ) + info = maxConstraint->getMaxValueOfEliminatedPhases(); + else if ( constraint->phaseFixed() ) + info = maxConstraint->phaseToVariable( constraint->getPhaseStatus() ); + // else + // for ( const auto &element : maxConstraint->getParticipatingVariables() + // ) + // elements.erase( element ); + + SmtLibWriter::addMaxConstraint( constraint->getParticipatingVariables().back(), + elements, + constraint->getPhaseStatus(), + info, + instance ); + } + else if ( constraint->getType() == DISJUNCTION ) + SmtLibWriter::addDisjunctionConstraint( + ( (DisjunctionConstraint *)constraint )->getFeasibleDisjuncts(), instance ); + else if ( constraint->getType() == LEAKY_RELU ) + { + b = conVars[0]; + f = conVars[1]; + double slope = ( (LeakyReluConstraint *)constraint )->getSlope(); + SmtLibWriter::addLeakyReLUConstraint( + b, f, slope, constraint->getPhaseStatus(), instance ); + } + } + + SmtLibWriter::addFooter( instance ); + File file( fileName ); + SmtLibWriter::writeInstanceToFile( file, instance ); +} + void SmtLibWriter::addHeader( unsigned numberOfVariables, List &instance ) { instance.append( "( set-logic QF_LRA )\n" ); @@ -147,7 +253,7 @@ void SmtLibWriter::addDisjunctionConstraint( const List &i instance.append( assertRowLine + "\n" ); } -void SmtLibWriter::addGroundUpperBounds( Vector &bounds, List &instance ) +void SmtLibWriter::addGroundUpperBounds( const Vector &bounds, List &instance ) { unsigned n = bounds.size(); for ( unsigned i = 0; i < n; ++i ) @@ -256,7 +362,7 @@ void SmtLibWriter::addGroundUpperBounds( Vector &bounds, List &i signedValue( bounds[i] ) + " ) )\n" ); } -void SmtLibWriter::addGroundLowerBounds( Vector &bounds, List &instance ) +void SmtLibWriter::addGroundLowerBounds( const Vector &bounds, List &instance ) { unsigned n = bounds.size(); for ( unsigned i = 0; i < n; ++i ) @@ -282,7 +388,7 @@ String SmtLibWriter::signedValue( double val ) : String( "( - " + s.str() ).trimZerosFromRight() + " )"; } -void SmtLibWriter::addEquation( const Equation &eq, List &instance ) +void SmtLibWriter::addEquation( const Equation &eq, List &instance, bool assertEquations ) { unsigned size = eq._addends.size(); @@ -293,6 +399,9 @@ void SmtLibWriter::addEquation( const Equation &eq, List &instance ) String assertRowLine = ""; + if ( assertEquations ) + assertRowLine += "( assert "; + if ( eq._type == Equation::EQ ) assertRowLine += "( = "; else if ( eq._type == Equation::LE ) @@ -307,7 +416,12 @@ void SmtLibWriter::addEquation( const Equation &eq, List &instance ) for ( const auto &addend : eq._addends ) { if ( FloatUtils::isZero( addend._coefficient ) ) + { + // If the last addend has coefficient zero, add 0 to close previously opened addition + if ( addend == eq._addends.back() ) + assertRowLine += String( " 0 )" ); continue; + } if ( !( addend == eq._addends.back() ) ) assertRowLine += String( " ( + " ); @@ -330,7 +444,7 @@ void SmtLibWriter::addEquation( const Equation &eq, List &instance ) for ( unsigned i = 0; i < counter; ++i ) assertRowLine += String( " )" ); - instance.append( assertRowLine + " " ); + instance.append( assertRowLine + ( assertEquations ? " ) \n" : " " ) ); } void SmtLibWriter::addTightening( Tightening bound, List &instance ) diff --git a/src/proofs/SmtLibWriter.h b/src/proofs/SmtLibWriter.h index e3d3bf0d8..ab7787cab 100644 --- a/src/proofs/SmtLibWriter.h +++ b/src/proofs/SmtLibWriter.h @@ -95,17 +95,17 @@ class SmtLibWriter /* Adds a line representing an equation , in SMTLIB format, to the SMTLIB instance */ - static void addEquation( const Equation &eq, List &instance ); + static void addEquation( const Equation &eq, List &instance, bool assertEquations ); /* Adds lines representing the ground upper bounds, in SMTLIB format, to the SMTLIB instance */ - static void addGroundUpperBounds( Vector &bounds, List &instance ); + static void addGroundUpperBounds( const Vector &bounds, List &instance ); /* Adds lines representing the ground lower bounds, in SMTLIB format, to the SMTLIB instance */ - static void addGroundLowerBounds( Vector &bounds, List &instance ); + static void addGroundLowerBounds( const Vector &bounds, List &instance ); /* Adds lines representing a tightening, in SMTLIB format, to the SMTLIB instance @@ -121,6 +121,17 @@ class SmtLibWriter Returns a string representing the value of a double */ static String signedValue( double val ); + /* + A wrapper function calling all previous functions + */ + static void writeToSmtLibFile( const String &fileName, + unsigned numOfTableauRows, + unsigned numOfVariables, + const Vector &upperBounds, + const Vector &lowerBounds, + const SparseMatrix *tableau, + const List &additionalEquations, + const List &problemConstraints ); }; #endif //__SmtLibWriter_h__ \ No newline at end of file