Skip to content

Commit

Permalink
simp, define_fix, QForAll
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Aug 4, 2024
1 parent 9a17444 commit 547dccd
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 47 deletions.
3 changes: 3 additions & 0 deletions knuckledragger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from . import kernel
from . import notation
from . import tactics

lemma = kernel.lemma
axiom = kernel.axiom
define = kernel.define

QForAll = notation.QForAll
QExists = notation.QExists

Calc = tactics.Calc
33 changes: 30 additions & 3 deletions knuckledragger/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ class LemmaError(Exception):


def lemma(
thm: z3.BoolRef, by: list[Proof] = [], admit=False, timeout=1000, dump=False
thm: z3.BoolRef,
by: list[Proof] = [],
admit=False,
timeout=1000,
dump=False,
solver=z3.Solver,
) -> Proof:
"""Prove a theorem using a list of previously proved lemmas.
Expand All @@ -58,7 +63,7 @@ def lemma(
logger.warn("Admitting lemma {}".format(thm))
return __Proof(thm, by, True)
else:
s = z3.Solver()
s = solver()
s.set("timeout", timeout)
for n, p in enumerate(by):
if not isinstance(p, __Proof):
Expand Down Expand Up @@ -108,7 +113,9 @@ class Defn:


def define(name: str, args: list[z3.ExprRef], body: z3.ExprRef) -> z3.FuncDeclRef:
"""Define a non recursive definition. Useful for shorthand and abstraction.
"""
Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions.
TODO: Check for bad circularity, record dependencies
Args:
name: The name of the term to define.
Expand All @@ -133,3 +140,23 @@ def define(name: str, args: list[z3.ExprRef], body: z3.ExprRef) -> z3.FuncDeclRe
print("WARNING: Redefining function", f, "from", defns[f].ax, "to", def_ax.thm)
defns[f] = defn
return f


def define_fix(name: str, args: list[z3.ExprRef], retsort, fix_lam) -> z3.FuncDeclRef:
"""
Define a recursive definition.
"""
sorts = [arg.sort() for arg in args]
sorts.append(retsort)
f = z3.Function(name, *sorts)

# wrapper to record calls
calls = set()

def record_f(*args):
calls.add(args)
return f(*args)

defn = define(name, args, fix_lam(record_f))
# TODO: check for well foundedness/termination, custom induction principle.
return defn
53 changes: 43 additions & 10 deletions knuckledragger/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,45 @@
"""

import z3

import knuckledragger as kd

z3.BoolRef.__and__ = lambda self, other: z3.And(self, other)
z3.BoolRef.__or__ = lambda self, other: z3.Or(self, other)
z3.BoolRef.__invert__ = lambda self: z3.Not(self)


def QForAll(vars, hyp, conc):
def QForAll(vs, *hyp_conc):
"""Quantified ForAll
Shorthand for `ForAll(vars, Implies(hyp, conc))`
"""
return z3.ForAll(vars, z3.Implies(hyp, conc))
Shorthand for `ForAll(vars, Implies(And(hyp[0], hyp[1], ...), conc))`
If variables have a property `wf` attached, this is added as a hypothesis.
def QExists(vars, hyp, conc):
"""
conc = hyp_conc[-1]
hyps = hyp_conc[:-1]
hyps = [v.wf for v in vs if hasattr(v, "wf")] + list(hyps)
if len(hyps) == 0:
return z3.ForAll(vs, conc)
elif len(hyps) == 1:
return z3.ForAll(vs, z3.Implies(hyps[0], conc))
else:
hyp = z3.And(hyps)
return z3.ForAll(vs, z3.Implies(hyp, conc))


def QExists(vs, *concs):
"""Quantified Exists
Shorthand for `Exists(vars, And(hyp, conc))`
Shorthand for `ForAll(vars, And(conc[0], conc[1], ...))`
If variables have a property `wf` attached, this is anded into the properties.
"""
return z3.Exists(vars, z3.And(hyp, conc))
concs = [v.wf for v in vs if hasattr(v, "wf")] + list(concs)
if len(concs) == 1:
z3.Exists(vars, concs[0])
else:
z3.Exists(vars, z3.And(concs))


z3.SortRef.__rshift__ = lambda self, other: z3.ArraySort(self, other)
Expand All @@ -41,16 +57,23 @@ class SortDispatch:
It allows for dispatching on the sort of the first argument
"""

def __init__(self, default=None):
def __init__(self, default=None, name=None):
self.methods = {}
self.default = default
self.name = name

def register(self, sort, func):
self.methods[sort] = func

def __call__(self, *args, **kwargs):
return self.methods.get(args[0].sort(), self.default)(*args, **kwargs)

def define(self, args, body):
assert isinstance(self.name, str)
defn = kd.define(self.name, args, body)
self.register(args[0].sort(), defn)
return defn


add = SortDispatch(z3.ArithRef.__add__)
z3.ExprRef.__add__ = lambda x, y: add(x, y)
Expand Down Expand Up @@ -97,3 +120,13 @@ def lookup_cons_recog(self, k):


z3.DatatypeRef.__getattr__ = lookup_cons_recog


def Record(name, *fields):
"""
Define a record datatype
"""
rec = z3.Datatype(name)
rec.declare("mk", *fields)
rec = rec.create()
return rec
34 changes: 34 additions & 0 deletions knuckledragger/tactics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import knuckledragger as kd
import z3


class Calc:
"""
calc is for equational reasoning.
One can write a sequence of formulas interspersed with useful lemmas.
"""

def __init__(self, vars, lhs):
# TODO: hyps=None for conditional rewriting. assumpt, assume=[]
self.vars = vars
self.terms = [lhs]
self.lemmas = []

def ForAll(self, body):
if len(self.vars) == 0:
return body
else:
return z3.ForAll(self.vars, body)

def eq(self, rhs, by=[]):
self.lemmas.append(kd.lemma(self.ForAll(self.terms[-1] == rhs), by=by))
self.terms.append(rhs)
return self

# TODO: lt, le, gt, ge chaining. Or custom op.

def __repr__(self):
return "... = " + repr(self.terms[-1])

def qed(self):
return kd.lemma(self.ForAll(self.terms[0] == self.terms[-1]), by=self.lemmas)
53 changes: 21 additions & 32 deletions knuckledragger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,27 @@
import z3
import subprocess
import sys
import knuckledragger as kd


class Calc:
"""
calc is for equational reasoning.
One can write a sequence of formulas interspersed with useful lemmas.
"""

def __init__(self, vars, lhs):
# TODO: hyps=None for conditional rewriting. assumpt, assume=[]
self.vars = vars
self.terms = [lhs]
self.lemmas = []

def ForAll(self, body):
if len(self.vars) == 0:
return body
else:
return z3.ForAll(self.vars, body)

def eq(self, rhs, by=[]):
self.lemmas.append(lemma(self.ForAll(self.terms[-1] == rhs), by=by))
self.terms.append(rhs)
return self

# TODO: lt, le, gt, ge chaining. Or custom op.
def simp(t: z3.ExprRef) -> z3.ExprRef:
expr = z3.FreshConst(t.sort(), prefix="knuckle_goal")
G = z3.Goal()
for v in kd.kernel.defns.values():
G.add(v.ax.thm)
G.add(expr == t)
G2 = z3.Then(z3.Tactic("demodulator"), z3.Tactic("simplify")).apply(G)[0]
return G2[len(G2) - 1].children()[1]

def __repr__(self):
return "... = " + repr(self.terms[-1])

def qed(self):
return lemma(self.ForAll(self.terms[0] == self.terms[-1]), by=self.lemmas)
def simp2(t: z3.ExprRef) -> z3.ExprRef:
expr = z3.FreshConst(t.sort(), prefix="knuckle_goal")
G = z3.Goal()
for v in kd.kernel.defns.values():
G.add(v.ax.thm)
G.add(expr == t)
G2 = z3.Tactic("elim-predicates").apply(G)[0]
return G2[len(G2) - 1].children()[1]


def lemma_smt(thm: z3.BoolRef, by=[], sig=[]) -> list[str]:
Expand Down Expand Up @@ -73,7 +62,7 @@ def expr_to_tptp(expr: z3.ExprRef):
elif isinstance(expr, z3.QuantifierRef):
vars, body = open_binder(expr)
body = expr_to_tptp(body)
vs = ", ".join([v.sexpr() + ":" + z3_sort_tptp(v.sort()) for v in vars])
vs = ", ".join([v.sexpr() + ":" + sort_to_tptp(v.sort()) for v in vars])
if expr.is_forall():
return f"(![{vs}] : {body})"
elif expr.is_exists():
Expand Down Expand Up @@ -123,7 +112,7 @@ def expr_to_tptp(expr: z3.ExprRef):
z3.ExprRef.tptp = expr_to_tptp


def z3_sort_tptp(sort: z3.SortRef):
def sort_to_tptp(sort: z3.SortRef):
name = sort.name()
if name == "Int":
return "$int"
Expand All @@ -133,13 +122,13 @@ def z3_sort_tptp(sort: z3.SortRef):
return "$real"
elif name == "Array":
return "({} > {})".format(
z3_sort_tptp(sort.domain()), z3_sort_tptp(sort.range())
sort_to_tptp(sort.domain()), sort_to_tptp(sort.range())
)
else:
return name.lower()


z3.SortRef.tptp = z3_sort_tptp
z3.SortRef.tptp = sort_to_tptp


def lemma_tptp(thm: z3.BoolRef, by=[], sig=[], timeout=None, command=None):
Expand Down
29 changes: 27 additions & 2 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from knuckledragger import lemma, axiom
import z3

import knuckledragger as kd
import knuckledragger.theories.Nat
import knuckledragger.theories.Int
import knuckledragger.theories.Real
import knuckledragger.theories.Real as R

import knuckledragger.theories.List
import knuckledragger.theories.Seq

from knuckledragger.utils import Calc
from knuckledragger import Calc
import knuckledragger.utils


def test_true_infer():
Expand Down Expand Up @@ -52,3 +54,26 @@ def test_datatype():
x = Foo.foo(1, True)
assert z3.simplify(x.bar).eq(z3.IntVal(1))
assert z3.simplify(x.baz).eq(z3.BoolVal(True))


def test_qforall():
x, y = z3.Reals("x y")
assert kd.QForAll([x], x > 0).eq(z3.ForAll([x], x > 0))
assert kd.QForAll([x], x == 10, x == 14, x > 0).eq(
z3.ForAll([x], z3.Implies(z3.And(x == 10, x == 14), x > 0))
)
assert kd.QForAll([x, y], x > 0, y > 0).eq(
z3.ForAll([x, y], z3.Implies(x > 0, y > 0))
)
x.wf = x >= 0
assert kd.QForAll([x], x == 14).eq(z3.ForAll([x], z3.Implies(x >= 0, x == 14)))


def test_simp():
assert kd.utils.simp(R.max(8, R.max(3, 4))).eq(z3.RealVal(8))
assert kd.utils.simp2(R.max(8, R.max(3, 4))).eq(z3.RealVal(8))


def test_record():
foo = kd.notation.Record("foo", ("bar", z3.IntSort()), ("baz", z3.BoolSort()))
assert z3.simplify(foo.mk(1, True).bar).eq(z3.IntVal(1))

0 comments on commit 547dccd

Please sign in to comment.