Skip to content

Commit

Permalink
write queries to SMTLIB files
Browse files Browse the repository at this point in the history
  • Loading branch information
omriisack committed Sep 24, 2024
1 parent 4f7a7da commit 58a4ed0
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 95 deletions.
16 changes: 16 additions & 0 deletions maraboupy/MarabouCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,11 @@ void saveQuery( InputQuery &inputQuery, std::string filename )
inputQuery.saveQuery( String( filename ) );
}

void saveQueryAsSmtLib( InputQuery &query, std::string filename )
{
query.saveQueryAsSmtLib( String( filename ) );
}

void loadQuery( std::string filename, InputQuery &inputQuery )
{
return QueryLoader::loadQuery( String( filename ), inputQuery );
Expand Down Expand Up @@ -610,6 +615,17 @@ PYBIND11_MODULE( MarabouCore, m )
R"pbdoc(
Serializes the inputQuery in the given filename
Args:
inputQuery (:class:`~maraboupy.MarabouCore.InputQuery`): Marabou input query to be saved
filename (str): Name of file to save query
)pbdoc",
py::arg( "inputQuery" ),
py::arg( "filename" ) );
m.def( "saveQueryAsSmtLib",
&saveQueryAsSmtLib,
R"pbdoc(
Serializes the inputQuery in the given filename as an SMTLIB file
Args:
inputQuery (:class:`~maraboupy.MarabouCore.InputQuery`): Marabou input query to be saved
filename (str): Name of file to save query
Expand Down
7 changes: 7 additions & 0 deletions src/engine/InputQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ void InputQuery::saveQuery( const String &fileName )
delete query;
}

void InputQuery::saveQueryAsSmtLib( const String &fileName ) const
{
Query *query = generateQuery();
query->saveQueryAsSmtLib( fileName );
delete query;
}

Query *InputQuery::generateQuery() const
{
Query *query = new Query();
Expand Down
1 change: 1 addition & 0 deletions src/engine/InputQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class InputQuery : public IQuery
Serializes the query to a file which can then be loaded using QueryLoader.
*/
void saveQuery( const String &fileName );
void saveQueryAsSmtLib( const String &filename ) const;

/*
Generate a non-context-dependent version of the Query
Expand Down
31 changes: 31 additions & 0 deletions src/engine/Query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,37 @@ void Query::saveQuery( const String &fileName )
queryFile->close();
}

void Query::saveQueryAsSmtLib( const String &fileName ) const
{
if ( !_nlConstraints.empty() )
{
printf( "SMTLIB conversion does not support nonlinear constraints yet. Aborting "
"Conversion.\n" );
return;
}

List<Equation> equations;

Vector<double> upperBounds( _numberOfVariables, 0 );
Vector<double> lowerBounds( _numberOfVariables, 0 );

for ( unsigned i = 0; i < _numberOfVariables; ++i )
{
upperBounds[i] = _upperBounds.exists( i ) ? _upperBounds[i] : FloatUtils::infinity();
lowerBounds[i] =
_lowerBounds.exists( i ) ? _lowerBounds[i] : FloatUtils::negativeInfinity();
}

SmtLibWriter::writeToSmtLibFile( fileName,
0,
_numberOfVariables,
upperBounds,
lowerBounds,
NULL,
_equations,
_plConstraints );
}

void Query::markInputVariable( unsigned variable, unsigned inputIndex )
{
_variableToInputIndex[variable] = inputIndex;
Expand Down
1 change: 1 addition & 0 deletions src/engine/Query.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class Query : public IQuery
Serializes the query to a file which can then be loaded using QueryLoader.
*/
void saveQuery( const String &fileName );
void saveQueryAsSmtLib( const String &fileName ) const;

/*
Print input and output bounds
Expand Down
96 changes: 10 additions & 86 deletions src/proofs/Checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,92 +309,16 @@ Checker::getCorrespondingConstraint( const List<PiecewiseLinearCaseSplit> &split

void Checker::writeToFile()
{
List<String> leafInstance;

// Write with SmtLibWriter
unsigned b, f;
unsigned m = _proofSize;
unsigned n = _groundUpperBounds.size();

SmtLibWriter::addHeader( n, leafInstance );
SmtLibWriter::addGroundUpperBounds( _groundUpperBounds, leafInstance );
SmtLibWriter::addGroundLowerBounds( _groundLowerBounds, leafInstance );

auto tableauRow = SparseUnsortedList();

for ( unsigned i = 0; i < m; ++i )
{
tableauRow = SparseUnsortedList();
_initialTableau->getRow( i, &tableauRow );

// Fix correct size
if ( !tableauRow.getSize() && !tableauRow.empty() )
for ( auto it = tableauRow.begin(); it != tableauRow.end(); ++it )
tableauRow.incrementSize();

SmtLibWriter::addTableauRow( tableauRow, leafInstance );
}

for ( auto &constraint : _problemConstraints )
{
auto vars = constraint->getParticipatingVariables();
Vector<unsigned> conVars( vars.begin(), vars.end() );

if ( constraint->getType() == RELU )
{
b = conVars[0];
f = conVars[1];
SmtLibWriter::addReLUConstraint( b, f, constraint->getPhaseStatus(), leafInstance );
}
else if ( constraint->getType() == SIGN )
{
b = conVars[0];
f = conVars[1];
SmtLibWriter::addSignConstraint( b, f, constraint->getPhaseStatus(), leafInstance );
}
else if ( constraint->getType() == ABSOLUTE_VALUE )
{
b = conVars[0];
f = conVars[1];
SmtLibWriter::addAbsConstraint( b, f, constraint->getPhaseStatus(), leafInstance );
}
else if ( constraint->getType() == MAX )
{
MaxConstraint *maxConstraint = (MaxConstraint *)constraint;

Set<unsigned> elements = maxConstraint->getParticipatingElements();
double info = 0;

if ( constraint->getPhaseStatus() == MAX_PHASE_ELIMINATED )
info = maxConstraint->getMaxValueOfEliminatedPhases();
else if ( constraint->phaseFixed() )
info = maxConstraint->phaseToVariable( constraint->getPhaseStatus() );
else
for ( const auto &element : maxConstraint->getParticipatingVariables() )
elements.erase( element );

SmtLibWriter::addMaxConstraint( constraint->getParticipatingVariables().back(),
elements,
constraint->getPhaseStatus(),
info,
leafInstance );
}
else if ( constraint->getType() == DISJUNCTION )
SmtLibWriter::addDisjunctionConstraint(
( (DisjunctionConstraint *)constraint )->getFeasibleDisjuncts(), leafInstance );
else if ( constraint->getType() == LEAKY_RELU )
{
b = conVars[0];
f = conVars[1];
double slope = ( (LeakyReluConstraint *)constraint )->getSlope();
SmtLibWriter::addLeakyReLUConstraint(
b, f, slope, constraint->getPhaseStatus(), leafInstance );
}
}

SmtLibWriter::addFooter( leafInstance );
File file( "delegated" + std::to_string( _delegationCounter ) + ".smtlib" );
SmtLibWriter::writeInstanceToFile( file, leafInstance );
String filename = "delegated" + std::to_string( _delegationCounter ) + ".smtlib";

SmtLibWriter::writeToSmtLibFile( filename,
_proofSize,
_groundUpperBounds.size(),
_groundUpperBounds,
_groundLowerBounds,
_initialTableau,
List<Equation>(),
_problemConstraints );

++_delegationCounter;
}
Expand Down
126 changes: 120 additions & 6 deletions src/proofs/SmtLibWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,115 @@

#include "SmtLibWriter.h"

#include "DisjunctionConstraint.h"
#include "LeakyReluConstraint.h"
#include "MaxConstraint.h"
#include "ReluConstraint.h"
#include "SignConstraint.h"

const unsigned SmtLibWriter::SMTLIBWRITER_PRECISION =
(unsigned)std::log10( 1 / GlobalConfiguration::DEFAULT_EPSILON_FOR_COMPARISONS );


void SmtLibWriter::writeToSmtLibFile( const String &fileName,
unsigned numOfTableauRows,
unsigned numOfVariables,
const Vector<double> &upperBounds,
const Vector<double> &lowerBounds,
const SparseMatrix *tableau,
const List<Equation> &additionalEquations,
const List<PiecewiseLinearConstraint *> &problemConstraints )
{
List<String> instance;

// Write with SmtLibWriter
unsigned b, f;

SmtLibWriter::addHeader( numOfVariables, instance );
SmtLibWriter::addGroundUpperBounds( upperBounds, instance );
SmtLibWriter::addGroundLowerBounds( lowerBounds, instance );

auto tableauRow = SparseUnsortedList();

for ( unsigned i = 0; i < numOfTableauRows; ++i )
{
tableauRow = SparseUnsortedList();
tableau->getRow( i, &tableauRow );

// Fix correct size
if ( !tableauRow.getSize() && !tableauRow.empty() )
for ( auto it = tableauRow.begin(); it != tableauRow.end(); ++it )
tableauRow.incrementSize();

SmtLibWriter::addTableauRow( tableauRow, instance );
}

for ( const auto &eq : additionalEquations )
SmtLibWriter::addEquation( eq, instance, true );

for ( auto &constraint : problemConstraints )
{
auto vars = constraint->getParticipatingVariables();
Vector<unsigned> conVars( vars.begin(), vars.end() );

if ( constraint->getType() == RELU )
{
b = conVars[0];
f = conVars[1];
SmtLibWriter::addReLUConstraint( b, f, constraint->getPhaseStatus(), instance );
}
else if ( constraint->getType() == SIGN )
{
b = conVars[0];
f = conVars[1];
SmtLibWriter::addSignConstraint( b, f, constraint->getPhaseStatus(), instance );
}
else if ( constraint->getType() == ABSOLUTE_VALUE )
{
b = conVars[0];
f = conVars[1];
SmtLibWriter::addAbsConstraint( b, f, constraint->getPhaseStatus(), instance );
}
else if ( constraint->getType() == MAX )
{
MaxConstraint *maxConstraint = (MaxConstraint *)constraint;

Set<unsigned> elements = maxConstraint->getParticipatingElements();
double info = 0;

if ( constraint->getPhaseStatus() == MAX_PHASE_ELIMINATED )
info = maxConstraint->getMaxValueOfEliminatedPhases();
else if ( constraint->phaseFixed() )
info = maxConstraint->phaseToVariable( constraint->getPhaseStatus() );
// else
// for ( const auto &element : maxConstraint->getParticipatingVariables()
// )
// elements.erase( element );

SmtLibWriter::addMaxConstraint( constraint->getParticipatingVariables().back(),
elements,
constraint->getPhaseStatus(),
info,
instance );
}
else if ( constraint->getType() == DISJUNCTION )
SmtLibWriter::addDisjunctionConstraint(
( (DisjunctionConstraint *)constraint )->getFeasibleDisjuncts(), instance );
else if ( constraint->getType() == LEAKY_RELU )
{
b = conVars[0];
f = conVars[1];
double slope = ( (LeakyReluConstraint *)constraint )->getSlope();
SmtLibWriter::addLeakyReLUConstraint(
b, f, slope, constraint->getPhaseStatus(), instance );
}
}

SmtLibWriter::addFooter( instance );
File file( fileName );
SmtLibWriter::writeInstanceToFile( file, instance );
}

void SmtLibWriter::addHeader( unsigned numberOfVariables, List<String> &instance )
{
instance.append( "( set-logic QF_LRA )\n" );
Expand Down Expand Up @@ -147,7 +253,7 @@ void SmtLibWriter::addDisjunctionConstraint( const List<PiecewiseLinearCaseSplit
ASSERT( size )
// If disjucnt is a single equation or a single tightening, simply add them
if ( size == 1 && disjunct.getEquations().size() == 1 )
SmtLibWriter::addEquation( disjunct.getEquations().back(), instance );
SmtLibWriter::addEquation( disjunct.getEquations().back(), instance, false );
else if ( size == 1 && disjunct.getBoundTightenings().size() == 1 )
SmtLibWriter::addTightening( disjunct.getBoundTightenings().back(), instance );
else
Expand All @@ -160,7 +266,7 @@ void SmtLibWriter::addDisjunctionConstraint( const List<PiecewiseLinearCaseSplit
if ( counter < size - 1 )
instance.append( "( and " );
++counter;
SmtLibWriter::addEquation( eq, instance );
SmtLibWriter::addEquation( eq, instance, false );
}

for ( const auto &bound : disjunct.getBoundTightenings() )
Expand Down Expand Up @@ -248,15 +354,15 @@ void SmtLibWriter::addTableauRow( const SparseUnsortedList &row, List<String> &i
instance.append( assertRowLine + "\n" );
}

void SmtLibWriter::addGroundUpperBounds( Vector<double> &bounds, List<String> &instance )
void SmtLibWriter::addGroundUpperBounds( const Vector<double> &bounds, List<String> &instance )
{
unsigned n = bounds.size();
for ( unsigned i = 0; i < n; ++i )
instance.append( String( "( assert ( <= x" + std::to_string( i ) ) + String( " " ) +
signedValue( bounds[i] ) + " ) )\n" );
}

void SmtLibWriter::addGroundLowerBounds( Vector<double> &bounds, List<String> &instance )
void SmtLibWriter::addGroundLowerBounds( const Vector<double> &bounds, List<String> &instance )
{
unsigned n = bounds.size();
for ( unsigned i = 0; i < n; ++i )
Expand All @@ -282,7 +388,7 @@ String SmtLibWriter::signedValue( double val )
: String( "( - " + s.str() ).trimZerosFromRight() + " )";
}

void SmtLibWriter::addEquation( const Equation &eq, List<String> &instance )
void SmtLibWriter::addEquation( const Equation &eq, List<String> &instance, bool assertEquations )
{
unsigned size = eq._addends.size();

Expand All @@ -293,6 +399,9 @@ void SmtLibWriter::addEquation( const Equation &eq, List<String> &instance )

String assertRowLine = "";

if ( assertEquations )
assertRowLine += "( assert ";

if ( eq._type == Equation::EQ )
assertRowLine += "( = ";
else if ( eq._type == Equation::LE )
Expand All @@ -307,7 +416,12 @@ void SmtLibWriter::addEquation( const Equation &eq, List<String> &instance )
for ( const auto &addend : eq._addends )
{
if ( FloatUtils::isZero( addend._coefficient ) )
{
// If the last addend has coefficient zero, add 0 to close previously opened addition
if ( addend == eq._addends.back() )
assertRowLine += String( " 0 )" );
continue;
}

if ( !( addend == eq._addends.back() ) )
assertRowLine += String( " ( + " );
Expand All @@ -330,7 +444,7 @@ void SmtLibWriter::addEquation( const Equation &eq, List<String> &instance )
for ( unsigned i = 0; i < counter; ++i )
assertRowLine += String( " )" );

instance.append( assertRowLine + " " );
instance.append( assertRowLine + ( assertEquations ? " ) \n" : " " ) );
}

void SmtLibWriter::addTightening( Tightening bound, List<String> &instance )
Expand Down
Loading

0 comments on commit 58a4ed0

Please sign in to comment.