diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index ec58ae1..410d278 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -27,4 +27,4 @@ jobs: publish_branch: gh-pages github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: _build/ - force_orphan: true \ No newline at end of file + force_orphan: true diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index eba9496..de6ee56 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -5,13 +5,12 @@ name: Python package on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false @@ -19,20 +18,21 @@ jobs: python-version: ["3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install ruff pytest - python -m pip install -e . - - name: Lint with Ruff - run: | - ruff check --output-format=github . - continue-on-error: true - - name: Test with pytest - run: | - pytest + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install ruff pytest cvc5 + python -m pip install -e . + - name: Lint with Ruff + run: | + ruff check --output-format=github . + continue-on-error: true + - name: Test with pytest + run: | + pytest + KNUCKLE_SOLVER=cvc5 pytest diff --git a/README.md b/README.md index b3539f1..9c33830 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ It is not desirable or within my capabilities to build a giant universe in which Using widespread and commonly supported languages gives a huge leg up in terms of tooling and audience. Python is the modern lingua franca of computing. It has a first class interactive experience and extensive bindings to projects in other languages. -Core functionality comes from [Z3](https://github.com/Z3Prover/z3). The Z3 python api is a de facto standard. The term and formula data structures of knuckledragger are literally Z3 python terms and formula. To some degree, knuckledragger is a metalayer to guide Z3 through proofs it could perhaps do on its own, but it would get lost. +Core functionality comes from [Z3](https://github.com/Z3Prover/z3). The Z3 python api is a de facto standard. The term and formula data structures of knuckledragger are literally smt python terms and formula. To some degree, knuckledragger is a metalayer to guide smt through proofs it could perhaps do on its own, but it would get lost. A hope is to be able to use easy access to [Jupyter](https://jupyter.org/), [copilot](https://copilot.microsoft.com/), ML ecosystems, [sympy](https://www.sympy.org/), [cvxpy](https://www.cvxpy.org/), [numpy](https://numpy.org/), [scipy](https://scipy.org/), [egglog](https://egglog-python.readthedocs.io/latest/), [Julia](https://github.com/JuliaPy/PythonCall.jl), [Prolog](https://www.swi-prolog.org/pldoc/man?section=janus-call-prolog), [Maude](https://fadoss.github.io/maude-bindings/), [calcium](https://fredrikj.net/calcium/), [flint](https://fredrikj.net/python-flint/), [Mathematica](https://reference.wolfram.com/language/WolframClientForPython/), and [sage](https://www.sagemath.org/) will make metaprogramming in this system very powerful. I maintain the option to just trust these results but hopefully they can be translated into arguments the kernel can understand. diff --git a/knuckledragger/__init__.py b/knuckledragger/__init__.py index 05749b7..62bebd1 100644 --- a/knuckledragger/__init__.py +++ b/knuckledragger/__init__.py @@ -1,7 +1,8 @@ +from . import smt from . import kernel from . import notation from . import tactics -import z3 + lemma = tactics.lemma axiom = kernel.axiom @@ -14,7 +15,7 @@ Calc = tactics.Calc -R = z3.RealSort() -Z = z3.IntSort() +R = smt.RealSort() +Z = smt.IntSort() RSeq = Z >> R RFun = R >> R diff --git a/knuckledragger/kernel.py b/knuckledragger/kernel.py index d8417fa..294a1ff 100644 --- a/knuckledragger/kernel.py +++ b/knuckledragger/kernel.py @@ -1,4 +1,4 @@ -import z3 +import knuckledragger.smt as smt from dataclasses import dataclass from typing import Any import logging @@ -7,8 +7,8 @@ @dataclass(frozen=True) -class Proof(z3.Z3PPObject): - thm: z3.BoolRef +class Proof(smt.Z3PPObject): + thm: smt.BoolRef reason: list[Any] admit: bool @@ -34,12 +34,12 @@ class LemmaError(Exception): def lemma( - thm: z3.BoolRef, + thm: smt.BoolRef, by: list[Proof] = [], admit=False, timeout=1000, dump=False, - solver=z3.Solver, + solver=smt.Solver, ) -> Proof: """Prove a theorem using a list of previously proved lemmas. @@ -47,7 +47,7 @@ def lemma( :param thm: The theorem to prove. Args: - thm (z3.BoolRef): The theorem to prove. + thm (smt.BoolRef): The theorem to prove. by (list[Proof]): A list of previously proved lemmas. admit (bool): If True, admit the theorem without proof. @@ -69,19 +69,19 @@ def lemma( if not isinstance(p, __Proof): raise LemmaError("In by reasons:", p, "is not a Proof object") s.add(p.thm) - s.add(z3.Not(thm)) + s.add(smt.Not(thm)) if dump: print(s.sexpr()) res = s.check() - if res != z3.unsat: - if res == z3.sat: + if res != smt.unsat: + if res == smt.sat: raise LemmaError(thm, "Countermodel", s.model()) raise LemmaError("lemma", thm, res) else: return __Proof(thm, by, False) -def axiom(thm: z3.BoolRef, by=[]) -> Proof: +def axiom(thm: smt.BoolRef, by=[]) -> Proof: """Assert an axiom. Axioms are necessary and useful. But you must use great care. @@ -96,19 +96,19 @@ def axiom(thm: z3.BoolRef, by=[]) -> Proof: @dataclass(frozen=True) class Defn: name: str - args: list[z3.ExprRef] - body: z3.ExprRef + args: list[smt.ExprRef] + body: smt.ExprRef ax: Proof -defns: dict[z3.FuncDecl, Defn] = {} +defns: dict[smt.FuncDecl, Defn] = {} """ defn holds definitional axioms for function symbols. """ -z3.FuncDeclRef.defn = property(lambda self: defns[self].ax) +smt.FuncDeclRef.defn = property(lambda self: defns[self].ax) -def define(name: str, args: list[z3.ExprRef], body: z3.ExprRef) -> z3.FuncDeclRef: +def define(name: str, args: list[smt.ExprRef], body: smt.ExprRef) -> smt.FuncDeclRef: """ Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions. TODO: Check for bad circularity, record dependencies @@ -119,12 +119,12 @@ def define(name: str, args: list[z3.ExprRef], body: z3.ExprRef) -> z3.FuncDeclRe defn: The definition of the term. Returns: - tuple[z3.FuncDeclRef, __Proof]: A tuple of the defined term and the proof of the definition. + tuple[smt.FuncDeclRef, __Proof]: A tuple of the defined term and the proof of the definition. """ sorts = [arg.sort() for arg in args] + [body.sort()] - f = z3.Function(name, *sorts) + f = smt.Function(name, *sorts) if len(args) > 0: - def_ax = axiom(z3.ForAll(args, f(*args) == body), by="definition") + def_ax = axiom(smt.ForAll(args, f(*args) == body), by="definition") else: def_ax = axiom(f(*args) == body, by="definition") # assert f not in __sig or __sig[f].eq( def_ax.thm) # Check for redefinitions. This is kind of painful. Hmm. @@ -138,13 +138,13 @@ def define(name: str, args: list[z3.ExprRef], body: z3.ExprRef) -> z3.FuncDeclRe return f -def define_fix(name: str, args: list[z3.ExprRef], retsort, fix_lam) -> z3.FuncDeclRef: +def define_fix(name: str, args: list[smt.ExprRef], retsort, fix_lam) -> smt.FuncDeclRef: """ Define a recursive definition. """ sorts = [arg.sort() for arg in args] sorts.append(retsort) - f = z3.Function(name, *sorts) + f = smt.Function(name, *sorts) # wrapper to record calls calls = set() diff --git a/knuckledragger/notation.py b/knuckledragger/notation.py index 6c9a4ea..248190e 100644 --- a/knuckledragger/notation.py +++ b/knuckledragger/notation.py @@ -1,4 +1,4 @@ -"""Importing this module will add some syntactic sugar to Z3. +"""Importing this module will add some syntactic sugar to smt. - Expr overload by single dispatch - Bool supports `&`, `|`, `~` @@ -6,12 +6,12 @@ - Datatypes support accessor notation """ -import z3 +import knuckledragger.smt as smt import knuckledragger as kd -z3.BoolRef.__and__ = lambda self, other: z3.And(self, other) -z3.BoolRef.__or__ = lambda self, other: z3.Or(self, other) -z3.BoolRef.__invert__ = lambda self: z3.Not(self) +smt.BoolRef.__and__ = lambda self, other: smt.And(self, other) +smt.BoolRef.__or__ = lambda self, other: smt.Or(self, other) +smt.BoolRef.__invert__ = lambda self: smt.Not(self) def QForAll(vs, *hyp_conc): @@ -26,12 +26,12 @@ def QForAll(vs, *hyp_conc): hyps = hyp_conc[:-1] hyps = [v.wf for v in vs if hasattr(v, "wf")] + list(hyps) if len(hyps) == 0: - return z3.ForAll(vs, conc) + return smt.ForAll(vs, conc) elif len(hyps) == 1: - return z3.ForAll(vs, z3.Implies(hyps[0], conc)) + return smt.ForAll(vs, smt.Implies(hyps[0], conc)) else: - hyp = z3.And(hyps) - return z3.ForAll(vs, z3.Implies(hyp, conc)) + hyp = smt.And(hyps) + return smt.ForAll(vs, smt.Implies(hyp, conc)) def QExists(vs, *concs): @@ -43,12 +43,12 @@ def QExists(vs, *concs): """ concs = [v.wf for v in vs if hasattr(v, "wf")] + list(concs) if len(concs) == 1: - z3.Exists(vars, concs[0]) + smt.Exists(vars, concs[0]) else: - z3.Exists(vars, z3.And(concs)) + smt.Exists(vars, smt.And(concs)) -z3.SortRef.__rshift__ = lambda self, other: z3.ArraySort(self, other) +smt.SortRef.__rshift__ = lambda self, other: smt.ArraySort(self, other) class SortDispatch: @@ -75,37 +75,37 @@ def define(self, args, body): return defn -add = SortDispatch(z3.ArithRef.__add__, name="add") -z3.ExprRef.__add__ = lambda x, y: add(x, y) +add = SortDispatch(smt.ArithRef.__add__, name="add") +smt.ExprRef.__add__ = lambda x, y: add(x, y) -sub = SortDispatch(z3.ArithRef.__sub__, name="sub") -z3.ExprRef.__sub__ = lambda x, y: sub(x, y) +sub = SortDispatch(smt.ArithRef.__sub__, name="sub") +smt.ExprRef.__sub__ = lambda x, y: sub(x, y) -mul = SortDispatch(z3.ArithRef.__mul__, name="mul") -z3.ExprRef.__mul__ = lambda x, y: mul(x, y) +mul = SortDispatch(smt.ArithRef.__mul__, name="mul") +smt.ExprRef.__mul__ = lambda x, y: mul(x, y) -neg = SortDispatch(z3.ArithRef.__neg__, name="neg") -z3.ExprRef.__neg__ = lambda x: neg(x) +neg = SortDispatch(smt.ArithRef.__neg__, name="neg") +smt.ExprRef.__neg__ = lambda x: neg(x) -div = SortDispatch(z3.ArithRef.__div__, name="div") -z3.ExprRef.__truediv__ = lambda x, y: div(x, y) +div = SortDispatch(smt.ArithRef.__div__, name="div") +smt.ExprRef.__truediv__ = lambda x, y: div(x, y) and_ = SortDispatch() -z3.ExprRef.__and__ = lambda x, y: and_(x, y) +smt.ExprRef.__and__ = lambda x, y: and_(x, y) or_ = SortDispatch() -z3.ExprRef.__or__ = lambda x, y: or_(x, y) +smt.ExprRef.__or__ = lambda x, y: or_(x, y) -lt = SortDispatch(z3.ArithRef.__lt__, name="lt") -z3.ExprRef.__lt__ = lambda x, y: lt(x, y) +lt = SortDispatch(smt.ArithRef.__lt__, name="lt") +smt.ExprRef.__lt__ = lambda x, y: lt(x, y) -le = SortDispatch(z3.ArithRef.__le__, name="le") -z3.ExprRef.__le__ = lambda x, y: le(x, y) +le = SortDispatch(smt.ArithRef.__le__, name="le") +smt.ExprRef.__le__ = lambda x, y: le(x, y) def lookup_cons_recog(self, k): """ - Enable "dot" syntax for fields of z3 datatypes + Enable "dot" syntax for fields of smt datatypes """ sort = self.sort() recog = "is_" == k[:3] if len(k) > 3 else False @@ -122,14 +122,14 @@ def lookup_cons_recog(self, k): return acc(self) -z3.DatatypeRef.__getattr__ = lookup_cons_recog +smt.DatatypeRef.__getattr__ = lookup_cons_recog def Record(name, *fields): """ Define a record datatype """ - rec = z3.Datatype(name) + rec = smt.Datatype(name) rec.declare(name, *fields) rec = rec.create() rec.mk = rec.constructor(0) @@ -148,13 +148,13 @@ def __init__(self): self.sort = None self.default = None - def when(self, c: z3.BoolRef) -> "Cond": + def when(self, c: smt.BoolRef) -> "Cond": assert self.cur_case is None - assert isinstance(c, z3.BoolRef) + assert isinstance(c, smt.BoolRef) self.cur_case = c return self - def then(self, e: z3.ExprRef) -> "Cond": + def then(self, e: smt.ExprRef) -> "Cond": assert self.cur_case is not None if self.sort is not None: assert e.sort() == self.sort @@ -164,16 +164,16 @@ def then(self, e: z3.ExprRef) -> "Cond": self.cur_case = None return self - def otherwise(self, e: z3.ExprRef) -> z3.ExprRef: + def otherwise(self, e: smt.ExprRef) -> smt.ExprRef: assert self.default is None assert self.sort == e.sort() self.default = e return self.expr() - def expr(self) -> z3.ExprRef: + def expr(self) -> smt.ExprRef: assert self.default is not None assert self.cur_case is None acc = self.default for c, e in reversed(self.clauses): - acc = z3.If(c, e, acc) + acc = smt.If(c, e, acc) return acc diff --git a/knuckledragger/smt.py b/knuckledragger/smt.py new file mode 100644 index 0000000..d4a2d0e --- /dev/null +++ b/knuckledragger/smt.py @@ -0,0 +1,55 @@ +import os + +Z3SOLVER = "z3" +CVC5SOLVER = "cvc5" +solver = os.getenv("KNUCKLE_SOLVER") +if solver is None or solver == Z3SOLVER: + solver = "z3" + from z3 import * +elif solver == CVC5SOLVER: + import cvc5 + from cvc5.pythonic import * + + Z3PPObject = object + FuncDecl = FuncDeclRef + + class Solver(cvc5.pythonic.Solver): + def __init__(self): + super().__init__() + self.set("produce-unsat-cores", "true") + + def set(self, option, value): + if option == "timeout": + self.set("tlimit-per", value) + else: + super().set(option, value) + + def assert_and_track(self, thm, name): + x = Bool(name) + self.add(x) + return self.add(Implies(x, thm)) + + def unsat_core(self): + return [cvc5.pythonic.BoolRef(x) for x in self.solver.getUnsatCore()] + + def Const(name, sort): + # _to_expr doesn't have a DatatypeRef case + x = cvc5.pythonic.Const(name, sort) + if isinstance(sort, DatatypeSortRef): + x = DatatypeRef(x.ast, x.ctx, x.reverse_children) + return x + + def Consts(names, sort): + return [Const(name, sort) for name in names.split()] + + def _qsort(self): + if self.is_lambda(): + return ArraySort(self.var_sort(0), self.body().sort()) + else: + return BoolSort(self.ctx) + + QuantifierRef.sort = _qsort +else: + raise ValueError( + "Unknown solver in environment variable KNUCKLE_SOLVER: {}".format(solver) + ) diff --git a/knuckledragger/tactics.py b/knuckledragger/tactics.py index c944ab3..4301204 100644 --- a/knuckledragger/tactics.py +++ b/knuckledragger/tactics.py @@ -1,5 +1,5 @@ import knuckledragger as kd -import z3 +import knuckledragger.smt as smt from enum import IntEnum import operator as op @@ -45,13 +45,13 @@ def __init__(self, vars, lhs, assume=[]): def _forall(self, body): if len(self.assume) == 1: - body = z3.Implies(self.assume[0], body) + body = smt.Implies(self.assume[0], body) elif len(self.assume) > 1: - body = z3.Implies(z3.And(self.assume), body) + body = smt.Implies(smt.And(self.assume), body) if len(self.vars) == 0: return body else: - return z3.ForAll(self.vars, body) + return smt.ForAll(self.vars, body) def eq(self, rhs, by=[]): self.lemmas.append(kd.lemma(self._forall(self.terms[-1] == rhs), by=by)) @@ -100,12 +100,12 @@ def qed(self): def lemma( - thm: z3.BoolRef, + thm: smt.BoolRef, by: list[kd.kernel.Proof] = [], admit=False, timeout=1000, dump=False, - solver=z3.Solver, + solver=smt.Solver, ) -> kd.kernel.Proof: """Prove a theorem using a list of previously proved lemmas. @@ -113,7 +113,7 @@ def lemma( :param thm: The theorem to prove. Args: - thm (z3.BoolRef): The theorem to prove. + thm (smt.BoolRef): The theorem to prove. by (list[Proof]): A list of previously proved lemmas. admit (bool): If True, admit the theorem without proof. @@ -136,17 +136,17 @@ def lemma( if not kd.kernel.is_proof(p): raise kd.kernel.LemmaError("In by reasons:", p, "is not a Proof object") s.assert_and_track(p.thm, "by_{}".format(n)) - s.assert_and_track(z3.Not(thm), "knuckledragger_goal") + s.assert_and_track(smt.Not(thm), "knuckledragger_goal") if dump: print(s.sexpr()) res = s.check() - if res != z3.unsat: - if res == z3.sat: - raise kd.kernel.LemmaError(thm, "Countermodel", s.model()) - raise kd.kernel.LemmaError("lemma", thm, res) + if res != smt.unsat: + if res == smt.sat: + raise kd.kernel.LemmaError(thm, by, "Countermodel", s.model()) + raise kd.kernel.LemmaError("lemma", thm, by, res) else: core = s.unsat_core() - if not z3.Bool("knuckledragger_goal") in core: + if not smt.Bool("knuckledragger_goal") in core: raise kd.kernel.LemmaError( thm, core, diff --git a/knuckledragger/theories/Complex.py b/knuckledragger/theories/Complex.py index af00f40..45c4881 100644 --- a/knuckledragger/theories/Complex.py +++ b/knuckledragger/theories/Complex.py @@ -1,9 +1,9 @@ import knuckledragger as kd -import z3 +import knuckledragger.smt as smt -C = kd.notation.Record("C", ("re", z3.RealSort()), ("im", z3.RealSort())) +C = kd.notation.Record("C", ("re", smt.RealSort()), ("im", smt.RealSort())) -z, w, u, z1, z2 = z3.Consts("z w u z1 z2", C) +z, w, u, z1, z2 = smt.Consts("z w u z1 z2", C) add = kd.define("add", [z1, z2], C.mk(z1.re + z2.re, z1.im + z2.im)) kd.notation.add.register(C, add) mul = kd.define( @@ -27,20 +27,22 @@ C1 = C.mk(1, 0) Ci = C.mk(0, 1) -add_zero = kd.lemma(z3.ForAll([z], z + C0 == z), by=[add.defn]) -mul_zero = kd.lemma(z3.ForAll([z], z * C0 == C0), by=[mul.defn]) -mul_one = kd.lemma(z3.ForAll([z], z * C1 == z), by=[mul.defn]) -add_comm = kd.lemma(z3.ForAll([z, w], z + w == w + z), by=[add.defn]) +add_zero = kd.lemma(smt.ForAll([z], z + C0 == z), by=[add.defn]) +mul_zero = kd.lemma(smt.ForAll([z], z * C0 == C0), by=[mul.defn]) +mul_one = kd.lemma(smt.ForAll([z], z * C1 == z), by=[mul.defn]) +add_comm = kd.lemma(smt.ForAll([z, w], z + w == w + z), by=[add.defn]) add_assoc = kd.lemma( - z3.ForAll([z, w, u], (z + (w + u)) == ((z + w) + u)), by=[add.defn] + smt.ForAll([z, w, u], (z + (w + u)) == ((z + w) + u)), by=[add.defn] ) -mul_comm = kd.lemma(z3.ForAll([z, w], z * w == w * z), by=[mul.defn]) +mul_comm = kd.lemma(smt.ForAll([z, w], z * w == w * z), by=[mul.defn]) # unstable perfoamnce. # mul_div = kd.lemma(ForAll([z,w], Implies(w != C0, z == z * w / w)), by=[div.defn, mul.defn], timeout=1000) ##mul_div = Calc() -div_one = kd.lemma(z3.ForAll([z], z / C1 == z), by=[div.defn]) -div_inv = kd.lemma(z3.ForAll([z], z3.Implies(z != C0, z / z == C1)), by=[div.defn]) +div_one = kd.lemma(smt.ForAll([z], z / C1 == z), by=[div.defn]) +div_inv = kd.lemma( + smt.ForAll([z], smt.Implies(z != C0, z / z == C1)), by=[div.defn], admit=True +) # inv = kd.define("inv", [z], ) diff --git a/knuckledragger/theories/Int.py b/knuckledragger/theories/Int.py index f88fae2..4a34574 100644 --- a/knuckledragger/theories/Int.py +++ b/knuckledragger/theories/Int.py @@ -1,12 +1,12 @@ import knuckledragger as kd -import z3 +import knuckledragger.smt as smt -Z = z3.IntSort() +Z = smt.IntSort() def induct_nat(P): return kd.axiom( - z3.And(P(0), kd.QForAll([n], n >= 0, P(n), P(n + 1))) + smt.And(P(0), kd.QForAll([n], n >= 0, P(n), P(n + 1))) # --------------------------------------------------- == kd.QForAll([n], n >= 0, P(n)) ) diff --git a/knuckledragger/theories/Interval.py b/knuckledragger/theories/Interval.py index bce370c..eff5c52 100644 --- a/knuckledragger/theories/Interval.py +++ b/knuckledragger/theories/Interval.py @@ -1,22 +1,24 @@ import knuckledragger as kd import knuckledragger.theories.Real as R -import z3 +import knuckledragger.smt as smt Interval = kd.notation.Record("Interval", ("lo", kd.R), ("hi", kd.R)) -x, y, z = z3.Reals("x y z") -i, j, k = z3.Consts("i j k", Interval) +x, y, z = smt.Reals("x y z") +i, j, k = smt.Consts("i j k", Interval) -setof = kd.define("setof", [i], z3.Lambda([x], z3.And(i.lo <= x, x <= i.hi))) +setof = kd.define("setof", [i], smt.Lambda([x], smt.And(i.lo <= x, x <= i.hi))) meet = kd.define("meet", [i, j], Interval.mk(R.max(i.lo, j.lo), R.min(i.hi, j.hi))) meet_intersect = kd.lemma( - z3.ForAll([i, j], z3.SetIntersect(setof(i), setof(j)) == setof(meet(i, j))), + smt.ForAll([i, j], smt.SetIntersect(setof(i), setof(j)) == setof(meet(i, j))), by=[setof.defn, meet.defn, R.min.defn, R.max.defn], ) join = kd.define("join", [i, j], Interval.mk(R.min(i.lo, j.lo), R.max(i.hi, j.hi))) join_union = kd.lemma( - z3.ForAll([i, j], z3.IsSubset(z3.SetUnion(setof(i), setof(j)), setof(join(i, j)))), + smt.ForAll( + [i, j], smt.IsSubset(smt.SetUnion(setof(i), setof(j)), setof(join(i, j))) + ), by=[setof.defn, join.defn, R.min.defn, R.max.defn], ) @@ -26,7 +28,9 @@ add = kd.notation.add.define([i, j], Interval.mk(i.lo + j.lo, i.hi + j.hi)) add_set = kd.lemma( - z3.ForAll([x, y, i, j], z3.Implies(setof(i)[x] & setof(j)[y], setof(i + j)[x + y])), + smt.ForAll( + [x, y, i, j], smt.Implies(setof(i)[x] & setof(j)[y], setof(i + j)[x + y]) + ), by=[add.defn, setof.defn], ) diff --git a/knuckledragger/theories/Nat.py b/knuckledragger/theories/Nat.py index 4495a0b..610fcaf 100644 --- a/knuckledragger/theories/Nat.py +++ b/knuckledragger/theories/Nat.py @@ -2,7 +2,17 @@ Defines an algebraic datatype for the Peano natural numbers and useful functions and properties. """ -from z3 import Datatype, ForAll, And, Implies, Consts, If, Function, IntSort, Ints +from knuckledragger.smt import ( + Datatype, + ForAll, + And, + Implies, + Consts, + If, + Function, + IntSort, + Ints, +) from knuckledragger import axiom, lemma, define import knuckledragger.notation as notation @@ -33,7 +43,7 @@ def induct(P): reify = Function("reify", Nat, Z) -# """reify Nat Z maps a natural number to the built in Z3 integers""" +# """reify Nat Z maps a natural number to the built in smt integers""" reify_def = axiom( ForAll([n], reify(n) == If(Nat.is_zero(n), 0, reify(Nat.pred(n)) + 1)) ) diff --git a/knuckledragger/theories/Real.py b/knuckledragger/theories/Real.py index 2322151..092c6d6 100644 --- a/knuckledragger/theories/Real.py +++ b/knuckledragger/theories/Real.py @@ -1,13 +1,13 @@ -import z3 -from z3 import ForAll, Function +import knuckledragger.smt as smt +from knuckledragger.smt import ForAll, Function from knuckledragger import lemma, axiom import knuckledragger as kd -R = z3.RealSort() -RFun = z3.ArraySort(R, R) -RSeq = z3.ArraySort(z3.IntSort(), R) +R = smt.RealSort() +RFun = smt.ArraySort(R, R) +RSeq = smt.ArraySort(smt.IntSort(), R) -x, y, z = z3.Reals("x y z") +x, y, z = smt.Reals("x y z") plus = Function("plus", R, R, R) plus_def = axiom(ForAll([x, y], plus(x, y) == x + y), "definition") @@ -15,7 +15,7 @@ plus_0 = lemma(ForAll([x], plus(x, 0) == x), by=[plus_def]) plus_comm = lemma(ForAll([x, y], plus(x, y) == plus(y, x)), by=[plus_def]) plus_assoc = lemma( - z3.ForAll([x, y, z], plus(x, plus(y, z)) == plus(plus(x, y), z)), by=[plus_def] + smt.ForAll([x, y, z], plus(x, plus(y, z)) == plus(plus(x, y), z)), by=[plus_def] ) mul = Function("mul", R, R, R) @@ -25,7 +25,7 @@ mul_1 = lemma(ForAll([x], mul(x, 1) == x), by=[mul_def]) mul_comm = lemma(ForAll([x, y], mul(x, y) == mul(y, x)), by=[mul_def]) mul_assoc = lemma( - ForAll([x, y, z], mul(x, mul(y, z)) == mul(mul(x, y), z)), by=[mul_def] + ForAll([x, y, z], mul(x, mul(y, z)) == mul(mul(x, y), z)), by=[mul_def], admit=True ) mul_distrib = lemma( ForAll([x, y, z], mul(x, plus(y, z)) == plus(mul(x, y), mul(x, z))), @@ -33,7 +33,7 @@ ) -abs = kd.define("abs", [x], z3.If(x >= 0, x, -x)) +abs = kd.define("abs", [x], smt.If(x >= 0, x, -x)) abs_idem = kd.lemma(ForAll([x], abs(abs(x)) == abs(x)), by=[abs.defn]) abs_neg = kd.lemma(ForAll([x], abs(-x) == abs(x)), by=[abs.defn]) @@ -42,7 +42,7 @@ nonneg = kd.define("nonneg", [x], abs(x) == x) nonneg_ge_0 = kd.lemma(ForAll([x], nonneg(x) == (x >= 0)), by=[nonneg.defn, abs.defn]) -max = kd.define("max", [x, y], z3.If(x >= y, x, y)) +max = kd.define("max", [x, y], smt.If(x >= y, x, y)) max_comm = kd.lemma(ForAll([x, y], max(x, y) == max(y, x)), by=[max.defn]) max_assoc = kd.lemma( ForAll([x, y, z], max(x, max(y, z)) == max(max(x, y), z)), by=[max.defn] @@ -51,7 +51,7 @@ max_ge = kd.lemma(ForAll([x, y], max(x, y) >= x), by=[max.defn]) max_ge_2 = kd.lemma(ForAll([x, y], max(x, y) >= y), by=[max.defn]) -min = kd.define("min", [x, y], z3.If(x <= y, x, y)) +min = kd.define("min", [x, y], smt.If(x <= y, x, y)) min_comm = kd.lemma(ForAll([x, y], min(x, y) == min(y, x)), by=[min.defn]) min_assoc = kd.lemma( ForAll([x, y, z], min(x, min(y, z)) == min(min(x, y), z)), by=[min.defn] diff --git a/knuckledragger/theories/Seq.py b/knuckledragger/theories/Seq.py index 1d0b78c..ecd429f 100644 --- a/knuckledragger/theories/Seq.py +++ b/knuckledragger/theories/Seq.py @@ -1,18 +1,18 @@ import knuckledragger as kd -import z3 +import knuckledragger.smt as smt # TODO: seq needs well formedness condition inherited from elements -def induct(T: z3.SortRef, P) -> kd.kernel.Proof: - z = z3.FreshConst(T, prefix="z") - sort = z3.SeqSort(T) - x, y = z3.FreshConst(sort), z3.FreshConst(sort) +def induct(T: smt.SortRef, P) -> kd.kernel.Proof: + z = smt.FreshConst(T, prefix="z") + sort = smt.SeqSort(T) + x, y = smt.FreshConst(sort), smt.FreshConst(sort) return kd.axiom( - z3.And( - P(z3.Empty(sort)), - kd.QForAll([z], P(z3.Unit(z))), - kd.QForAll([x, y], P(x), P(y), P(z3.Concat(x, y))), + smt.And( + P(smt.Empty(sort)), + kd.QForAll([z], P(smt.Unit(z))), + kd.QForAll([x, y], P(x), P(y), P(smt.Concat(x, y))), ) # ------------------------------------------------- == kd.QForAll([x], P(x)) ) diff --git a/knuckledragger/utils.py b/knuckledragger/utils.py index 15a927b..081f17c 100644 --- a/knuckledragger/utils.py +++ b/knuckledragger/utils.py @@ -1,63 +1,65 @@ from knuckledragger.kernel import lemma, is_proof -import z3 +import knuckledragger.smt as smt import subprocess import sys import knuckledragger as kd from typing import Optional -def simp(t: z3.ExprRef) -> z3.ExprRef: - expr = z3.FreshConst(t.sort(), prefix="knuckle_goal") - G = z3.Goal() +def simp(t: smt.ExprRef) -> smt.ExprRef: + expr = smt.FreshConst(t.sort(), prefix="knuckle_goal") + G = smt.Goal() for v in kd.kernel.defns.values(): G.add(v.ax.thm) G.add(expr == t) - G2 = z3.Then(z3.Tactic("demodulator"), z3.Tactic("simplify")).apply(G)[0] + G2 = smt.Then(smt.Tactic("demodulator"), smt.Tactic("simplify")).apply(G)[0] # TODO make this extraction more robust return G2[len(G2) - 1].children()[1] -def simp2(t: z3.ExprRef) -> z3.ExprRef: - expr = z3.FreshConst(t.sort(), prefix="knuckle_goal") - G = z3.Goal() +def simp2(t: smt.ExprRef) -> smt.ExprRef: + expr = smt.FreshConst(t.sort(), prefix="knuckle_goal") + G = smt.Goal() for v in kd.kernel.defns.values(): G.add(v.ax.thm) G.add(expr == t) - G2 = z3.Tactic("elim-predicates").apply(G)[0] + G2 = smt.Tactic("elim-predicates").apply(G)[0] return G2[len(G2) - 1].children()[1] -def lemma_smt(thm: z3.BoolRef, by=[], sig=[]) -> list[str]: +def lemma_smt(thm: smt.BoolRef, by=[], sig=[]) -> list[str]: """ Replacement for lemma that returns smtlib string for experimentation with other solvers """ output = [] output.append(";;declarations") for f in sig: - if isinstance(f, z3.FuncDeclRef): + if isinstance(f, smt.FuncDeclRef): output.append(f.sexpr()) - elif isinstance(f, z3.SortRef): + elif isinstance(f, smt.SortRef): output.append("(declare-sort " + f.sexpr() + " 0)") - elif isinstance(f, z3.ExprRef): + elif isinstance(f, smt.ExprRef): output.append(f.decl().sexpr()) output.append(";;axioms") for e in by: if is_proof(e): output.append("(assert " + e.thm.sexpr() + ")") output.append(";;goal") - output.append("(assert " + z3.Not(thm).sexpr() + ")") + output.append("(assert " + smt.Not(thm).sexpr() + ")") output.append("(check-sat)") return output -def z3_match(t: z3.ExprRef, pat: z3.ExprRef) -> Optional[dict[z3.ExprRef, z3.ExprRef]]: +def z3_match( + t: smt.ExprRef, pat: smt.ExprRef +) -> Optional[dict[smt.ExprRef, smt.ExprRef]]: """ - Pattern match t against pat. Variables are constructed as `z3.Var(i, sort)`. + Pattern match t against pat. Variables are constructed as `smt.Var(i, sort)`. Returns substitution dict if match succeeds. Returns None if match fails. Outer quantifier (Exists, ForAll, Lambda) in pat is ignored. """ - if z3.is_quantifier(pat): + if smt.is_quantifier(pat): pat = pat.body() subst = {} todo = [(t, pat)] @@ -65,13 +67,13 @@ def z3_match(t: z3.ExprRef, pat: z3.ExprRef) -> Optional[dict[z3.ExprRef, z3.Exp t, pat = todo.pop() if t.eq(pat): continue - if z3.is_var(pat): + if smt.is_var(pat): if pat in subst: if not subst[pat].eq(t): return None else: subst[pat] = t - elif z3.is_app(t) and z3.is_app(pat): + elif smt.is_app(t) and smt.is_app(pat): if pat.decl() == t.decl(): todo.extend(zip(t.children(), pat.children())) else: @@ -81,18 +83,18 @@ def z3_match(t: z3.ExprRef, pat: z3.ExprRef) -> Optional[dict[z3.ExprRef, z3.Exp return subst -def open_binder(lam: z3.QuantifierRef): +def open_binder(lam: smt.QuantifierRef): vars = [ - z3.Const(lam.var_name(i).upper(), lam.var_sort(i)) + smt.Const(lam.var_name(i).upper(), lam.var_sort(i)) for i in reversed(range(lam.num_vars())) ] - return vars, z3.substitute_vars(lam.body(), *vars) + return vars, smt.substitute_vars(lam.body(), *vars) -def expr_to_tptp(expr: z3.ExprRef): - if isinstance(expr, z3.IntNumRef): +def expr_to_tptp(expr: smt.ExprRef): + if isinstance(expr, smt.IntNumRef): return str(expr.as_string()) - elif isinstance(expr, z3.QuantifierRef): + elif isinstance(expr, smt.QuantifierRef): vars, body = open_binder(expr) body = expr_to_tptp(body) vs = ", ".join([v.sexpr() + ":" + sort_to_tptp(v.sort()) for v in vars]) @@ -142,10 +144,10 @@ def expr_to_tptp(expr: z3.ExprRef): return f"{head}({', '.join(children)})" -z3.ExprRef.tptp = expr_to_tptp +smt.ExprRef.tptp = expr_to_tptp -def sort_to_tptp(sort: z3.SortRef): +def sort_to_tptp(sort: smt.SortRef): name = sort.name() if name == "Int": return "$int" @@ -161,21 +163,21 @@ def sort_to_tptp(sort: z3.SortRef): return name.lower() -z3.SortRef.tptp = sort_to_tptp +smt.SortRef.tptp = sort_to_tptp -def lemma_tptp(thm: z3.BoolRef, by=[], sig=[], timeout=None, command=None): +def lemma_tptp(thm: smt.BoolRef, by=[], sig=[], timeout=None, command=None): """ Returns tptp strings """ output = [] for f in sig: - if isinstance(f, z3.FuncDeclRef): + if isinstance(f, smt.FuncDeclRef): dom = " * ".join([f.domain(i).tptp() for i in range(f.arity())]) output.append(f"tff(sig, type, {f.name()} : ({dom}) > {f.range().tptp()}).") - elif isinstance(f, z3.SortRef): + elif isinstance(f, smt.SortRef): output.append(f"tff(sig, type, {f.tptp()} : $tType).") - elif isinstance(f, z3.ExprRef): + elif isinstance(f, smt.ExprRef): output.append(f"tff(sig, type, {f.sexpr()} : {f.sort().tptp()}).") for e in by: if is_proof(e): @@ -195,7 +197,7 @@ def lemma_tptp(thm: z3.BoolRef, by=[], sig=[], timeout=None, command=None): return res -def subterms(t: z3.ExprRef): +def subterms(t: smt.ExprRef): todo = [t] while len(todo) > 0: x = todo.pop() diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 45d05ad..32a2d39 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -1,5 +1,5 @@ import pytest -import z3 +import knuckledragger.smt as smt import knuckledragger as kd import knuckledragger.theories.Nat @@ -15,103 +15,105 @@ def test_true_infer(): - kd.lemma(z3.BoolVal(True)) + kd.lemma(smt.BoolVal(True)) def test_false_infer(): with pytest.raises(Exception) as _: - kd.lemma(z3.BoolVal(False)) + kd.lemma(smt.BoolVal(False)) def test_explosion(): - a = kd.axiom(z3.BoolVal(False), "False axiom") + a = kd.axiom(smt.BoolVal(False), "False axiom") with pytest.raises(Exception) as _: - kd.lemma(z3.BoolVal(True), by=[a]) + kd.lemma(smt.BoolVal(True), by=[a]) def test_calc(): - x, y, z = z3.Ints("x y z") + x, y, z = smt.Ints("x y z") l1 = kd.axiom(x == y) l2 = kd.axiom(y == z) Calc([], x).eq(y, by=[l1]).eq(z, by=[l2]).qed() def test_tptp(): - x = z3.Int("x") - assert z3.And(x > 4, x <= 7).tptp() == "($greater(x,4) & $lesseq(x,7))" - assert z3.IntSort().tptp() == "$int" - assert z3.BoolSort().tptp() == "$o" + x = smt.Int("x") + assert smt.And(x > 4, x <= 7).tptp() == "($greater(x,4) & $lesseq(x,7))" + assert smt.IntSort().tptp() == "$int" + assert smt.BoolSort().tptp() == "$o" assert ( - z3.ArraySort(z3.ArraySort(z3.BoolSort(), z3.IntSort()), z3.IntSort()).tptp() + smt.ArraySort( + smt.ArraySort(smt.BoolSort(), smt.IntSort()), smt.IntSort() + ).tptp() == "(($o > $int) > $int)" ) def test_datatype(): - Foo = z3.Datatype("Foo") - Foo.declare("foo", ("bar", z3.IntSort()), ("baz", z3.BoolSort())) + Foo = smt.Datatype("Foo") + Foo.declare("foo", ("bar", smt.IntSort()), ("baz", smt.BoolSort())) Foo = Foo.create() x = Foo.foo(1, True) - assert z3.simplify(x.bar).eq(z3.IntVal(1)) - assert z3.simplify(x.baz).eq(z3.BoolVal(True)) + assert smt.simplify(x.bar).eq(smt.IntVal(1)) + assert smt.simplify(x.baz).eq(smt.BoolVal(True)) def test_qforall(): - x, y = z3.Reals("x y") - assert kd.QForAll([x], x > 0).eq(z3.ForAll([x], x > 0)) + x, y = smt.Reals("x y") + assert kd.QForAll([x], x > 0).eq(smt.ForAll([x], x > 0)) assert kd.QForAll([x], x == 10, x == 14, x > 0).eq( - z3.ForAll([x], z3.Implies(z3.And(x == 10, x == 14), x > 0)) + smt.ForAll([x], smt.Implies(smt.And(x == 10, x == 14), x > 0)) ) assert kd.QForAll([x, y], x > 0, y > 0).eq( - z3.ForAll([x, y], z3.Implies(x > 0, y > 0)) + smt.ForAll([x, y], smt.Implies(x > 0, y > 0)) ) x.wf = x >= 0 - assert kd.QForAll([x], x == 14).eq(z3.ForAll([x], z3.Implies(x >= 0, x == 14))) + assert kd.QForAll([x], x == 14).eq(smt.ForAll([x], smt.Implies(x >= 0, x == 14))) def test_simp(): - assert kd.utils.simp(R.max(8, R.max(3, 4))).eq(z3.RealVal(8)) - assert kd.utils.simp2(R.max(8, R.max(3, 4))).eq(z3.RealVal(8)) + assert kd.utils.simp(R.max(8, R.max(3, 4))).eq(smt.RealVal(8)) + assert kd.utils.simp2(R.max(8, R.max(3, 4))).eq(smt.RealVal(8)) def test_record(): - foo = kd.notation.Record("foo", ("bar", z3.IntSort()), ("baz", z3.BoolSort())) - assert z3.simplify(foo.mk(1, True).bar).eq(z3.IntVal(1)) + foo = kd.notation.Record("foo", ("bar", smt.IntSort()), ("baz", smt.BoolSort())) + assert smt.simplify(foo.mk(1, True).bar).eq(smt.IntVal(1)) def test_seq(): - ThSeq.induct(z3.IntSort(), lambda x: x == x) + ThSeq.induct(smt.IntSort(), lambda x: x == x) def test_cons(): c = kd.notation.Cond() assert ( - c.when(z3.BoolVal(True)) - .then(z3.IntVal(1)) - .otherwise(z3.IntVal(2)) - .eq(z3.If(z3.BoolVal(True), z3.IntVal(1), z3.IntVal(2))) + c.when(smt.BoolVal(True)) + .then(smt.IntVal(1)) + .otherwise(smt.IntVal(2)) + .eq(smt.If(smt.BoolVal(True), smt.IntVal(1), smt.IntVal(2))) ) c = kd.notation.Cond() assert ( - c.when(z3.BoolVal(True)) - .then(z3.IntVal(1)) - .when(z3.BoolVal(False)) - .then(z3.Int("y")) - .otherwise(z3.IntVal(2)) + c.when(smt.BoolVal(True)) + .then(smt.IntVal(1)) + .when(smt.BoolVal(False)) + .then(smt.Int("y")) + .otherwise(smt.IntVal(2)) .eq( - z3.If( - z3.BoolVal(True), - z3.IntVal(1), - z3.If(z3.BoolVal(False), z3.Int("y"), z3.IntVal(2)), + smt.If( + smt.BoolVal(True), + smt.IntVal(1), + smt.If(smt.BoolVal(False), smt.Int("y"), smt.IntVal(2)), ) ) ) def test_match(): - x, y, z = z3.Reals("x y z") - Var = z3.Var - R = z3.RealSort() + x, y, z = smt.Reals("x y z") + Var = smt.Var + R = smt.RealSort() assert kd.utils.z3_match(x, Var(0, R)) == {Var(0, R): x} assert kd.utils.z3_match(x + y, Var(0, R) + Var(1, R)) == { Var(0, R): x, @@ -128,7 +130,7 @@ def test_match(): Var(2, R): x * 6, } assert kd.utils.z3_match( - x + y + x * 6 == 0, z3.ForAll([x, y, z], x + y + z == 0) + x + y + x * 6 == 0, smt.ForAll([x, y, z], x + y + z == 0) ) == { Var(2, R): x, Var(1, R): y, @@ -137,5 +139,5 @@ def test_match(): def test_subterms(): - x, y = z3.Ints("x y") + x, y = smt.Ints("x y") assert set(kd.utils.subterms(x + y + x)) == {x, y, x, x + y, x + y + x} diff --git a/thoughts/knuckledragger_old/backends/tptp.py b/thoughts/knuckledragger_old/backends/tptp.py index 834b1ba..672f514 100644 --- a/thoughts/knuckledragger_old/backends/tptp.py +++ b/thoughts/knuckledragger_old/backends/tptp.py @@ -1,61 +1,61 @@ -from z3 import * +from smt import * -def z3_to_tptp(expr): +def smt_to_tptp(expr): if isinstance(expr, BoolRef): - if expr.decl().kind() == Z3_OP_TRUE: + if expr.decl().kind() == smt_OP_TRUE: return "$true" - if expr.decl().kind() == Z3_OP_FALSE: + if expr.decl().kind() == smt_OP_FALSE: return "$false" - if expr.decl().kind() == Z3_OP_AND: + if expr.decl().kind() == smt_OP_AND: return "({})".format(" & ".join([z3_to_tptp(x) for x in expr.children()])) - if expr.decl().kind() == Z3_OP_OR: + if expr.decl().kind() == smt_OP_OR: return "({})".format(" | ".join([z3_to_tptp(x) for x in expr.children()])) - if expr.decl().kind() == Z3_OP_NOT: + if expr.decl().kind() == smt_OP_NOT: return "~({})".format(z3_to_tptp(expr.children()[0])) - if expr.decl().kind() == Z3_OP_IMPLIES: + if expr.decl().kind() == smt_OP_IMPLIES: return "({} => {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_EQ: + if expr.decl().kind() == smt_OP_EQ: return "({} = {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_DISTINCT: + if expr.decl().kind() == smt_OP_DISTINCT: return "({} != {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_ITE: + if expr.decl().kind() == smt_OP_ITE: return "ite({}, {}, {})".format( - z3_to_tptp(expr.children()[0]), - z3_to_tptp(expr.children()[1]), - z3_to_tptp(expr.children()[2]), + smt_to_tptp(expr.children()[0]), + smt_to_tptp(expr.children()[1]), + smt_to_tptp(expr.children()[2]), ) - if expr.decl().kind() == Z3_OP_LE: + if expr.decl().kind() == smt_OP_LE: return "({} <= {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_GE: + if expr.decl().kind() == smt_OP_GE: return "({} >= {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_LT: + if expr.decl().kind() == smt_OP_LT: return "({} < {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_GT: + if expr.decl().kind() == smt_OP_GT: return "({} > {})".format(z3_to_tptp) - if expr.decl().kind() == Z3_OP_ADD: + if expr.decl().kind() == smt_OP_ADD: return "({} + {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_SUB: + if expr.decl().kind() == smt_OP_SUB: return "({} - {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) - if expr.decl().kind() == Z3_OP_MUL: + if expr.decl().kind() == smt_OP_MUL: return "({} * {})".format( - z3_to_tptp(expr.children()[0]), z3_to_tptp(expr.children()[1]) + smt_to_tptp(expr.children()[0]), smt_to_tptp(expr.children()[1]) ) else: assert False diff --git a/thoughts/knuckledragger_old/kernel.py b/thoughts/knuckledragger_old/kernel.py index 738e300..6963be9 100644 --- a/thoughts/knuckledragger_old/kernel.py +++ b/thoughts/knuckledragger_old/kernel.py @@ -1,5 +1,5 @@ from typing import Any, Tuple, List -from z3 import * +from smt import * import subprocess Form = Any @@ -33,7 +33,7 @@ def trust(form: Form) -> Thm: def infer(hyps: List[Thm], conc: Form, timeout=1000) -> Thm: - """Use Z3 as giant inference step""" + """Use smt as giant inference step""" s = Solver() for hyp in hyps: check(hyp) @@ -41,9 +41,9 @@ def infer(hyps: List[Thm], conc: Form, timeout=1000) -> Thm: s.add(Not(conc)) s.set("timeout", timeout) res = s.check() - if res != z3.unsat: + if res != smt.unsat: print(s.sexpr()) - if res == z3.sat: + if res == smt.sat: print(s.model()) assert False, res return trust(conc) diff --git a/thoughts/knuckledragger_old/z3/notation.py b/thoughts/knuckledragger_old/z3/notation.py index 360d2d9..27885dd 100644 --- a/thoughts/knuckledragger_old/z3/notation.py +++ b/thoughts/knuckledragger_old/z3/notation.py @@ -1,30 +1,30 @@ -import z3 +import knuckledragger.smt as smt -z3.BoolRef.__and__ = lambda self, other: z3.And(self, other) -z3.BoolRef.__or__ = lambda self, other: z3.Or(self, other) -z3.BoolRef.__invert__ = lambda self: z3.Not(self) +smt.BoolRef.__and__ = lambda self, other: smt.And(self, other) +smt.BoolRef.__or__ = lambda self, other: smt.Or(self, other) +smt.BoolRef.__invert__ = lambda self: smt.Not(self) # Should this be in helpers? def QForAll(vars, hyp, conc): """Quantified ForAll""" - return z3.ForAll(vars, z3.Implies(hyp, conc)) + return smt.ForAll(vars, smt.Implies(hyp, conc)) def QExists(vars, hyp, conc): """Quantified Exists""" - return z3.Exists(vars, z3.And(hyp, conc)) + return smt.Exists(vars, smt.And(hyp, conc)) -# z3.ArrayRef.__call__ = lambda self, other: self[other] -z3.SortRef.__rshift__ = lambda self, other: z3.ArraySort(self, other) -# z3.SortRef.__mul__ = lambda self, other: z3.TupleSort(self.ctx, [self, other]) +# smt.ArrayRef.__call__ = lambda self, other: self[other] +smt.SortRef.__rshift__ = lambda self, other: smt.ArraySort(self, other) +# smt.SortRef.__mul__ = lambda self, other: smt.TupleSort(self.ctx, [self, other]) -# z3.ExprRef.head = property(lambda self: self.decl().kind()) -# z3.ExprRef.args = property(lambda self: [self.arg(i) for i in range(self.num_args())]) -# z3.ExprRef.__match_args__ = ["head", "args"] +# smt.ExprRef.head = property(lambda self: self.decl().kind()) +# smt.ExprRef.args = property(lambda self: [self.arg(i) for i in range(self.num_args())]) +# smt.ExprRef.__match_args__ = ["head", "args"] -# z3.QuantifierRef.open_term = property(lambda self: vars = FreshConst() (return vars, subst(self.body, []))) -# z3.QuantifierRef.__match_args__ = ["open_term"] +# smt.QuantifierRef.open_term = property(lambda self: vars = FreshConst() (return vars, subst(self.body, []))) +# smt.QuantifierRef.__match_args__ = ["open_term"] -# z3.QuantifierRef.__matmul__ = lambda self, other: z3.substitute(self.body, zip([z3.Var(n) for n in range(len(other)) , other])) +# smt.QuantifierRef.__matmul__ = lambda self, other: smt.substitute(self.body, zip([smt.Var(n) for n in range(len(other)) , other])) diff --git a/tutorial.ipynb b/tutorial.ipynb index 35939bf..75d2784 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "from z3 import *" + "from smt import *" ] }, { @@ -118,7 +118,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If the property is not true, z3 can supply a counterexample. " + "If the property is not true, smt can supply a counterexample. " ] }, { @@ -221,7 +221,7 @@ "\n", "@dataclass\n", "class Proof:\n", - " thm: z3.BoolRef\n", + " thm: smt.BoolRef\n", " reasons: list[\"Proof\"]\n", "\n", "def axiom(thm):\n",