Skip to content

Commit

Permalink
filling out reals. no redefinition of records.
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Aug 30, 2024
1 parent 84b3571 commit efa0b33
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 36 deletions.
5 changes: 5 additions & 0 deletions kdrag/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,15 @@ def datatype_call(self, *args):

smt.DatatypeSortRef.__call__ = datatype_call

records = {}


def Record(name, *fields, pred=None):
"""
Define a record datatype
"""
if name in records:
raise Exception("Record already defined", name)
rec = smt.Datatype(name)
rec.declare(name, *fields)
rec = rec.create()
Expand All @@ -176,6 +180,7 @@ def Record(name, *fields, pred=None):
rec,
lambda x: smt.And(pred(x), *[rec.accessor(0, n)(x).wf() for n in wf_cond]),
)
records[name] = rec

return rec

Expand Down
6 changes: 3 additions & 3 deletions kdrag/theories/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def induct_nat(P):
)


Nat = kd.Record("Nat", ("val", Z))
n, m, k = smt.Consts("n m k", Nat)
kd.notation.wf.register(Nat, lambda x: x.val >= 0)
NatI = kd.Record("NatI", ("val", Z))
n, m, k = smt.Consts("n m k", NatI)
kd.notation.wf.register(NatI, lambda x: x.val >= 0)
113 changes: 84 additions & 29 deletions kdrag/theories/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@
kd.notation.mul.define([f, g], smt.Lambda([x], f[x] * g[x]))
kd.notation.div.define([f, g], smt.Lambda([x], f[x] / g[x]))

add = kd.define("add", [x, y], x + y)

plus = Function("plus", R, R, R)
plus_def = axiom(ForAll([x, y], plus(x, y) == x + y), "definition")

plus_0 = lemma(ForAll([x], plus(x, 0) == x), by=[plus_def])
plus_comm = lemma(ForAll([x, y], plus(x, y) == plus(y, x)), by=[plus_def])
plus_assoc = lemma(
smt.ForAll([x, y, z], plus(x, plus(y, z)) == plus(plus(x, y), z)), by=[plus_def]
add_0 = lemma(ForAll([x], add(x, 0) == x), by=[add.defn])
add_comm = lemma(ForAll([x, y], add(x, y) == add(y, x)), by=[add.defn])
add_assoc = lemma(
smt.ForAll([x, y, z], add(x, add(y, z)) == add(add(x, y), z)), by=[add.defn]
)

mul = Function("mul", R, R, R)
Expand All @@ -35,24 +33,33 @@
ForAll([x, y, z], mul(x, mul(y, z)) == mul(mul(x, y), z)), by=[mul_def], admit=True
)
mul_distrib = lemma(
ForAll([x, y, z], mul(x, plus(y, z)) == plus(mul(x, y), mul(x, z))),
by=[mul_def, plus_def],
ForAll([x, y, z], mul(x, add(y, z)) == add(mul(x, y), mul(x, z))),
by=[mul_def, add.defn],
)


abs = kd.define("absR", [x], smt.If(x >= 0, x, -x))
sgn = kd.define("sgn", [x], smt.If(x > 0, 1, smt.If(x < 0, -1, 0)))

sgn_abs = kd.lemma(ForAll([x], x == abs(x) * sgn(x)), by=[abs.defn, sgn.defn])

abs_le = kd.lemma(
ForAll([x, y], (abs(x) <= y) == smt.And(-y <= x, x <= y)), by=[abs.defn]
)
abs_idem = kd.lemma(ForAll([x], abs(abs(x)) == abs(x)), by=[abs.defn])
abs_neg = kd.lemma(ForAll([x], abs(-x) == abs(x)), by=[abs.defn])
abs_ge_0 = kd.lemma(ForAll([x], abs(x) >= 0), by=[abs.defn])
abs_pos = kd.lemma(ForAll([x], abs(x) >= 0), by=[abs.defn])
abs_add = kd.lemma(ForAll([x, y], abs(x + y) <= abs(x) + abs(y)), by=[abs.defn])
abs_mul = kd.lemma(ForAll([x, y], abs(x * y) == abs(x) * abs(y)), by=[abs.defn])
abs_triangle = kd.lemma(
ForAll([x, y, z], abs(x - y) <= abs(x - z) + abs(z - y)), by=[abs.defn]
)


nonneg = kd.define("nonneg", [x], abs(x) == x)
nonneg_ge_0 = kd.lemma(ForAll([x], nonneg(x) == (x >= 0)), by=[nonneg.defn, abs.defn])

max = kd.define("max", [x, y], smt.If(x >= y, x, y))
max_le = kd.lemma(ForAll([x, y], (x <= y) == (max(x, y) == y)), by=[max.defn])
max_comm = kd.lemma(ForAll([x, y], max(x, y) == max(y, x)), by=[max.defn])
max_assoc = kd.lemma(
ForAll([x, y, z], max(x, max(y, z)) == max(max(x, y), z)), by=[max.defn]
Expand All @@ -62,13 +69,46 @@
max_ge_2 = kd.lemma(ForAll([x, y], max(x, y) >= y), by=[max.defn])

min = kd.define("min", [x, y], smt.If(x <= y, x, y))
min_le = kd.lemma(ForAll([x, y], (x <= y) == (min(x, y) == x)), by=[min.defn])
min_comm = kd.lemma(ForAll([x, y], min(x, y) == min(y, x)), by=[min.defn])
min_assoc = kd.lemma(
ForAll([x, y, z], min(x, min(y, z)) == min(min(x, y), z)), by=[min.defn]
)
min_idem = kd.lemma(ForAll([x], min(x, x) == x), by=[min.defn])
min_le = kd.lemma(ForAll([x, y], min(x, y) <= x), by=[min.defn])
min_le_2 = kd.lemma(ForAll([x, y], min(x, y) <= y), by=[min.defn])
min_le_2 = kd.lemma(ForAll([x, y], min(x, y) <= x), by=[min.defn])
min_le_3 = kd.lemma(ForAll([x, y], min(x, y) <= y), by=[min.defn])


n, m = smt.Ints("n m")
to_real_add = kd.lemma(
ForAll([n, m], smt.ToReal(n + m) == smt.ToReal(n) + smt.ToReal(m))
)
to_real_sub = kd.lemma(
ForAll([n, m], smt.ToReal(n - m) == smt.ToReal(n) - smt.ToReal(m))
)
to_real_mul = kd.lemma(
ForAll([n, m], smt.ToReal(n * m) == smt.ToReal(n) * smt.ToReal(m))
)
to_real_neg = kd.lemma(ForAll([n], smt.ToReal(-n) == -smt.ToReal(n)))
to_real_inj = kd.lemma(ForAll([n, m], (smt.ToReal(n) == smt.ToReal(m)) == (n == m)))
to_real_mono = kd.lemma(ForAll([n, m], (n < m) == (smt.ToReal(n) < smt.ToReal(m))))

floor = kd.define("floor", [x], smt.ToReal(smt.ToInt(x)))
n = smt.Int("n")
int_real_galois_lt = kd.lemma(ForAll([x, n], (x < smt.ToReal(n)) == (smt.ToInt(x) < n)))
int_real_galois_le = kd.lemma(
ForAll([x, n], (smt.ToReal(n) <= x) == (n <= smt.ToInt(x)))
)

_2 = kd.lemma(ForAll([x], smt.ToInt(floor(x)) == smt.ToInt(x)), by=[floor.defn])
floor_idem = kd.lemma(ForAll([x], floor(floor(x)) == floor(x)), by=[floor.defn, _2])
floor_le = kd.lemma(ForAll([x], floor(x) <= x), by=[floor.defn])
floor_gt = kd.lemma(ForAll([x], x < floor(x) + 1), by=[floor.defn])

c = kd.Calc([n, x], smt.ToReal(n) <= x)
c.eq(n <= smt.ToInt(x))
c.eq(smt.ToReal(n) <= floor(x), by=[floor.defn])
floor_minint = c.qed(defns=False)


pow = kd.define("pow", [x, y], x**y)
Expand All @@ -80,15 +120,25 @@
kd.kernel.lemma(smt.Implies(x > 0, x**0 == 1))
# pow_pos = kd.lemma(kd.QForAll([x, y], x > 0, pow(x, y) > 0), by=[pow.defn])

# Inverses
sqrt = kd.define("sqrt", [x], pow(x, 0.5))
sqr = kd.define("sqr", [x], x * x)


sqrt = kd.define("sqrt", [x], x**0.5)
_1 = kd.kernel.lemma(smt.Implies(x >= 0, x**0.5 >= 0))
sqrt_pos = kd.lemma(kd.QForAll([x], x >= 0, sqrt(x) >= 0), by=[_1], admit=True)
sqrt_define = kd.lemma(smt.ForAll([x], sqrt(x) == x**0.5), by=[sqrt.defn, pow.defn])
_1 = kd.kernel.lemma(smt.Implies(x >= 0, (x**0.5) ** 2 == x)) # forall messes it up?
sqrt_square = kd.lemma(
kd.QForAll([x], x >= 0, sqrt(x) ** 2 == x),
by=[sqrt_define, sqrt.defn, _1],
admit=True,
)
sqr_sqrt = kd.lemma(
kd.QForAll([x], x >= 0, sqr(sqrt(x)) == x), by=[sqrt_square, sqr.defn]
)
_1 = kd.kernel.lemma(smt.Implies(x >= 0, (x**2) ** 0.5 == x))
sqrt_sqr = kd.lemma(kd.QForAll([x], x >= 0, sqrt(sqr(x)) == x), by=[_1], admit=True)


exp = smt.Function("exp", kd.R, kd.R)
exp_add = kd.axiom(smt.ForAll([x, y], exp(x + y) == exp(x) * exp(y)))
Expand All @@ -104,22 +154,27 @@
exp_inv = kd.lemma(smt.ForAll([x], exp(-x) == 1 / exp(x)), by=[exp_div])


floor = kd.define("floor", [x], smt.ToReal(smt.ToInt(x)))
n = smt.Int("n")
int_real_galois_lt = kd.lemma(ForAll([x, n], (x < smt.ToReal(n)) == (smt.ToInt(x) < n)))
int_real_galois_le = kd.lemma(
ForAll([x, n], (smt.ToReal(n) <= x) == (n <= smt.ToInt(x)))
)
cos = smt.Function("cos", kd.R, kd.R)
sin = smt.Function("sin", kd.R, kd.R)

_2 = kd.lemma(ForAll([x], smt.ToInt(floor(x)) == smt.ToInt(x)), by=[floor.defn])
floor_idem = kd.lemma(ForAll([x], floor(floor(x)) == floor(x)), by=[floor.defn, _2])
floor_le = kd.lemma(ForAll([x], floor(x) <= x), by=[floor.defn])
floor_gt = kd.lemma(ForAll([x], x < floor(x) + 1), by=[floor.defn])
pythag = kd.axiom(smt.ForAll([x], cos(x) ** 2 + sin(x) ** 2 == 1))
cos_abs_le = kd.lemma(smt.ForAll([x], abs(cos(x)) <= 1), by=[pythag, abs.defn])
sin_abs_le = kd.lemma(smt.ForAll([x], abs(sin(x)) <= 1), by=[pythag, abs.defn])

c = kd.Calc([n, x], smt.ToReal(n) <= x)
c.eq(n <= smt.ToInt(x))
c.eq(smt.ToReal(n) <= floor(x), by=[floor.defn])
floor_minint = c.qed(defns=False)
cos_0 = kd.axiom(cos(0) == 1)
sin_0 = kd.lemma(sin(0) == 0, by=[pythag, cos_0])

pi = smt.Const("pi", kd.R)
pi_bnd = kd.axiom(smt.And(3.14159 < pi, pi < 3.14160))

cos_pi = kd.axiom(cos(pi) == -1)
sin_pi = kd.lemma(sin(pi) == 0, by=[pythag, cos_pi])

cos_neg = kd.axiom(smt.ForAll([x], cos(-x) == cos(x)))
sin_neg = kd.axiom(smt.ForAll([x], sin(-x) == -sin(x)))

cos_add = kd.axiom(smt.ForAll([x, y], cos(x + y) == cos(x) * cos(y) - sin(x) * sin(y)))
sin_add = kd.axiom(smt.ForAll([x, y], sin(x + y) == sin(x) * cos(y) + cos(x) * sin(y)))


EReal = smt.Datatype("EReal")
Expand Down
10 changes: 6 additions & 4 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import kdrag as kd
import kdrag.theories.nat

import kdrag.theories.int

import kdrag.theories.real as R

if smt.solver != smt.VAMPIRESOLVER:
Expand Down Expand Up @@ -68,10 +70,10 @@ def test_qforall():
assert kd.QForAll([x, y], x > 0, y > 0).eq(
smt.ForAll([x, y], smt.Implies(x > 0, y > 0))
)
Nat = kd.theories.int.Nat
x = smt.Const("x", Nat)
assert kd.QForAll([x], x == Nat.mk(14)).eq(
smt.ForAll([x], smt.Implies(x.wf(), x == Nat.mk(14)))
NatI = kd.theories.int.NatI
x = smt.Const("x", NatI)
assert kd.QForAll([x], x == NatI.mk(14)).eq(
smt.ForAll([x], smt.Implies(x.wf(), x == NatI.mk(14)))
)


Expand Down

0 comments on commit efa0b33

Please sign in to comment.