Skip to content

Commit

Permalink
refactored cond, recrusive well formed in record, Lemma forward tactic
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Aug 24, 2024
1 parent 4829e90 commit 3a4bf86
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 45 deletions.
2 changes: 1 addition & 1 deletion knuckledragger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

QForAll = notation.QForAll
QExists = notation.QExists
Cond = notation.Cond
cond = notation.cond
Record = notation.Record

Calc = tactics.Calc
Expand Down
99 changes: 57 additions & 42 deletions knuckledragger/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SortDispatch:
It allows for dispatching on the sort of the first argument
"""

def __init__(self, default=None, name=None):
def __init__(self, name=None, default=None):
self.methods = {}
self.default = default
self.name = name
Expand All @@ -41,19 +41,22 @@ def define(self, args, body):
return defn


add = SortDispatch(smt.ArithRef.__add__, name="add")
add = SortDispatch(name="add")
smt.ExprRef.__add__ = lambda x, y: add(x, y)

sub = SortDispatch(smt.ArithRef.__sub__, name="sub")
radd = SortDispatch(name="radd")
smt.ExprRef.__radd__ = lambda x, y: radd(x, y)

sub = SortDispatch(name="sub")
smt.ExprRef.__sub__ = lambda x, y: sub(x, y)

mul = SortDispatch(smt.ArithRef.__mul__, name="mul")
mul = SortDispatch(name="mul")
smt.ExprRef.__mul__ = lambda x, y: mul(x, y)

neg = SortDispatch(smt.ArithRef.__neg__, name="neg")
neg = SortDispatch(name="neg")
smt.ExprRef.__neg__ = lambda x: neg(x)

div = SortDispatch(smt.ArithRef.__div__, name="div_")
div = SortDispatch(name="div_")
smt.ExprRef.__truediv__ = lambda x, y: div(x, y)

and_ = SortDispatch()
Expand All @@ -62,10 +65,10 @@ def define(self, args, body):
or_ = SortDispatch()
smt.ExprRef.__or__ = lambda x, y: or_(x, y)

lt = SortDispatch(smt.ArithRef.__lt__, name="lt")
lt = SortDispatch(name="lt")
smt.ExprRef.__lt__ = lambda x, y: lt(x, y)

le = SortDispatch(smt.ArithRef.__le__, name="le")
le = SortDispatch(name="le")
smt.ExprRef.__le__ = lambda x, y: le(x, y)


Expand Down Expand Up @@ -107,7 +110,7 @@ def QExists(vs, *concs):
smt.Exists(vars, smt.And(concs))


def lookup_cons_recog(self, k):
def _lookup_constructor_recog(self, k):
"""
Enable "dot" syntax for fields of smt datatypes
"""
Expand All @@ -126,7 +129,7 @@ def lookup_cons_recog(self, k):
return acc(self)


smt.DatatypeRef.__getattr__ = lookup_cons_recog
smt.DatatypeRef.__getattr__ = _lookup_constructor_recog


def datatype_call(self, *args):
Expand All @@ -149,49 +152,61 @@ def Record(name, *fields, pred=None):
rec.declare(name, *fields)
rec = rec.create()
rec.mk = rec.constructor(0)
if pred is not None:
wf_cond = [n for (n, (_, sort)) in enumerate(fields) if sort in wf.methods]
if pred is None and len(wf_cond) == 1:
acc = rec.accessor(0, wf_cond[0])
wf.register(rec, lambda x: rec.accessor(0, acc(x).wf()))
elif pred is None and len(wf_cond) > 1:
wf.register(
rec, lambda x: smt.And(*[rec.accessor(0, n)(x).wf() for n in wf_cond])
)
elif pred is not None and len(wf_cond) == 0:
wf.register(rec, lambda x: pred(x))
elif pred is not None and len(wf_cond) > 0:
wf.register(
rec,
lambda x: smt.And(pred(x), *[rec.accessor(0, n)(x).wf() for n in wf_cond]),
)

return rec


class Cond:
"""
Cond is a useful way to build up giant if-then-else expressions.
"""
def cond(*cases, default=None) -> smt.ExprRef:
sort = cases[0][1].sort()
if default is None:
s = smt.Solver()
s.add(smt.Not(smt.Or([c for c, t in cases])))
res = s.check()
if res == smt.sat:
raise Exception("Cases not exhaustive. Fix or give default", s.model())
elif res != smt.unsat:
raise Exception("Solver error. Give default", res)
else:
default = smt.FreshConst(sort, prefix="unreachable")
acc = default
for c, t in reversed(cases):
if t.sort() != sort:
raise Exception("Sort mismatch in cond", t, sort)
acc = smt.If(c, t, acc)
return acc


class Cond:
def __init__(self):
self.clauses = []
self.cur_case = None
self.other = None
self.sort = None
self.cases = []
self.default = None

def when(self, c: smt.BoolRef) -> "Cond":
assert self.cur_case is None
assert isinstance(c, smt.BoolRef)
self.cur_case = c
def when(self, cond: smt.BoolRef):
self.cases.append((cond, None))
return self

def then(self, e: smt.ExprRef) -> "Cond":
assert self.cur_case is not None
if self.sort is not None:
assert e.sort() == self.sort
else:
self.sort = e.sort()
self.clauses.append((self.cur_case, e))
self.cur_case = None
def then(self, thn: smt.ExprRef):
self.cases[-1] = (self.cases[-1][0], thn)
return self

def otherwise(self, e: smt.ExprRef) -> smt.ExprRef:
assert self.default is None
assert self.sort == e.sort()
self.default = e
return self.expr()
def otherwise(self, els: smt.ExprRef):
self.default = els
return self

def expr(self) -> smt.ExprRef:
assert self.default is not None
assert self.cur_case is None
acc = self.default
for c, e in reversed(self.clauses):
acc = smt.If(c, e, acc)
return acc
return cond(*self.cases, default=self.default)
27 changes: 27 additions & 0 deletions knuckledragger/tactics.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,30 @@ def lemma(
return kd.kernel.lemma(
thm, by, admit=admit, timeout=timeout, dump=dump, solver=solver
)


class Lemma:
# Isar style forward proof
def __init__(self, goal):
self.goal = goal
self.lemmas = []
self.vars = []
self.hyps = []

def intro(self, vars): # fix
self.vars.extend(vars)
return self

def assume(self, hyps):
self.hyps.extend(hyps)
return self

def _wrap(self, form):
return smt.ForAll(self.vars, smt.Implies(smt.And(self.hyps), form))

def have(self, conc, **kwargs):
self.lemmas.append(lemma(self._wrap(conc), **kwargs))
return self

def qed(self):
return lemma(self.goal, by=self.lemmas)
27 changes: 26 additions & 1 deletion knuckledragger/theories/Vec.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
import knuckledragger as kd
import knuckledragger.theories.Real as R

norm2 = kd.notation.SortDispatch(name="norm2")
dot = kd.notation.SortDispatch(name="dot")

Vec2 = kd.Record("Vec2", ("x", kd.R), ("y", kd.R))
u, v = kd.smt.Consts("u v", Vec2)
kd.notation.add.define([u, v], Vec2(u.x + v.x, u.y + v.y))
kd.notation.sub.define([u, v], Vec2(u.x - v.x, u.y - v.y))

norm2 = kd.define("norm2", [u], u.x * u.x + u.y * u.y)
Vec2.vzero = Vec2(0, 0)
Vec2.dot = dot.define([u, v], u.x * v.x + u.y * v.y)
Vec2.norm2 = norm2.define([u], dot(u, u))


Vec2.norm_pos = kd.lemma(
kd.smt.ForAll([u], norm2(u) >= 0), by=[Vec2.norm2.defn, Vec2.dot.defn]
)
Vec2.norm_zero = kd.lemma(
kd.smt.ForAll([u], (norm2(u) == 0) == (u == Vec2.vzero)),
by=[Vec2.norm2.defn, Vec2.dot.defn],
)

dist = kd.define("dist", [u, v], R.sqrt(norm2(u - v)))

# Vec2.triangle = norm2(u - v) <= norm2(u) + norm2(v)

Vec3 = kd.Record("Vec3", ("x", kd.R), ("y", kd.R), ("z", kd.R))
u, v = kd.smt.Consts("u v", Vec3)
kd.notation.add.define([u, v], Vec3(u.x + v.x, u.y + v.y, u.z + v.z))
kd.notation.sub.define([u, v], Vec3(u.x - v.x, u.y - v.y, u.z - v.z))

norm2.define([u], u.x * u.x + u.y * u.y + u.z * u.z)
2 changes: 2 additions & 0 deletions knuckledragger/theories/zf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import knuckledragger as kd
import knuckledragger.smt as smt
27 changes: 26 additions & 1 deletion tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_seq():
ThSeq.induct(smt.IntSort(), lambda x: x == x)


def test_cons():
"""
def test_cond():
c = kd.notation.Cond()
assert (
c.when(smt.BoolVal(True))
Expand All @@ -114,6 +115,30 @@ def test_cons():
)
)
)
"""


def test_cond():
x = smt.Real("x")
assert kd.cond(
(x > 0, 3 * x), (x < 0, 2 * x), (x == 0, 5 * x), default=smt.Real("undefined")
).eq(
smt.If(
x > 0,
3 * x,
smt.If(x < 0, 2 * x, smt.If(x == 0, 5 * x, smt.Real("undefined"))),
)
)
with pytest.raises(Exception) as _:
kd.cond((x < 0, 2 * x), (x > 0, 3 * x))


def test_Lemma():
x = smt.Int("x")
l = kd.tactics.Lemma(x != x + 1)
l.intro([smt.Int("x")])
l.have(x != x + 1)
l.qed()


def test_match():
Expand Down

0 comments on commit 3a4bf86

Please sign in to comment.