Skip to content

Commit

Permalink
added theories from notes
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Aug 5, 2024
1 parent 547dccd commit c02b952
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 29 deletions.
8 changes: 7 additions & 1 deletion knuckledragger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from . import kernel
from . import notation
from . import tactics
import z3

lemma = kernel.lemma
lemma = tactics.lemma
axiom = kernel.axiom
define = kernel.define

QForAll = notation.QForAll
QExists = notation.QExists

Calc = tactics.Calc

R = z3.RealSort()
Z = z3.IntSort()
RSeq = Z >> R
RFun = R >> R
12 changes: 4 additions & 8 deletions knuckledragger/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any
import logging

logger = logging.getLogger("knuckeldragger")
logger = logging.getLogger("knuckledragger")


@dataclass(frozen=True)
Expand Down Expand Up @@ -65,11 +65,11 @@ def lemma(
else:
s = solver()
s.set("timeout", timeout)
for n, p in enumerate(by):
for p in by:
if not isinstance(p, __Proof):
raise LemmaError("In by reasons:", p, "is not a Proof object")
s.assert_and_track(p.thm, "by_{}".format(n))
s.assert_and_track(z3.Not(thm), "knuckledragger_goal")
s.add(p.thm)
s.add(z3.Not(thm))
if dump:
print(s.sexpr())
res = s.check()
Expand All @@ -78,10 +78,6 @@ def lemma(
raise LemmaError(thm, "Countermodel", s.model())
raise LemmaError("lemma", thm, res)
else:
core = s.unsat_core()
assert z3.Bool("knuckledragger_goal") in core
if len(core) < len(by) + 1:
print("WARNING: Unneeded assumptions. Used", core)
return __Proof(thm, by, False)


Expand Down
12 changes: 6 additions & 6 deletions knuckledragger/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,16 @@ def define(self, args, body):
return defn


add = SortDispatch(z3.ArithRef.__add__)
add = SortDispatch(z3.ArithRef.__add__, name="add")
z3.ExprRef.__add__ = lambda x, y: add(x, y)

sub = SortDispatch(z3.ArithRef.__sub__)
sub = SortDispatch(z3.ArithRef.__sub__, name="sub")
z3.ExprRef.__sub__ = lambda x, y: sub(x, y)

mul = SortDispatch(z3.ArithRef.__mul__)
mul = SortDispatch(z3.ArithRef.__mul__, name="mul")
z3.ExprRef.__mul__ = lambda x, y: mul(x, y)

div = SortDispatch(z3.ArithRef.__div__)
div = SortDispatch(z3.ArithRef.__div__, name="div")
z3.ExprRef.__truediv__ = lambda x, y: div(x, y)

and_ = SortDispatch()
Expand All @@ -93,10 +93,10 @@ def define(self, args, body):
or_ = SortDispatch()
z3.ExprRef.__or__ = lambda x, y: or_(x, y)

lt = SortDispatch(z3.ArithRef.__lt__)
lt = SortDispatch(z3.ArithRef.__lt__, name="lt")
z3.ExprRef.__lt__ = lambda x, y: lt(x, y)

le = SortDispatch(z3.ArithRef.__le__)
le = SortDispatch(z3.ArithRef.__le__, name="le")
z3.ExprRef.__le__ = lambda x, y: le(x, y)


Expand Down
60 changes: 60 additions & 0 deletions knuckledragger/tactics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,63 @@ def __repr__(self):

def qed(self):
return kd.lemma(self.ForAll(self.terms[0] == self.terms[-1]), by=self.lemmas)


def lemma(
thm: z3.BoolRef,
by: list[kd.kernel.Proof] = [],
admit=False,
timeout=1000,
dump=False,
solver=z3.Solver,
) -> kd.kernel.Proof:
"""Prove a theorem using a list of previously proved lemmas.
In essence `prove(Implies(by, thm))`.
:param thm: The theorem to prove.
Args:
thm (z3.BoolRef): The theorem to prove.
by (list[Proof]): A list of previously proved lemmas.
admit (bool): If True, admit the theorem without proof.
Returns:
Proof: A proof object of thm
>>> lemma(BoolVal(True))
>>> lemma(RealVal(1) >= RealVal(0))
"""
if admit:
return kd.kernel.lemma(
thm, by, admit=admit, timeout=timeout, dump=dump, solver=solver
)
else:
s = solver()
s.set("timeout", timeout)
for n, p in enumerate(by):
if not kd.kernel.is_proof(p):
raise kd.kernel.LemmaError("In by reasons:", p, "is not a Proof object")
s.assert_and_track(p.thm, "by_{}".format(n))
s.assert_and_track(z3.Not(thm), "knuckledragger_goal")
if dump:
print(s.sexpr())
res = s.check()
if res != z3.unsat:
if res == z3.sat:
raise kd.kernel.LemmaError(thm, "Countermodel", s.model())
raise kd.kernel.LemmaError("lemma", thm, res)
else:
core = s.unsat_core()
if not z3.Bool("knuckledragger_goal") in core:

Check failure on line 84 in knuckledragger/tactics.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (E713)

knuckledragger/tactics.py:84:20: E713 Test for membership should be `not in`

Check failure on line 84 in knuckledragger/tactics.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E713)

knuckledragger/tactics.py:84:20: E713 Test for membership should be `not in`

Check failure on line 84 in knuckledragger/tactics.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E713)

knuckledragger/tactics.py:84:20: E713 Test for membership should be `not in`
raise kd.kernel.LemmaError(
thm,
core,
"Inconsistent lemmas. Goal is not used for proof. Something has gone awry.",
)
if len(core) < len(by) + 1:
print("WARNING: Unneeded assumptions. Used", core, thm)
return kd.kernel.lemma(
thm, by, admit=admit, timeout=timeout, dump=dump, solver=solver
)
48 changes: 48 additions & 0 deletions knuckledragger/theories/Complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import knuckledragger as kd
import z3

C = kd.notation.Record("C", ("re", z3.RealSort()), ("im", z3.RealSort()))

z, w, u, z1, z2 = z3.Consts("z w u z1 z2", C)
add = kd.define("add", [z1, z2], C.mk(z1.re + z2.re, z1.im + z2.im))
kd.notation.add.register(C, add)
mul = kd.define(
"mul", [z1, z2], C.mk(z1.re * z2.re - z1.im * z2.im, z1.re * z2.im + z1.im * z2.re)
)
kd.notation.mul.register(C, mul)
conj = kd.define("conj", [z], C.mk(z.re, -z.im))


div = kd.define(
"div",
[z1, z2],
C.mk(
(z1.re * z2.re + z1.im * z2.im) / (z2.re**2 + z2.im**2),
(z1.im * z2.re - z1.re * z2.im) / (z2.re**2 + z2.im**2),
),
)
kd.notation.div.register(C, div)

C0 = C.mk(0, 0)
C1 = C.mk(1, 0)

add_zero = kd.lemma(z3.ForAll([z], z + C0 == z), by=[add.defn])
mul_zero = kd.lemma(z3.ForAll([z], z * C0 == C0), by=[mul.defn])
mul_one = kd.lemma(z3.ForAll([z], z * C1 == z), by=[mul.defn])
add_comm = kd.lemma(z3.ForAll([z, w], z + w == w + z), by=[add.defn])
add_assoc = kd.lemma(
z3.ForAll([z, w, u], (z + (w + u)) == ((z + w) + u)), by=[add.defn]
)
mul_comm = kd.lemma(z3.ForAll([z, w], z * w == w * z), by=[mul.defn])

# unstable perfoamnce.
# mul_div = kd.lemma(ForAll([z,w], Implies(w != C0, z == z * w / w)), by=[div.defn, mul.defn], timeout=1000)
##mul_div = Calc()
div_one = kd.lemma(z3.ForAll([z], z / C1 == z), by=[div.defn])
div_inv = kd.lemma(z3.ForAll([z], z3.Implies(z != C0, z / z == C1)), by=[div.defn])

# inv = kd.define("inv", [z], )

# conjugate
# polar
norm2 = kd.define("norm2", [z], z * conj(z))
33 changes: 33 additions & 0 deletions knuckledragger/theories/Interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import knuckledragger as kd
import knuckledragger.theories.Real as R
import z3

Interval = kd.notation.Record("Interval", ("lo", kd.R), ("hi", kd.R))
x, y, z = z3.Reals("x y z")
i, j, k = z3.Consts("i j k", Interval)

setof = kd.define("setof", [i], z3.Lambda([x], z3.And(i.lo <= x, x <= i.hi)))

meet = kd.define("meet", [i, j], Interval.mk(R.max(i.lo, j.lo), R.min(i.hi, j.hi)))
meet_intersect = kd.lemma(
z3.ForAll([i, j], z3.SetIntersect(setof(i), setof(j)) == setof(meet(i, j))),
by=[setof.defn, meet.defn, R.min.defn, R.max.defn],
)

join = kd.define("join", [i, j], Interval.mk(R.min(i.lo, j.lo), R.max(i.hi, j.hi)))
join_union = kd.lemma(
z3.ForAll([i, j], z3.IsSubset(z3.SetUnion(setof(i), setof(j)), setof(join(i, j)))),
by=[setof.defn, join.defn, R.min.defn, R.max.defn],
)


width = kd.define("width", [i], i.hi - i.lo)
mid = kd.define("mid", [i], (i.lo + i.hi) / 2)

add = kd.notation.add.define([i, j], Interval.mk(i.lo + j.lo, i.hi + j.hi))
add_set = kd.lemma(
z3.ForAll([x, y, i, j], z3.Implies(setof(i)[x] & setof(j)[y], setof(i + j)[x + y])),
by=[add.defn, setof.defn],
)

sub = kd.notation.sub.define([i, j], Interval.mk(i.lo - j.hi, i.hi - j.lo))
9 changes: 9 additions & 0 deletions knuckledragger/theories/Real.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,12 @@
max_idem = kd.lemma(ForAll([x], max(x, x) == x), by=[max.defn])
max_ge = kd.lemma(ForAll([x, y], max(x, y) >= x), by=[max.defn])
max_ge_2 = kd.lemma(ForAll([x, y], max(x, y) >= y), by=[max.defn])

min = kd.define("min", [x, y], z3.If(x <= y, x, y))
min_comm = kd.lemma(ForAll([x, y], min(x, y) == min(y, x)), by=[min.defn])
min_assoc = kd.lemma(
ForAll([x, y, z], min(x, min(y, z)) == min(min(x, y), z)), by=[min.defn]
)
min_idem = kd.lemma(ForAll([x], min(x, x) == x), by=[min.defn])
min_le = kd.lemma(ForAll([x, y], min(x, y) <= x), by=[min.defn])
min_le_2 = kd.lemma(ForAll([x, y], min(x, y) <= y), by=[min.defn])
22 changes: 17 additions & 5 deletions knuckledragger/theories/Seq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from knuckledragger.theories.Nat import Nat
from knuckledragger.theories.Real import R
from z3 import ArraySort
import knuckledragger as kd
import z3

"""A Sequence type of Nat -> R"""
Seq = ArraySort(Nat, R)
# TODO: seq needs well formedness condition inherited from elements


def induct(T: z3.SortRef, P) -> kd.kernel.Proof:
z = z3.FreshConst(T, prefix="z")
sort = z3.SeqSort(T)
x, y = z3.FreshConst(sort), z3.FreshConst(sort)
return kd.axiom(
z3.And(
P(z3.Empty(sort)),
kd.QForAll([z], P(z3.Unit(z))),
kd.QForAll([x, y], P(x), P(y), P(z3.Concat(x, y))),
) # -------------------------------------------------
== kd.QForAll([x], P(x))
)
File renamed without changes.
1 change: 1 addition & 0 deletions knuckledragger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def simp(t: z3.ExprRef) -> z3.ExprRef:
G.add(v.ax.thm)
G.add(expr == t)
G2 = z3.Then(z3.Tactic("demodulator"), z3.Tactic("simplify")).apply(G)[0]
# TODO make this extraction more robust
return G2[len(G2) - 1].children()[1]


Expand Down
22 changes: 13 additions & 9 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import pytest
from knuckledragger import lemma, axiom
import z3

import knuckledragger as kd
import knuckledragger.theories.Nat
import knuckledragger.theories.Int
import knuckledragger.theories.Real as R
import knuckledragger.theories.Complex as C

Check failure on line 8 in tests/test_kernel.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

tests/test_kernel.py:8:43: F401 `knuckledragger.theories.Complex` imported but unused

Check failure on line 8 in tests/test_kernel.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F401)

tests/test_kernel.py:8:43: F401 `knuckledragger.theories.Complex` imported but unused

Check failure on line 8 in tests/test_kernel.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

tests/test_kernel.py:8:43: F401 `knuckledragger.theories.Complex` imported but unused
import knuckledragger.theories.Interval

import knuckledragger.theories.List
import knuckledragger.theories.Seq
import knuckledragger.theories.Seq as ThSeq

from knuckledragger import Calc
import knuckledragger.utils


def test_true_infer():
lemma(z3.BoolVal(True))
kd.lemma(z3.BoolVal(True))


def test_false_infer():
with pytest.raises(Exception) as _:
lemma(z3.BoolVal(False))
kd.lemma(z3.BoolVal(False))


def test_explosion():
a = axiom(z3.BoolVal(False), "False axiom")
a = kd.axiom(z3.BoolVal(False), "False axiom")
with pytest.raises(Exception) as _:
lemma(z3.BoolVal(True), by=[a])
kd.lemma(z3.BoolVal(True), by=[a])


def test_calc():
x, y, z = z3.Ints("x y z")
l1 = axiom(x == y)
l2 = axiom(y == z)
l1 = kd.axiom(x == y)
l2 = kd.axiom(y == z)
Calc([], x).eq(y, by=[l1]).eq(z, by=[l2]).qed()


Expand Down Expand Up @@ -77,3 +77,7 @@ def test_simp():
def test_record():
foo = kd.notation.Record("foo", ("bar", z3.IntSort()), ("baz", z3.BoolSort()))
assert z3.simplify(foo.mk(1, True).bar).eq(z3.IntVal(1))


def test_seq():
ThSeq.induct(z3.IntSort(), lambda x: x == x)

0 comments on commit c02b952

Please sign in to comment.