diff --git a/src/api/MainSolver.cc b/src/api/MainSolver.cc index 6d1043df8..cda6aa342 100644 --- a/src/api/MainSolver.cc +++ b/src/api/MainSolver.cc @@ -9,6 +9,7 @@ #include "MainSolver.h" #include +#include #include #include #include @@ -333,16 +334,21 @@ sstat MainSolver::check() { StopWatch sw(query_timer); } if (isLastFrameUnsat()) { return s_False; } - sstat rval = simplifyFormulas(); + sstat rval; + try { + rval = simplifyFormulas(); + } catch (opensmt::NonLinException const & error) { + reasonUnknown = error.what(); + rval = s_Undef; + return rval; + } if (config.dump_query()) printCurrentAssertionsAsQuery(); if (rval == s_Undef) { try { rval = solve(); - } catch (std::overflow_error const & error) { - rval = s_Error; - } catch (opensmt::LANonLinearException const & error) { + } catch (std::overflow_error const & error) { rval = s_Error; } catch (opensmt::NonLinException const & error) { reasonUnknown = error.what(); rval = s_Undef; } diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index ed6e244aa..fcb542f1a 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -24,5 +24,5 @@ include(numbers/CMakeLists.txt) install(FILES StringMap.h Timer.h inttypes.h IColor.h TreeOps.h FlaPartitionMap.h PartitionInfo.h Partitions.h ApiException.h TypeUtils.h - NatSet.h ScopedVector.h TermNames.h + NatSet.h ScopedVector.h TermNames.h NonLinException.h DESTINATION ${INSTALL_HEADERS_DIR}/common) diff --git a/src/common/NonLinException.h b/src/common/NonLinException.h new file mode 100644 index 000000000..25ea8e4aa --- /dev/null +++ b/src/common/NonLinException.h @@ -0,0 +1,21 @@ +// +// Created by Konstantin Britikov on 27.11.2024. +// + +#ifndef OPENSMT_NONLINEXCEPTION_H +#define OPENSMT_NONLINEXCEPTION_H + +namespace opensmt { +class NonLinException : public std::runtime_error { +public: + NonLinException(std::string_view const reason_) : runtime_error(std::string(reason_)) { + msg = "Term " + std::string(reason_) + " is non-linear"; + } + virtual char const * what() const noexcept override { return msg.c_str(); } + +private: + std::string msg; +}; +} + +#endif // OPENSMT_NONLINEXCEPTION_H diff --git a/src/logics/ArithLogic.cc b/src/logics/ArithLogic.cc index 02b602fab..492eb2273 100644 --- a/src/logics/ArithLogic.cc +++ b/src/logics/ArithLogic.cc @@ -114,8 +114,8 @@ ArithLogic::ArithLogic(Logic_t type) sym_Real_MINUS(declareFun_NoScoping_LeftAssoc(tk_real_minus, sort_REAL, {sort_REAL, sort_REAL})), sym_Real_PLUS(declareFun_Commutative_NoScoping_LeftAssoc(tk_real_plus, sort_REAL, {sort_REAL, sort_REAL})), sym_Real_TIMES(declareFun_Commutative_NoScoping_LeftAssoc(tk_real_times, sort_REAL, {sort_REAL, sort_REAL})), - sym_Real_TIMES_LIN(declareFun_Multiplication_Duplicate(tk_real_times, sort_REAL, {sort_REAL, sort_REAL})), - sym_Real_TIMES_NONLIN(declareFun_Multiplication_Duplicate(tk_real_times, sort_REAL, {sort_REAL, sort_REAL})), + sym_Real_TIMES_LIN(declareFun_Multiplication_LinNonlin(tk_real_times, sort_REAL, {sort_REAL, sort_REAL})), + sym_Real_TIMES_NONLIN(declareFun_Multiplication_LinNonlin(tk_real_times, sort_REAL, {sort_REAL, sort_REAL})), sym_Real_DIV(declareFun_NoScoping_LeftAssoc(tk_real_div, sort_REAL, {sort_REAL, sort_REAL})), sym_Real_EQ(sortToEquality[sort_REAL]), sym_Real_LEQ(declareFun_NoScoping_Chainable(tk_real_leq, sort_BOOL, {sort_REAL, sort_REAL})), @@ -136,8 +136,8 @@ ArithLogic::ArithLogic(Logic_t type) sym_Int_MINUS(declareFun_NoScoping_LeftAssoc(tk_int_minus, sort_INT, {sort_INT, sort_INT})), sym_Int_PLUS(declareFun_Commutative_NoScoping_LeftAssoc(tk_int_plus, sort_INT, {sort_INT, sort_INT})), sym_Int_TIMES(declareFun_Commutative_NoScoping_LeftAssoc(tk_int_times, sort_INT, {sort_INT, sort_INT})), - sym_Int_TIMES_LIN(declareFun_Multiplication_Duplicate(tk_int_times, sort_INT, {sort_INT, sort_INT})), - sym_Int_TIMES_NONLIN(declareFun_Multiplication_Duplicate(tk_int_times, sort_INT, {sort_INT, sort_INT})), + sym_Int_TIMES_LIN(declareFun_Multiplication_LinNonlin(tk_int_times, sort_INT, {sort_INT, sort_INT})), + sym_Int_TIMES_NONLIN(declareFun_Multiplication_LinNonlin(tk_int_times, sort_INT, {sort_INT, sort_INT})), sym_Int_DIV(declareFun_NoScoping_LeftAssoc(tk_int_div, sort_INT, {sort_INT, sort_INT})), sym_Int_MOD(declareFun_NoScoping(tk_int_mod, sort_INT, {sort_INT, sort_INT})), sym_Int_EQ(sortToEquality[sort_INT]), @@ -184,11 +184,11 @@ PTRef ArithLogic::getMinusOneForSort(SRef sort) const { } bool ArithLogic::isLinearFactor(PTRef tr) const { - if (isNumConst(tr) || isNumVarLike(tr)) { return true; } + if (isNumConst(tr) || isMonomial(tr)) { return true; } if (isTimesLin(tr)) { Pterm const & term = getPterm(tr); return term.size() == 2 && - ((isNumConst(term[0]) && (isNumVarLike(term[1]))) || (isNumConst(term[1]) && (isNumVarLike(term[0])))); + ((isNumConst(term[0]) && (isMonomial(term[1]))) || (isNumConst(term[1]) && (isMonomial(term[0])))); } return false; } @@ -238,19 +238,16 @@ pair> ArithLogic::getConstantAndFactors(PTRef sum) const { } pair ArithLogic::splitPolyTerm(PTRef term) const { - assert(isTimes(term) || isNumVarLike(term) || isConstant(term)); + assert(isTimes(term) || isMonomial(term) || isConstant(term)); if (isTimesLin(term)) { assert(getPterm(term).size() == 2); PTRef fac = getPterm(term)[0]; PTRef var = getPterm(term)[1]; if (not isConstant(fac)) { std::swap(fac, var); } assert(isConstant(fac)); - assert(isNumVarLike(var) || isTimesNonlin(var)); + assert(isMonomial(var)); return {var, fac}; - } else if (isTimesNonlin(term)) { - PTRef one = yieldsSortInt(term) ? getTerm_IntOne() : getTerm_RealOne(); - return {term, one}; - } else if (isNumVarLike(term)) { + } else if (isMonomial(term)) { assert(yieldsSortInt(term) or yieldsSortReal(term)); PTRef var = term; PTRef fac = yieldsSortInt(term) ? getTerm_IntOne() : getTerm_RealOne(); @@ -263,7 +260,7 @@ pair ArithLogic::splitPolyTerm(PTRef term) const { // Normalize a product of the form (* a v) to either v or (* -1 v) PTRef ArithLogic::normalizeMul(PTRef mul) { - assert(isTimesDefined(mul)); + assert(isTimesLinNonlin(mul)); auto [v, c] = splitPolyTerm(mul); if (getNumConst(c) < 0) { return mkNeg(v); @@ -426,7 +423,7 @@ lbool ArithLogic::arithmeticElimination(vec const & top_level_arith, Subs auto coeff = logic.getNumConst(c); poly.addTerm(var, std::move(coeff)); } else { - assert(logic.isPlus(polyTerm) || logic.isTimesDefined(polyTerm)); + assert(logic.isPlus(polyTerm) || logic.isTimesLinNonlin(polyTerm)); for (PTRef factor : logic.getPterm(polyTerm)) { auto [var, c] = logic.splitPolyTerm(factor); auto coeff = logic.getNumConst(c); @@ -495,14 +492,14 @@ pair ArithLogic::retrieveSubstitutions(vec const } uint32_t LessThan_deepPTRef::getVarIdFromProduct(PTRef tr) const { - assert(l.isTimesDefined(tr)); + assert(l.isTimesLinNonlin(tr)); auto [v, c] = l.splitPolyTerm(tr); return v.x; } bool LessThan_deepPTRef::operator()(PTRef x_, PTRef y_) const { - uint32_t id_x = l.isTimesDefined(x_) ? getVarIdFromProduct(x_) : x_.x; - uint32_t id_y = l.isTimesDefined(y_) ? getVarIdFromProduct(y_) : y_.x; + uint32_t id_x = l.isTimesLinNonlin(x_) ? getVarIdFromProduct(x_) : x_.x; + uint32_t id_y = l.isTimesLinNonlin(y_) ? getVarIdFromProduct(y_) : y_.x; return id_x < id_y; } @@ -517,10 +514,10 @@ bool ArithLogic::isBuiltinFunction(SymRef const sr) const { return Logic::isBuiltinFunction(sr); } bool ArithLogic::isNumTerm(PTRef tr) const { - if (isNumVarLike(tr)) return true; + if (isMonomial(tr)) return true; Pterm const & t = getPterm(tr); if (t.size() == 2 && isTimesLin(tr)) - return (isNumVarLike(t[0]) && isConstant(t[1])) || (isNumVarLike(t[1]) && isConstant(t[0])); + return (isMonomial(t[0]) && isConstant(t[1])) || (isMonomial(t[1]) && isConstant(t[0])); else if (t.size() == 0) return isNumVar(tr) || isConstant(tr); else @@ -557,7 +554,7 @@ PTRef ArithLogic::mkNeg(PTRef tr) { SRef returnSort = getSortRef(tr); return mkFun(getTimesLinForSort(returnSort), {tr, getMinusOneForSort(returnSort)}); } - if (isNumVarLike(symref)) { + if (isMonomial(symref)) { auto sortRef = getSortRef(symref); return mkFun(getTimesLinForSort(sortRef), {tr, getMinusOneForSort(sortRef)}); } @@ -744,9 +741,9 @@ PTRef ArithLogic::mkBinaryLeq(PTRef lhs, PTRef rhs) { Number const & v = this->getNumConst(sum_tmp); return v.sign() < 0 ? getTerm_false() : getTerm_true(); } - if (isNumVarLike(sum_tmp) || - isTimesDefined(sum_tmp)) { // "sum_tmp = c * v", just scale to "v" or "-v" without changing the sign - sum_tmp = isTimesDefined(sum_tmp) ? normalizeMul(sum_tmp) : sum_tmp; + if (isMonomial(sum_tmp) || + isTimesLinNonlin(sum_tmp)) { // "sum_tmp = c * v", just scale to "v" or "-v" without changing the sign + sum_tmp = isTimesLinNonlin(sum_tmp) ? normalizeMul(sum_tmp) : sum_tmp; return mkFun(getLeqForSort(argSort), {getZeroForSort(argSort), sum_tmp}); } else if (isPlus(sum_tmp)) { // Normalize the sum @@ -817,7 +814,7 @@ PTRef ArithLogic::mkBinaryEq(PTRef lhs, PTRef rhs) { if (isConstant(diff)) { Number const & v = this->getNumConst(diff); return v.isZero() ? getTerm_true() : getTerm_false(); - } else if (isNumVarLike(diff) || isTimesDefined(diff)) { + } else if (isMonomial(diff) || isTimesLin(diff)) { auto [var, constant] = splitPolyTerm(diff); return Logic::mkBinaryEq(getZeroForSort(eqSort), var); // Avoid anything that calls Logic::mkEq as this would create a loop @@ -1014,7 +1011,7 @@ void SimplifyConst::simplify(SymRef s, vec const & args, SymRef & s_new, } // // A single argument for the operator, and the operator is identity // // in that case - if (args_new.size() == 1 && (l.isPlus(s_new) || l.isTimesDefined(s_new))) { + if (args_new.size() == 1 && (l.isPlus(s_new) || l.isTimesLinNonlin(s_new))) { PTRef ch_tr = args_new[0]; args_new.clear(); s_new = l.getPterm(ch_tr).symb(); @@ -1233,7 +1230,7 @@ pair ArithLogic::sumToNormalizedIntPair(PTRef sum) { coeffs.reserve(varFactors.size()); for (PTRef factor : varFactors) { auto [var, coeff] = splitPolyTerm(factor); - assert((ArithLogic::isNumVarLike(var) || ArithLogic::isTimesDefined(var)) and isNumConst(coeff)); + assert((ArithLogic::isMonomial(var) || ArithLogic::isTimesLin(var)) and isNumConst(coeff)); vars.push(var); coeffs.push_back(getNumConst(coeff)); } @@ -1372,8 +1369,8 @@ std::pair ArithLogic::leqToConstantAndTerm(PTRef leq) { } bool ArithLogic::hasNegativeLeadingVariable(PTRef poly) const { - if (isNumConst(poly) or isNumVarLike(poly)) { return false; } - if (isTimesDefined(poly)) { + if (isNumConst(poly) or isMonomial(poly)) { return false; } + if (isTimesLinNonlin(poly)) { auto [var, constant] = splitPolyTerm(poly); return isNegative(getNumConst(constant)); } diff --git a/src/logics/ArithLogic.h b/src/logics/ArithLogic.h index 10d9a387d..b1e8c2cf0 100644 --- a/src/logics/ArithLogic.h +++ b/src/logics/ArithLogic.h @@ -164,13 +164,13 @@ class ArithLogic : public Logic { bool isIntNeg(SymRef sr) const { return sr == sym_Int_NEG; } bool isRealNeg(SymRef sr) const { return sr == sym_Real_NEG; } - bool isTimes(SymRef sr) const { return isTimesLin(sr) or isTimesNonlin(sr) or isTimesUnknown(sr); }; - bool isTimesDefined(SymRef sr) const { return isTimesLin(sr) or isTimesNonlin(sr); }; + bool isTimes(SymRef sr) const { return isTimesLin(sr) or isTimesNonlin(sr) or isTimesUnparsed(sr); }; + bool isTimesLinNonlin(SymRef sr) const { return isTimesLin(sr) or isTimesNonlin(sr); }; bool isTimesLin(SymRef sr) const { return isIntTimesLin(sr) or isRealTimesLin(sr); } - bool isTimesUnknown(SymRef sr) const { return isIntTimes(sr) or isRealTimes(sr); } + bool isTimesUnparsed(SymRef sr) const { return isIntTimes(sr) or isRealTimes(sr); } bool isTimesNonlin(SymRef sr) const { return isIntTimesNonlin(sr) or isRealTimesNonlin(sr); } bool isTimes(PTRef tr) const { return isTimes(getPterm(tr).symb()); } - bool isTimesDefined(PTRef tr) const { return isTimesDefined(getPterm(tr).symb()); }; + bool isTimesLinNonlin(PTRef tr) const { return isTimesLinNonlin(getPterm(tr).symb()); }; bool isTimesLin(PTRef tr) const { return isTimesLin(getPterm(tr).symb()); } bool isTimesNonlin(PTRef tr) const { return isTimesNonlin(getPterm(tr).symb()); } bool isIntTimesLin(PTRef tr) const { return isIntTimesLin(getPterm(tr).symb()); } @@ -229,10 +229,10 @@ class ArithLogic : public Logic { bool isNumVar(SymRef sr) const { return isVar(sr) and (yieldsSortInt(sr) or yieldsSortReal(sr)); } bool isNumVar(PTRef tr) const { return isNumVar(getPterm(tr).symb()); } - bool isNumVarLike(SymRef sr) const { - return yieldsSortNum(sr) and not isPlus(sr) and not isTimes(sr) and not isNumConst(sr); + bool isMonomial(SymRef sr) const { + return yieldsSortNum(sr) and not isPlus(sr) and not isTimesLin(sr) and not isNumConst(sr); } - bool isNumVarLike(PTRef tr) const { return isNumVarLike(getPterm(tr).symb()); } + bool isMonomial(PTRef tr) const { return isMonomial(getPterm(tr).symb()); } bool isZero(SymRef sr) const { return isIntZero(sr) or isRealZero(sr); } bool isZero(PTRef tr) const { return isZero(getSymRef(tr)); } @@ -389,6 +389,10 @@ class ArithLogic : public Logic { PTRef mkBinaryGeq(PTRef lhs, PTRef rhs) { return mkBinaryLeq(rhs, lhs); } PTRef mkBinaryLt(PTRef lhs, PTRef rhs) { return mkNot(mkBinaryGeq(lhs, rhs)); } PTRef mkBinaryGt(PTRef lhs, PTRef rhs) { return mkNot(mkBinaryLeq(lhs, rhs)); } + SymRef declareFun_Multiplication_LinNonlin(std::string const & s, SRef rsort, vec const & args) { + SymRef sr = sym_store.newInternalSymb(s.c_str(), rsort, args, SymConf::CommutativeNoScopingLeftAssoc); + return sr; + } PTRef mkBinaryEq(PTRef lhs, PTRef rhs) override; pair sumToNormalizedPair(PTRef sum); pair sumToNormalizedIntPair(PTRef sum); diff --git a/src/logics/Logic.h b/src/logics/Logic.h index b5d0cbc10..9a144ee61 100644 --- a/src/logics/Logic.h +++ b/src/logics/Logic.h @@ -186,10 +186,6 @@ class Logic { SymRef declareFun_Commutative_NoScoping_LeftAssoc(std::string const & s, SRef rsort, vec const & args) { return declareFun(s, rsort, args, SymConf::CommutativeNoScopingLeftAssoc); } - SymRef declareFun_Multiplication_Duplicate(std::string const & s, SRef rsort, vec const & args) { - SymRef sr = sym_store.newUnparsableSymb(s.c_str(), rsort, args, SymConf::CommutativeNoScopingLeftAssoc); - return sr; - } SymRef declareFun_Commutative_NoScoping_Chainable(std::string const & s, SRef rsort, vec const & args) { return declareFun(s, rsort, args, SymConf::CommutativeNoScopingChainable); } diff --git a/src/rewriters/DivModRewriter.h b/src/rewriters/DivModRewriter.h index e7c5f2495..5123fbb5c 100644 --- a/src/rewriters/DivModRewriter.h +++ b/src/rewriters/DivModRewriter.h @@ -11,6 +11,7 @@ #include "Rewriter.h" #include +#include #include namespace opensmt { @@ -36,7 +37,7 @@ class DivModConfig : public DefaultRewriterConfig { PTRef modVar = divMod.mod; PTRef rewritten = logic.isIntDiv(symRef) ? divVar : modVar; if (not inCache) { - if (logic.isNonlin(term)) { return term; } + if (!logic.isConstant(divisor)) throw NonLinException(logic.pp(term)); // collect the definitions to add assert(logic.isConstant(divisor)); auto divisorVal = logic.getNumConst(divisor); diff --git a/src/simplifiers/LA.cc b/src/simplifiers/LA.cc index ec1a51fa8..ea964bc9a 100644 --- a/src/simplifiers/LA.cc +++ b/src/simplifiers/LA.cc @@ -64,7 +64,7 @@ void LAExpression::initialize(PTRef e, bool do_canonize) { curr_const.emplace_back(std::move(new_c)); } else { // Otherwise it is a variable, Ite, UF or constant - assert(logic.isNumVarLike(t) || logic.isConstant(t) || logic.isUF(t)); + assert(logic.isMonomial(t) || logic.isConstant(t) || logic.isUF(t)); if (logic.isConstant(t)) { const Real tval = logic.getNumConst(t); polynome[PTRef_Undef] += tval * c; diff --git a/src/symbols/SymStore.h b/src/symbols/SymStore.h index b32600f00..2a5a85f40 100644 --- a/src/symbols/SymStore.h +++ b/src/symbols/SymStore.h @@ -41,12 +41,13 @@ class SymStore { SymStore & operator=(SymStore &&) = default; // Constructs a new symbol. - SymRef newSymb(char const * fname, SRef rsort, vec const & args, SymbolConfig const & symConfig, - bool subSymb = false); + SymRef newSymb(char const * fname, SRef rsort, vec const & args, SymbolConfig const & symConfig) { + return newSymb(fname, rsort, args, symConfig, false); + }; SymRef newSymb(char const * fname, SRef rsort, vec const & args) { - return newSymb(fname, rsort, args, SymConf::Default); + return newSymb(fname, rsort, args, SymConf::Default, false); } - SymRef newUnparsableSymb(char const * fname, SRef rsort, vec const & args, SymbolConfig const & symConfig) { + SymRef newInternalSymb(char const * fname, SRef rsort, vec const & args, SymbolConfig const & symConfig) { return newSymb(fname, rsort, args, symConfig, true); } bool contains(char const * fname) const { return symbolTable.has(fname); } @@ -78,6 +79,9 @@ class SymStore { vec symbols; SymbolAllocator ta{1024}; vec idToName; + + SymRef newSymb(char const * fname, SRef rsort, vec const & args, SymbolConfig const & symConfig, + bool subSymb); }; } // namespace opensmt diff --git a/src/tsolvers/lasolver/LASolver.cc b/src/tsolvers/lasolver/LASolver.cc index b4cd98124..4f1345746 100644 --- a/src/tsolvers/lasolver/LASolver.cc +++ b/src/tsolvers/lasolver/LASolver.cc @@ -11,6 +11,7 @@ #include "CutCreator.h" #include +#include #include #include @@ -91,7 +92,7 @@ void LASolver::isProperLeq(PTRef tr) assert(logic.isLeq(tr)); auto [cons, sum] = logic.leqToConstantAndTerm(tr); assert(logic.isConstant(cons)); - assert(logic.isNumVar(sum) || logic.isPlus(sum) || logic.isTimesDefined(sum) || logic.isMod(logic.getPterm(sum).symb()) || + assert(logic.isNumVar(sum) || logic.isPlus(sum) || logic.isTimesLinNonlin(sum) || logic.isMod(logic.getPterm(sum).symb()) || logic.isRealDiv(sum) || logic.isIntDiv(sum)); (void) cons; (void)sum; } @@ -287,16 +288,11 @@ std::unique_ptr LASolver::expressionToLVarPoly(PTRef term) // // Returns internalized reference for the term LVRef LASolver::registerArithmeticTerm(PTRef expr) { - if (logic.isNonlin(expr)) { - auto termStr = logic.pp(expr); - throw LANonLinearException(termStr.c_str()); - } else if(logic.isTimesLin(expr) || logic.isPlus(expr)) { + if (logic.isTimesNonlin(expr)) throw NonLinException(logic.pp(expr)); + else if(logic.isTimesLin(expr) || logic.isPlus(expr)) { Pterm const & subterms = logic.getPterm(expr); for(auto subterm: subterms) { - if (logic.isNonlin(subterm)) { - auto termStr = logic.pp(subterm); - throw LANonLinearException(termStr.c_str()); - } + if (logic.isNonlin(subterm)) throw NonLinException(logic.pp(subterm)); } } LVRef x = LVRef::Undef; diff --git a/src/tsolvers/lasolver/LASolver.h b/src/tsolvers/lasolver/LASolver.h index bad5a03fc..f10e8cab8 100644 --- a/src/tsolvers/lasolver/LASolver.h +++ b/src/tsolvers/lasolver/LASolver.h @@ -24,16 +24,6 @@ #include namespace opensmt { -class LANonLinearException : public std::runtime_error { -public: - LANonLinearException(char const * reason_) : runtime_error(reason_) { - msg = "Term " + std::string(reason_) + " is non-linear"; - } - virtual char const * what() const noexcept override { return msg.c_str(); } - -private: - std::string msg; -}; class LAVarStore; class Delta;