diff --git a/knuckledragger/__init__.py b/knuckledragger/__init__.py index 86c3d04..45df83b 100644 --- a/knuckledragger/__init__.py +++ b/knuckledragger/__init__.py @@ -1,8 +1,9 @@ from . import kernel from . import notation from . import tactics +import z3 -lemma = kernel.lemma +lemma = tactics.lemma axiom = kernel.axiom define = kernel.define @@ -10,3 +11,8 @@ QExists = notation.QExists Calc = tactics.Calc + +R = z3.RealSort() +Z = z3.IntSort() +RSeq = Z >> R +RFun = R >> R diff --git a/knuckledragger/kernel.py b/knuckledragger/kernel.py index 6fe8089..d8417fa 100644 --- a/knuckledragger/kernel.py +++ b/knuckledragger/kernel.py @@ -3,7 +3,7 @@ from typing import Any import logging -logger = logging.getLogger("knuckeldragger") +logger = logging.getLogger("knuckledragger") @dataclass(frozen=True) @@ -65,11 +65,11 @@ def lemma( else: s = solver() s.set("timeout", timeout) - for n, p in enumerate(by): + for p in by: if not isinstance(p, __Proof): raise LemmaError("In by reasons:", p, "is not a Proof object") - s.assert_and_track(p.thm, "by_{}".format(n)) - s.assert_and_track(z3.Not(thm), "knuckledragger_goal") + s.add(p.thm) + s.add(z3.Not(thm)) if dump: print(s.sexpr()) res = s.check() @@ -78,10 +78,6 @@ def lemma( raise LemmaError(thm, "Countermodel", s.model()) raise LemmaError("lemma", thm, res) else: - core = s.unsat_core() - assert z3.Bool("knuckledragger_goal") in core - if len(core) < len(by) + 1: - print("WARNING: Unneeded assumptions. Used", core) return __Proof(thm, by, False) diff --git a/knuckledragger/notation.py b/knuckledragger/notation.py index 4982b73..17f0582 100644 --- a/knuckledragger/notation.py +++ b/knuckledragger/notation.py @@ -75,16 +75,16 @@ def define(self, args, body): return defn -add = SortDispatch(z3.ArithRef.__add__) +add = SortDispatch(z3.ArithRef.__add__, name="add") z3.ExprRef.__add__ = lambda x, y: add(x, y) -sub = SortDispatch(z3.ArithRef.__sub__) +sub = SortDispatch(z3.ArithRef.__sub__, name="sub") z3.ExprRef.__sub__ = lambda x, y: sub(x, y) -mul = SortDispatch(z3.ArithRef.__mul__) +mul = SortDispatch(z3.ArithRef.__mul__, name="mul") z3.ExprRef.__mul__ = lambda x, y: mul(x, y) -div = SortDispatch(z3.ArithRef.__div__) +div = SortDispatch(z3.ArithRef.__div__, name="div") z3.ExprRef.__truediv__ = lambda x, y: div(x, y) and_ = SortDispatch() @@ -93,10 +93,10 @@ def define(self, args, body): or_ = SortDispatch() z3.ExprRef.__or__ = lambda x, y: or_(x, y) -lt = SortDispatch(z3.ArithRef.__lt__) +lt = SortDispatch(z3.ArithRef.__lt__, name="lt") z3.ExprRef.__lt__ = lambda x, y: lt(x, y) -le = SortDispatch(z3.ArithRef.__le__) +le = SortDispatch(z3.ArithRef.__le__, name="le") z3.ExprRef.__le__ = lambda x, y: le(x, y) diff --git a/knuckledragger/tactics.py b/knuckledragger/tactics.py index f0bafe0..be51bac 100644 --- a/knuckledragger/tactics.py +++ b/knuckledragger/tactics.py @@ -32,3 +32,63 @@ def __repr__(self): def qed(self): return kd.lemma(self.ForAll(self.terms[0] == self.terms[-1]), by=self.lemmas) + + +def lemma( + thm: z3.BoolRef, + by: list[kd.kernel.Proof] = [], + admit=False, + timeout=1000, + dump=False, + solver=z3.Solver, +) -> kd.kernel.Proof: + """Prove a theorem using a list of previously proved lemmas. + + In essence `prove(Implies(by, thm))`. + + :param thm: The theorem to prove. + Args: + thm (z3.BoolRef): The theorem to prove. + by (list[Proof]): A list of previously proved lemmas. + admit (bool): If True, admit the theorem without proof. + + Returns: + Proof: A proof object of thm + + >>> lemma(BoolVal(True)) + + >>> lemma(RealVal(1) >= RealVal(0)) + + """ + if admit: + return kd.kernel.lemma( + thm, by, admit=admit, timeout=timeout, dump=dump, solver=solver + ) + else: + s = solver() + s.set("timeout", timeout) + for n, p in enumerate(by): + if not kd.kernel.is_proof(p): + raise kd.kernel.LemmaError("In by reasons:", p, "is not a Proof object") + s.assert_and_track(p.thm, "by_{}".format(n)) + s.assert_and_track(z3.Not(thm), "knuckledragger_goal") + if dump: + print(s.sexpr()) + res = s.check() + if res != z3.unsat: + if res == z3.sat: + raise kd.kernel.LemmaError(thm, "Countermodel", s.model()) + raise kd.kernel.LemmaError("lemma", thm, res) + else: + core = s.unsat_core() + if not z3.Bool("knuckledragger_goal") in core: + raise kd.kernel.LemmaError( + thm, + core, + "Inconsistent lemmas. Goal is not used for proof. Something has gone awry.", + ) + if len(core) < len(by) + 1: + print("WARNING: Unneeded assumptions. Used", core, thm) + return kd.kernel.lemma( + thm, by, admit=admit, timeout=timeout, dump=dump, solver=solver + ) diff --git a/knuckledragger/theories/Complex.py b/knuckledragger/theories/Complex.py new file mode 100644 index 0000000..42329c0 --- /dev/null +++ b/knuckledragger/theories/Complex.py @@ -0,0 +1,48 @@ +import knuckledragger as kd +import z3 + +C = kd.notation.Record("C", ("re", z3.RealSort()), ("im", z3.RealSort())) + +z, w, u, z1, z2 = z3.Consts("z w u z1 z2", C) +add = kd.define("add", [z1, z2], C.mk(z1.re + z2.re, z1.im + z2.im)) +kd.notation.add.register(C, add) +mul = kd.define( + "mul", [z1, z2], C.mk(z1.re * z2.re - z1.im * z2.im, z1.re * z2.im + z1.im * z2.re) +) +kd.notation.mul.register(C, mul) +conj = kd.define("conj", [z], C.mk(z.re, -z.im)) + + +div = kd.define( + "div", + [z1, z2], + C.mk( + (z1.re * z2.re + z1.im * z2.im) / (z2.re**2 + z2.im**2), + (z1.im * z2.re - z1.re * z2.im) / (z2.re**2 + z2.im**2), + ), +) +kd.notation.div.register(C, div) + +C0 = C.mk(0, 0) +C1 = C.mk(1, 0) + +add_zero = kd.lemma(z3.ForAll([z], z + C0 == z), by=[add.defn]) +mul_zero = kd.lemma(z3.ForAll([z], z * C0 == C0), by=[mul.defn]) +mul_one = kd.lemma(z3.ForAll([z], z * C1 == z), by=[mul.defn]) +add_comm = kd.lemma(z3.ForAll([z, w], z + w == w + z), by=[add.defn]) +add_assoc = kd.lemma( + z3.ForAll([z, w, u], (z + (w + u)) == ((z + w) + u)), by=[add.defn] +) +mul_comm = kd.lemma(z3.ForAll([z, w], z * w == w * z), by=[mul.defn]) + +# unstable perfoamnce. +# mul_div = kd.lemma(ForAll([z,w], Implies(w != C0, z == z * w / w)), by=[div.defn, mul.defn], timeout=1000) +##mul_div = Calc() +div_one = kd.lemma(z3.ForAll([z], z / C1 == z), by=[div.defn]) +div_inv = kd.lemma(z3.ForAll([z], z3.Implies(z != C0, z / z == C1)), by=[div.defn]) + +# inv = kd.define("inv", [z], ) + +# conjugate +# polar +norm2 = kd.define("norm2", [z], z * conj(z)) diff --git a/knuckledragger/theories/Interval.py b/knuckledragger/theories/Interval.py new file mode 100644 index 0000000..bce370c --- /dev/null +++ b/knuckledragger/theories/Interval.py @@ -0,0 +1,33 @@ +import knuckledragger as kd +import knuckledragger.theories.Real as R +import z3 + +Interval = kd.notation.Record("Interval", ("lo", kd.R), ("hi", kd.R)) +x, y, z = z3.Reals("x y z") +i, j, k = z3.Consts("i j k", Interval) + +setof = kd.define("setof", [i], z3.Lambda([x], z3.And(i.lo <= x, x <= i.hi))) + +meet = kd.define("meet", [i, j], Interval.mk(R.max(i.lo, j.lo), R.min(i.hi, j.hi))) +meet_intersect = kd.lemma( + z3.ForAll([i, j], z3.SetIntersect(setof(i), setof(j)) == setof(meet(i, j))), + by=[setof.defn, meet.defn, R.min.defn, R.max.defn], +) + +join = kd.define("join", [i, j], Interval.mk(R.min(i.lo, j.lo), R.max(i.hi, j.hi))) +join_union = kd.lemma( + z3.ForAll([i, j], z3.IsSubset(z3.SetUnion(setof(i), setof(j)), setof(join(i, j)))), + by=[setof.defn, join.defn, R.min.defn, R.max.defn], +) + + +width = kd.define("width", [i], i.hi - i.lo) +mid = kd.define("mid", [i], (i.lo + i.hi) / 2) + +add = kd.notation.add.define([i, j], Interval.mk(i.lo + j.lo, i.hi + j.hi)) +add_set = kd.lemma( + z3.ForAll([x, y, i, j], z3.Implies(setof(i)[x] & setof(j)[y], setof(i + j)[x + y])), + by=[add.defn, setof.defn], +) + +sub = kd.notation.sub.define([i, j], Interval.mk(i.lo - j.hi, i.hi - j.lo)) diff --git a/knuckledragger/theories/Real.py b/knuckledragger/theories/Real.py index 68367d6..499236e 100644 --- a/knuckledragger/theories/Real.py +++ b/knuckledragger/theories/Real.py @@ -50,3 +50,12 @@ max_idem = kd.lemma(ForAll([x], max(x, x) == x), by=[max.defn]) max_ge = kd.lemma(ForAll([x, y], max(x, y) >= x), by=[max.defn]) max_ge_2 = kd.lemma(ForAll([x, y], max(x, y) >= y), by=[max.defn]) + +min = kd.define("min", [x, y], z3.If(x <= y, x, y)) +min_comm = kd.lemma(ForAll([x, y], min(x, y) == min(y, x)), by=[min.defn]) +min_assoc = kd.lemma( + ForAll([x, y, z], min(x, min(y, z)) == min(min(x, y), z)), by=[min.defn] +) +min_idem = kd.lemma(ForAll([x], min(x, x) == x), by=[min.defn]) +min_le = kd.lemma(ForAll([x, y], min(x, y) <= x), by=[min.defn]) +min_le_2 = kd.lemma(ForAll([x, y], min(x, y) <= y), by=[min.defn]) diff --git a/knuckledragger/theories/Seq.py b/knuckledragger/theories/Seq.py index 5e23d37..1d0b78c 100644 --- a/knuckledragger/theories/Seq.py +++ b/knuckledragger/theories/Seq.py @@ -1,6 +1,18 @@ -from knuckledragger.theories.Nat import Nat -from knuckledragger.theories.Real import R -from z3 import ArraySort +import knuckledragger as kd +import z3 -"""A Sequence type of Nat -> R""" -Seq = ArraySort(Nat, R) +# TODO: seq needs well formedness condition inherited from elements + + +def induct(T: z3.SortRef, P) -> kd.kernel.Proof: + z = z3.FreshConst(T, prefix="z") + sort = z3.SeqSort(T) + x, y = z3.FreshConst(sort), z3.FreshConst(sort) + return kd.axiom( + z3.And( + P(z3.Empty(sort)), + kd.QForAll([z], P(z3.Unit(z))), + kd.QForAll([x, y], P(x), P(y), P(z3.Concat(x, y))), + ) # ------------------------------------------------- + == kd.QForAll([x], P(x)) + ) diff --git a/knuckledragger/theories/List.py b/knuckledragger/theories/Vec.py similarity index 100% rename from knuckledragger/theories/List.py rename to knuckledragger/theories/Vec.py diff --git a/knuckledragger/utils.py b/knuckledragger/utils.py index b35cb82..d544975 100644 --- a/knuckledragger/utils.py +++ b/knuckledragger/utils.py @@ -12,6 +12,7 @@ def simp(t: z3.ExprRef) -> z3.ExprRef: G.add(v.ax.thm) G.add(expr == t) G2 = z3.Then(z3.Tactic("demodulator"), z3.Tactic("simplify")).apply(G)[0] + # TODO make this extraction more robust return G2[len(G2) - 1].children()[1] diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 6b25863..b84d27e 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -1,38 +1,38 @@ import pytest -from knuckledragger import lemma, axiom import z3 import knuckledragger as kd import knuckledragger.theories.Nat import knuckledragger.theories.Int import knuckledragger.theories.Real as R +import knuckledragger.theories.Complex as C +import knuckledragger.theories.Interval -import knuckledragger.theories.List -import knuckledragger.theories.Seq +import knuckledragger.theories.Seq as ThSeq from knuckledragger import Calc import knuckledragger.utils def test_true_infer(): - lemma(z3.BoolVal(True)) + kd.lemma(z3.BoolVal(True)) def test_false_infer(): with pytest.raises(Exception) as _: - lemma(z3.BoolVal(False)) + kd.lemma(z3.BoolVal(False)) def test_explosion(): - a = axiom(z3.BoolVal(False), "False axiom") + a = kd.axiom(z3.BoolVal(False), "False axiom") with pytest.raises(Exception) as _: - lemma(z3.BoolVal(True), by=[a]) + kd.lemma(z3.BoolVal(True), by=[a]) def test_calc(): x, y, z = z3.Ints("x y z") - l1 = axiom(x == y) - l2 = axiom(y == z) + l1 = kd.axiom(x == y) + l2 = kd.axiom(y == z) Calc([], x).eq(y, by=[l1]).eq(z, by=[l2]).qed() @@ -77,3 +77,7 @@ def test_simp(): def test_record(): foo = kd.notation.Record("foo", ("bar", z3.IntSort()), ("baz", z3.BoolSort())) assert z3.simplify(foo.mk(1, True).bar).eq(z3.IntVal(1)) + + +def test_seq(): + ThSeq.induct(z3.IntSort(), lambda x: x == x)