Skip to content

Commit

Permalink
added datatype accessor notation
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Jul 31, 2024
1 parent 160810f commit 7aefec9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 12 deletions.
1 change: 1 addition & 0 deletions knuckledragger/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .kernel import lemma, axiom, define
from . import notation
31 changes: 19 additions & 12 deletions knuckledragger/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,22 @@ def axiom(thm: z3.BoolRef, by=[]) -> Proof:
return __Proof(thm, by, admit=True)


defn: dict[z3.FuncDecl, Proof] = {}
@dataclass(frozen=True)
class Defn:
name: str
args: list[z3.ExprRef]
body: z3.ExprRef
ax: Proof


defns: dict[z3.FuncDecl, Defn] = {}
"""
defn holds definitional axioms for function symbols.
"""
z3.FuncDeclRef.defn = property(lambda self: defn[self])
z3.FuncDeclRef.defn = property(lambda self: defns[self].ax)


def define(
name: str, args: list[z3.ExprRef], defn_expr: z3.ExprRef
) -> tuple[z3.FuncDeclRef, __Proof]:
def define(name: str, args: list[z3.ExprRef], body: z3.ExprRef) -> z3.FuncDeclRef:
"""Define a non recursive definition. Useful for shorthand and abstraction.
Args:
Expand All @@ -112,17 +118,18 @@ def define(
Returns:
tuple[z3.FuncDeclRef, __Proof]: A tuple of the defined term and the proof of the definition.
"""
sorts = [arg.sort() for arg in args] + [defn_expr.sort()]
sorts = [arg.sort() for arg in args] + [body.sort()]
f = z3.Function(name, *sorts)
if len(args) > 0:
def_ax = axiom(z3.ForAll(args, f(*args) == defn_expr), by="definition")
def_ax = axiom(z3.ForAll(args, f(*args) == body), by="definition")
else:
def_ax = axiom(f(*args) == defn_expr, by="definition")
def_ax = axiom(f(*args) == body, by="definition")
# assert f not in __sig or __sig[f].eq( def_ax.thm) # Check for redefinitions. This is kind of painful. Hmm.
# Soft warning is more pleasant.
if f not in defn or defn[f].thm.eq(def_ax.thm):
defn[f] = def_ax
defn = Defn(name, args, body, def_ax)
if f not in defns or defns[f].ax.thm.eq(def_ax.thm):
defns[f] = defn
else:
print("WARNING: Redefining function", f, "from", defn[f], "to", def_ax.thm)
defn[f] = def_ax
print("WARNING: Redefining function", f, "from", defns[f].ax, "to", def_ax.thm)
defns[f] = defn
return f
27 changes: 27 additions & 0 deletions knuckledragger/notation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Importing this module will add some syntactic sugar to Z3.
- Expr overload by single dispatch
- Bool supports `&`, `|`, `~`
- Sorts supports `>>` for ArraySort
- Datatypes support accessor notation
"""

import z3
Expand Down Expand Up @@ -59,6 +61,9 @@ def __call__(self, *args, **kwargs):
mul = SortDispatch(z3.ArithRef.__mul__)
z3.ExprRef.__mul__ = lambda x, y: mul(x, y)

div = SortDispatch(z3.ArithRef.__div__)
z3.ExprRef.__truediv__ = lambda x, y: div(x, y)

and_ = SortDispatch()
z3.ExprRef.__and__ = lambda x, y: and_(x, y)

Expand All @@ -85,3 +90,25 @@ def __call__(self, *args, **kwargs):
le = SortDispatch()
z3.ExprRef.__le__ = le
"""


def lookup_cons_recog(self, k):
"""
Enable "dot" syntax for fields of z3 datatypes
"""
sort = self.sort()
recog = "is_" == k[:3] if len(k) > 3 else False
for i in range(sort.num_constructors()):
cons = sort.constructor(i)
if recog:
if cons.name() == k[3:]:
recog = sort.recognizer(i)
return recog(self)
else:
for j in range(cons.arity()):
acc = sort.accessor(i, j)
if acc.name() == k:
return acc(self)


z3.DatatypeRef.__getattr__ = lookup_cons_recog
9 changes: 9 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,12 @@ def test_tptp():
ArraySort(ArraySort(BoolSort(), IntSort()), IntSort()).tptp()
== "(($o > $int) > $int)"
)


def test_datatype():
Foo = Datatype("Foo")
Foo.declare("foo", ("bar", IntSort()), ("baz", BoolSort()))
Foo = Foo.create()
x = Foo.foo(1, True)
assert z3.simplify(x.bar).eq(z3.IntVal(1))
assert z3.simplify(x.baz).eq(z3.BoolVal(True))

0 comments on commit 7aefec9

Please sign in to comment.