Skip to content

Commit

Permalink
improved overloading, calc tactic
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Jul 25, 2024
1 parent a3334f6 commit 160810f
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 43 deletions.
30 changes: 19 additions & 11 deletions knuckledragger/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,27 @@ def lemma(
return __Proof(thm, by, False)


def axiom(thm: z3.BoolRef, reason=[]) -> Proof:
def axiom(thm: z3.BoolRef, by=[]) -> Proof:
"""Assert an axiom.
Axioms are necessary and useful. But you must use great care.
Args:
thm: The axiom to assert.
reason: A python object explaining why the axiom should exist. Often a string explaining the axiom.
by: A python object explaining why the axiom should exist. Often a string explaining the axiom.
"""
return __Proof(thm, reason, admit=False)
return __Proof(thm, by, admit=True)


__sig = {}
defn: dict[z3.FuncDecl, Proof] = {}
"""
defn holds definitional axioms for function symbols.
"""
z3.FuncDeclRef.defn = property(lambda self: defn[self])


def define(
name: str, args: list[z3.ExprRef], defn: z3.ExprRef
name: str, args: list[z3.ExprRef], defn_expr: z3.ExprRef
) -> tuple[z3.FuncDeclRef, __Proof]:
"""Define a non recursive definition. Useful for shorthand and abstraction.
Expand All @@ -108,13 +112,17 @@ def define(
Returns:
tuple[z3.FuncDeclRef, __Proof]: A tuple of the defined term and the proof of the definition.
"""
sorts = [arg.sort() for arg in args] + [defn.sort()]
sorts = [arg.sort() for arg in args] + [defn_expr.sort()]
f = z3.Function(name, *sorts)
def_ax = axiom(z3.ForAll(args, f(*args) == defn), reason="definition")
if len(args) > 0:
def_ax = axiom(z3.ForAll(args, f(*args) == defn_expr), by="definition")
else:
def_ax = axiom(f(*args) == defn_expr, by="definition")
# assert f not in __sig or __sig[f].eq( def_ax.thm) # Check for redefinitions. This is kind of painful. Hmm.
# Soft warning is more pleasant.
if f not in __sig or __sig[f].eq(def_ax.thm):
__sig[f] = def_ax.thm
if f not in defn or defn[f].thm.eq(def_ax.thm):
defn[f] = def_ax
else:
print("WARNING: Redefining function", f, "from", __sig[f], "to", def_ax.thm)
return f, def_ax
print("WARNING: Redefining function", f, "from", defn[f], "to", def_ax.thm)
defn[f] = def_ax
return f
19 changes: 19 additions & 0 deletions knuckledragger/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ def __call__(self, *args, **kwargs):
add = SortDispatch(z3.ArithRef.__add__)
z3.ExprRef.__add__ = lambda x, y: add(x, y)

sub = SortDispatch(z3.ArithRef.__sub__)
z3.ExprRef.__sub__ = lambda x, y: sub(x, y)

mul = SortDispatch(z3.ArithRef.__mul__)
z3.ExprRef.__mul__ = lambda x, y: mul(x, y)

and_ = SortDispatch()
z3.ExprRef.__and__ = lambda x, y: and_(x, y)

or_ = SortDispatch()
z3.ExprRef.__or__ = lambda x, y: or_(x, y)

lt = SortDispatch(z3.ArithRef.__lt__)
z3.ExprRef.__lt__ = lambda x, y: lt(x, y)

le = SortDispatch(z3.ArithRef.__le__)
z3.ExprRef.__le__ = lambda x, y: le(x, y)


"""
mul = SortDispatch(z3.ArithRef.__mul__)
z3.ExprRef.__mul__ = mul
Expand Down
11 changes: 6 additions & 5 deletions knuckledragger/theories/Nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from z3 import Datatype, ForAll, And, Implies, Consts, If, Function, IntSort, Ints
from knuckledragger import axiom, lemma
from knuckledragger import axiom, lemma, define
import knuckledragger.notation as notation

Z = IntSort()
Expand Down Expand Up @@ -41,13 +41,14 @@ def induct(P):

reflect = Function("reflect", Z, Nat)
# """reflect Z Nat maps an integer to a natural number"""
reflect_def = axiom(
ForAll([x], reflect(x) == If(x <= 0, Nat.zero, Nat.succ(reflect(x - 1))))
)
# reflect_def = axiom(
# ForAll([x], reflect(x) == If(x <= 0, Nat.zero, Nat.succ(reflect(x - 1))))
# )
reflect = define("reflect", [x], If(x <= 0, Nat.zero, Nat.succ(reflect(x - 1))))

reflect_reify = lemma(
ForAll([n], reflect(reify(n)) == n),
by=[reflect_def, reify_def, induct(lambda n: reflect(reify(n)) == n)],
by=[reflect.defn, reify_def, induct(lambda n: reflect(reify(n)) == n)],
)

reify_ge_0 = lemma(
Expand Down
14 changes: 9 additions & 5 deletions knuckledragger/theories/Real.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from z3 import RealSort, ForAll, Function, Reals, If
import z3
from z3 import ForAll, Function
from knuckledragger import lemma, axiom

R = RealSort()
x, y, z = Reals("x y z")
R = z3.RealSort()
x, y, z = z3.Reals("x y z")

plus = Function("plus", R, R, R)
plus_def = axiom(ForAll([x, y], plus(x, y) == x + y), "definition")

plus_0 = lemma(ForAll([x], plus(x, 0) == x), by=[plus_def])
plus_comm = lemma(ForAll([x, y], plus(x, y) == plus(y, x)), by=[plus_def])
plus_assoc = lemma(
ForAll([x, y, z], plus(x, plus(y, z)) == plus(plus(x, y), z)), by=[plus_def]
z3.ForAll([x, y, z], plus(x, plus(y, z)) == plus(plus(x, y), z)), by=[plus_def]
)

mul = Function("mul", R, R, R)
Expand All @@ -28,4 +29,7 @@
)

abs = Function("abs", R, R)
abs_def = axiom(ForAll([x], abs(x) == If(x >= 0, x, -x)), "definition")
abs_def = axiom(ForAll([x], abs(x) == z3.If(x >= 0, x, -x)), "definition")

RFun = z3.ArraySort(R, R)
RSeq = z3.ArraySort(z3.IntSort(), R)
55 changes: 35 additions & 20 deletions knuckledragger/utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,39 @@
from numpy import isin
from knuckledragger.kernel import lemma, is_proof
import z3
from operator import eq
import subprocess
import sys


def calc(*args, vars=None, by=[], op=eq):
class Calc:
"""
calc is for equational reasoning.
One can write a sequence of formulas interspersed with useful lemmas.
Inequational chaining can be done via op=lambda x,y: x <= y
"""

def thm(lhs, rhs):
if vars == None:
return op(lhs, rhs)
else:
return z3.ForAll(vars, op(lhs, rhs))

lemmas = []
local_by = []
lhs = args[0]
for arg in args[1:]:
if is_proof(arg):
local_by.append(arg)
def __init__(self, vars, lhs):
# TODO: hyps=None for conditional rewriting. assumpt, assume=[]
self.vars = vars
self.terms = [lhs]
self.lemmas = []

def ForAll(self, body):
if len(self.vars) == 0:
return body
else:
lemmas.append(lemma(thm(lhs, arg), by=by + local_by))
lhs = arg
local_by = []
return lemma(thm(args[0], args[-1]), by=by + lemmas)
return z3.ForAll(self.vars, body)

def eq(self, rhs, by=[]):
self.lemmas.append(lemma(self.ForAll(self.terms[-1] == rhs), by=by))
self.terms.append(rhs)
return self

# TODO: lt, le, gt, ge chaining. Or custom op.

def __repr__(self):
return "... = " + repr(self.terms[-1])

def qed(self):
return lemma(self.ForAll(self.terms[0] == self.terms[-1]), by=self.lemmas)


def lemma_smt(thm: z3.BoolRef, by=[], sig=[]) -> list[str]:
Expand Down Expand Up @@ -166,3 +171,13 @@ def lemma_tptp(thm: z3.BoolRef, by=[], sig=[], timeout=None, command=None):
capture_output=True,
)
return res


def lemma_db():
"""Scan all modules for Proof objects and return a dictionary of them."""
db = {}
for modname, mod in sys.modules.items():
thms = {name: thm for name, thm in mod.__dict__.items() if is_proof(thm)}
if len(thms) > 0:
db[modname] = thms
return db
4 changes: 2 additions & 2 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def test_explosion():
lemma(BoolVal(True), by=[a])


from knuckledragger.utils import calc
from knuckledragger.utils import Calc


def test_calc():
x, y, z = Ints("x y z")
l1 = axiom(x == y)
l2 = axiom(y == z)
calc(x, l1, y, l2, z)
Calc([], x).eq(y, by=[l1]).eq(z, by=[l2]).qed()


def test_tptp():
Expand Down

0 comments on commit 160810f

Please sign in to comment.