Skip to content

Commit

Permalink
Mark saturating_add/sub as a builtin, and add Lean models to Primitiv…
Browse files Browse the repository at this point in the history
…es (#348)

* Add saturated_add/sub to list of builtins

* Add Lean model for saturating_add/sub

* With @sonmarcho, make scalar_tac more resilient when goals are solved by preprocessing
  • Loading branch information
R1kM authored Nov 7, 2024
1 parent 6377992 commit e838777
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 7 deletions.
13 changes: 7 additions & 6 deletions backends/lean/Base/Arith/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def intTacPreprocess (extraPrePreprocess extraPreprocess : Tactic.TacticM Unit)
-- Pre-preprocessing
extraPrePreprocess
-- Apply the forward rules
intTacSaturateForward
allGoalsNoRecover intTacSaturateForward
-- Extra preprocessing
extraPreprocess
allGoalsNoRecover extraPreprocess
-- Reduce all the terms in the goal - note that the extra preprocessing step
-- might have proven the goal, hence the `allGoals`
let dsimp :=
Expand Down Expand Up @@ -195,9 +195,10 @@ def intTac (tacName : String) (splitAllDisjs splitGoalConjs : Bool)
Utils.tryTac (
-- TODO: is there a simproc to simplify propositional logic?
Utils.simpAll {failIfUnchanged := false, maxSteps := 75} true [``reduceIte] []
[``and_self, ``false_implies, ``true_implies, ``Prod.mk.injEq, ``not_false_eq_true,
``true_and, ``and_true, ``false_and, ``and_false, ``true_or, ``or_true,
``false_or, ``or_false] [])
[``and_self, ``false_implies, ``true_implies, ``Prod.mk.injEq,
``not_false_eq_true, ``not_true_eq_false,
``true_and, ``and_true, ``false_and, ``and_false,
``true_or, ``or_true,``false_or, ``or_false] [])
allGoalsNoRecover (do
trace[Arith] "Goal after simplification: {← getMainGoal}"
Tactic.Omega.omegaTactic {})
Expand All @@ -214,7 +215,7 @@ def intTac (tacName : String) (splitAllDisjs splitGoalConjs : Bool)
simpThenOmega
catch _ =>
let g ← Tactic.getMainGoal
throwError "{tacName} failed to prove the goal below.\n\nNote that {tacName} is equivalent to:\n {tacName}_preprocess; split_all <;> simp_all only <;> omega\n\nGoal: \n{g}"
throwError "{tacName} failed to prove the goal below.\n\nNote that {tacName} is almost equivalent to:\n {tacName}_preprocess; split_all <;> simp_all only <;> omega\n\nGoal: \n{g}"

elab "int_tac" args:(" split_goal"?): tactic =>
let splitConjs := args.raw.getArgs.size > 0
Expand Down
216 changes: 216 additions & 0 deletions backends/lean/Base/Primitives/Scalar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1079,4 +1079,220 @@ theorem core.num.Usize.overflowing_add_spec (x y : Usize) :
simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add]
split <;> split <;> simp_all <;> scalar_tac

-- Saturating add
def int_saturating_add (ty : ScalarTy) (x y : Int) : Int :=
let r := x + y
let r := if r > Scalar.max ty then Scalar.max ty else r
let r := if r < Scalar.min ty then Scalar.min ty else r
r

def int_saturating_add_in_bounds {ty} (x y : Scalar ty) :
let r := int_saturating_add ty x.val y.val
Scalar.min ty ≤ r ∧ r ≤ Scalar.max ty := by
simp [int_saturating_add]
split <;> constructor <;> cases ty <;> scalar_tac

def Scalar.saturating_add {ty} (x y : Scalar ty) : Result (Scalar ty) :=
let r := int_saturating_add ty x.val y.val
have h := int_saturating_add_in_bounds x y
ok ⟨ r, h.1, h.2

/- [core::num::{u8}::saturating_add] -/
def core.num.U8.saturating_add := @Scalar.saturating_add ScalarTy.U8

/- [core::num::{u16}::saturating_add] -/
def core.num.U16.saturating_add := @Scalar.saturating_add ScalarTy.U16

/- [core::num::{u32}::saturating_add] -/
def core.num.U32.saturating_add := @Scalar.saturating_add ScalarTy.U32

/- [core::num::{u64}::saturating_add] -/
def core.num.U64.saturating_add := @Scalar.saturating_add ScalarTy.U64

/- [core::num::{u128}::saturating_add] -/
def core.num.U128.saturating_add := @Scalar.saturating_add ScalarTy.U128

/- [core::num::{usize}::saturating_add] -/
def core.num.Usize.saturating_add := @Scalar.saturating_add ScalarTy.Usize

/- [core::num::{i8}::saturating_add] -/
def core.num.I8.saturating_add := @Scalar.saturating_add ScalarTy.I8

/- [core::num::{i16}::saturating_add] -/
def core.num.I16.saturating_add := @Scalar.saturating_add ScalarTy.I16

/- [core::num::{i32}::saturating_add] -/
def core.num.I32.saturating_add := @Scalar.saturating_add ScalarTy.I32

/- [core::num::{i64}::saturating_add] -/
def core.num.I64.saturating_add := @Scalar.saturating_add ScalarTy.I64

/- [core::num::{i128}::saturating_add] -/
def core.num.I128.saturating_add := @Scalar.saturating_add ScalarTy.I128

/- [core::num::{isize}::saturating_add] -/
def core.num.Isize.saturating_add := @Scalar.saturating_add ScalarTy.Isize

@[pspec]
theorem core.num.U8.saturating_add_spec (x y : U8) :
∃ z, saturating_add x y = ok z ∧
if x.val + y.val > U8.max then z.val = U8.max
else z.val = x.val + y.val
:= by
simp [saturating_add, Scalar.saturating_add, int_saturating_add]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U16.saturating_add_spec (x y : U16) :
∃ z, saturating_add x y = ok z ∧
if x.val + y.val > U16.max then z.val = U16.max
else z.val = x.val + y.val
:= by
simp [saturating_add, Scalar.saturating_add, int_saturating_add]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U32.saturating_add_spec (x y : U32) :
∃ z, saturating_add x y = ok z ∧
if x.val + y.val > U32.max then z.val = U32.max
else z.val = x.val + y.val
:= by
simp [saturating_add, Scalar.saturating_add, int_saturating_add]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U64.saturating_add_spec (x y : U64) :
∃ z, saturating_add x y = ok z ∧
if x.val + y.val > U64.max then z.val = U64.max
else z.val = x.val + y.val
:= by
simp [saturating_add, Scalar.saturating_add, int_saturating_add]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U128.saturating_add_spec (x y : U128) :
∃ z, saturating_add x y = ok z ∧
if x.val + y.val > U128.max then z.val = U128.max
else z.val = x.val + y.val
:= by
simp [saturating_add, Scalar.saturating_add, int_saturating_add]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.Usize.saturating_add_spec (x y : Usize) :
∃ z, saturating_add x y = ok z ∧
if x.val + y.val > Usize.max then z.val = Usize.max
else z.val = x.val + y.val
:= by
simp [saturating_add, Scalar.saturating_add, int_saturating_add]
split <;> split <;> split <;> scalar_tac

-- Saturating sub
def int_saturating_sub (ty : ScalarTy) (x y : Int) : Int :=
let r := x - y
let r := if r > Scalar.max ty then Scalar.max ty else r
let r := if r < Scalar.min ty then Scalar.min ty else r
r

def int_saturating_sub_in_bounds {ty} (x y : Scalar ty) :
let r := int_saturating_sub ty x.val y.val
Scalar.min ty ≤ r ∧ r ≤ Scalar.max ty := by
simp [int_saturating_sub]
split <;> constructor <;> cases ty <;> scalar_tac

def Scalar.saturating_sub {ty} (x y : Scalar ty) : Result (Scalar ty) :=
let r := int_saturating_sub ty x.val y.val
have h := int_saturating_sub_in_bounds x y
ok ⟨ r, h.1, h.2

/- [core::num::{u8}::saturating_sub] -/
def core.num.U8.saturating_sub := @Scalar.saturating_sub ScalarTy.U8

/- [core::num::{u16}::saturating_sub] -/
def core.num.U16.saturating_sub := @Scalar.saturating_sub ScalarTy.U16

/- [core::num::{u32}::saturating_sub] -/
def core.num.U32.saturating_sub := @Scalar.saturating_sub ScalarTy.U32

/- [core::num::{u64}::saturating_sub] -/
def core.num.U64.saturating_sub := @Scalar.saturating_sub ScalarTy.U64

/- [core::num::{u128}::saturating_sub] -/
def core.num.U128.saturating_sub := @Scalar.saturating_sub ScalarTy.U128

/- [core::num::{usize}::saturating_sub] -/
def core.num.Usize.saturating_sub := @Scalar.saturating_sub ScalarTy.Usize

/- [core::num::{i8}::saturating_sub] -/
def core.num.I8.saturating_sub := @Scalar.saturating_sub ScalarTy.I8

/- [core::num::{i16}::saturating_sub] -/
def core.num.I16.saturating_sub := @Scalar.saturating_sub ScalarTy.I16

/- [core::num::{i32}::saturating_sub] -/
def core.num.I32.saturating_sub := @Scalar.saturating_sub ScalarTy.I32

/- [core::num::{i64}::saturating_sub] -/
def core.num.I64.saturating_sub := @Scalar.saturating_sub ScalarTy.I64

/- [core::num::{i128}::saturating_sub] -/
def core.num.I128.saturating_sub := @Scalar.saturating_sub ScalarTy.I128

/- [core::num::{isize}::saturating_sub] -/
def core.num.Isize.saturating_sub := @Scalar.saturating_sub ScalarTy.Isize

@[pspec]
theorem core.num.U8.saturating_sub_spec (x y : U8) :
∃ z, saturating_sub x y = ok z ∧
if x.val - y.val < 0 then z.val = 0
else z.val = x.val - y.val
:= by
simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U16.saturating_sub_spec (x y : U16) :
∃ z, saturating_sub x y = ok z ∧
if x.val - y.val < 0 then z.val = 0
else z.val = x.val - y.val
:= by
simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U32.saturating_sub_spec (x y : U32) :
∃ z, saturating_sub x y = ok z ∧
if x.val - y.val < 0 then z.val = 0
else z.val = x.val - y.val
:= by
simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U64.saturating_sub_spec (x y : U64) :
∃ z, saturating_sub x y = ok z ∧
if x.val - y.val < 0 then z.val = 0
else z.val = x.val - y.val
:= by
simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.U128.saturating_sub_spec (x y : U128) :
∃ z, saturating_sub x y = ok z ∧
if x.val - y.val < 0 then z.val = 0
else z.val = x.val - y.val
:= by
simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub]
split <;> split <;> split <;> scalar_tac

@[pspec]
theorem core.num.Usize.saturating_sub_spec (x y : Usize) :
∃ z, saturating_sub x y = ok z ∧
if x.val - y.val < 0 then z.val = 0
else z.val = x.val - y.val
:= by
simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub]
split <;> split <;> split <;> scalar_tac

end Primitives
9 changes: 8 additions & 1 deletion compiler/ExtractBuiltin.ml
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,14 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list =
mk_fun
("core::num::" ^ "{" ^ int_name ^ "}::" ^ op)
~can_fail:false ())
[ "wrapping_add"; "wrapping_sub"; "rotate_left"; "rotate_right" ])
[
"saturating_add";
"saturating_sub";
"wrapping_add";
"wrapping_sub";
"rotate_left";
"rotate_right";
])
all_int_names)
@ List.flatten
(List.map
Expand Down

0 comments on commit e838777

Please sign in to comment.