diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean index b041c3fb9..b215654d1 100644 --- a/backends/lean/Base/Arith/Int.lean +++ b/backends/lean/Base/Arith/Int.lean @@ -121,9 +121,9 @@ def intTacPreprocess (extraPrePreprocess extraPreprocess : Tactic.TacticM Unit) -- Pre-preprocessing extraPrePreprocess -- Apply the forward rules - intTacSaturateForward + allGoalsNoRecover intTacSaturateForward -- Extra preprocessing - extraPreprocess + allGoalsNoRecover extraPreprocess -- Reduce all the terms in the goal - note that the extra preprocessing step -- might have proven the goal, hence the `allGoals` let dsimp := @@ -195,9 +195,10 @@ def intTac (tacName : String) (splitAllDisjs splitGoalConjs : Bool) Utils.tryTac ( -- TODO: is there a simproc to simplify propositional logic? Utils.simpAll {failIfUnchanged := false, maxSteps := 75} true [``reduceIte] [] - [``and_self, ``false_implies, ``true_implies, ``Prod.mk.injEq, ``not_false_eq_true, - ``true_and, ``and_true, ``false_and, ``and_false, ``true_or, ``or_true, - ``false_or, ``or_false] []) + [``and_self, ``false_implies, ``true_implies, ``Prod.mk.injEq, + ``not_false_eq_true, ``not_true_eq_false, + ``true_and, ``and_true, ``false_and, ``and_false, + ``true_or, ``or_true,``false_or, ``or_false] []) allGoalsNoRecover (do trace[Arith] "Goal after simplification: {← getMainGoal}" Tactic.Omega.omegaTactic {}) @@ -214,7 +215,7 @@ def intTac (tacName : String) (splitAllDisjs splitGoalConjs : Bool) simpThenOmega catch _ => let g ← Tactic.getMainGoal - throwError "{tacName} failed to prove the goal below.\n\nNote that {tacName} is equivalent to:\n {tacName}_preprocess; split_all <;> simp_all only <;> omega\n\nGoal: \n{g}" + throwError "{tacName} failed to prove the goal below.\n\nNote that {tacName} is almost equivalent to:\n {tacName}_preprocess; split_all <;> simp_all only <;> omega\n\nGoal: \n{g}" elab "int_tac" args:(" split_goal"?): tactic => let splitConjs := args.raw.getArgs.size > 0 diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 627c6b338..6ae14027d 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1079,4 +1079,220 @@ theorem core.num.Usize.overflowing_add_spec (x y : Usize) : simp [overflowing_add, Scalar.overflowing_add, int_overflowing_add] split <;> split <;> simp_all <;> scalar_tac +-- Saturating add +def int_saturating_add (ty : ScalarTy) (x y : Int) : Int := + let r := x + y + let r := if r > Scalar.max ty then Scalar.max ty else r + let r := if r < Scalar.min ty then Scalar.min ty else r + r + +def int_saturating_add_in_bounds {ty} (x y : Scalar ty) : + let r := int_saturating_add ty x.val y.val + Scalar.min ty ≤ r ∧ r ≤ Scalar.max ty := by + simp [int_saturating_add] + split <;> constructor <;> cases ty <;> scalar_tac + +def Scalar.saturating_add {ty} (x y : Scalar ty) : Result (Scalar ty) := + let r := int_saturating_add ty x.val y.val + have h := int_saturating_add_in_bounds x y + ok ⟨ r, h.1, h.2 ⟩ + +/- [core::num::{u8}::saturating_add] -/ +def core.num.U8.saturating_add := @Scalar.saturating_add ScalarTy.U8 + +/- [core::num::{u16}::saturating_add] -/ +def core.num.U16.saturating_add := @Scalar.saturating_add ScalarTy.U16 + +/- [core::num::{u32}::saturating_add] -/ +def core.num.U32.saturating_add := @Scalar.saturating_add ScalarTy.U32 + +/- [core::num::{u64}::saturating_add] -/ +def core.num.U64.saturating_add := @Scalar.saturating_add ScalarTy.U64 + +/- [core::num::{u128}::saturating_add] -/ +def core.num.U128.saturating_add := @Scalar.saturating_add ScalarTy.U128 + +/- [core::num::{usize}::saturating_add] -/ +def core.num.Usize.saturating_add := @Scalar.saturating_add ScalarTy.Usize + +/- [core::num::{i8}::saturating_add] -/ +def core.num.I8.saturating_add := @Scalar.saturating_add ScalarTy.I8 + +/- [core::num::{i16}::saturating_add] -/ +def core.num.I16.saturating_add := @Scalar.saturating_add ScalarTy.I16 + +/- [core::num::{i32}::saturating_add] -/ +def core.num.I32.saturating_add := @Scalar.saturating_add ScalarTy.I32 + +/- [core::num::{i64}::saturating_add] -/ +def core.num.I64.saturating_add := @Scalar.saturating_add ScalarTy.I64 + +/- [core::num::{i128}::saturating_add] -/ +def core.num.I128.saturating_add := @Scalar.saturating_add ScalarTy.I128 + +/- [core::num::{isize}::saturating_add] -/ +def core.num.Isize.saturating_add := @Scalar.saturating_add ScalarTy.Isize + +@[pspec] +theorem core.num.U8.saturating_add_spec (x y : U8) : + ∃ z, saturating_add x y = ok z ∧ + if x.val + y.val > U8.max then z.val = U8.max + else z.val = x.val + y.val + := by + simp [saturating_add, Scalar.saturating_add, int_saturating_add] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U16.saturating_add_spec (x y : U16) : + ∃ z, saturating_add x y = ok z ∧ + if x.val + y.val > U16.max then z.val = U16.max + else z.val = x.val + y.val + := by + simp [saturating_add, Scalar.saturating_add, int_saturating_add] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U32.saturating_add_spec (x y : U32) : + ∃ z, saturating_add x y = ok z ∧ + if x.val + y.val > U32.max then z.val = U32.max + else z.val = x.val + y.val + := by + simp [saturating_add, Scalar.saturating_add, int_saturating_add] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U64.saturating_add_spec (x y : U64) : + ∃ z, saturating_add x y = ok z ∧ + if x.val + y.val > U64.max then z.val = U64.max + else z.val = x.val + y.val + := by + simp [saturating_add, Scalar.saturating_add, int_saturating_add] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U128.saturating_add_spec (x y : U128) : + ∃ z, saturating_add x y = ok z ∧ + if x.val + y.val > U128.max then z.val = U128.max + else z.val = x.val + y.val + := by + simp [saturating_add, Scalar.saturating_add, int_saturating_add] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.Usize.saturating_add_spec (x y : Usize) : + ∃ z, saturating_add x y = ok z ∧ + if x.val + y.val > Usize.max then z.val = Usize.max + else z.val = x.val + y.val + := by + simp [saturating_add, Scalar.saturating_add, int_saturating_add] + split <;> split <;> split <;> scalar_tac + +-- Saturating sub +def int_saturating_sub (ty : ScalarTy) (x y : Int) : Int := + let r := x - y + let r := if r > Scalar.max ty then Scalar.max ty else r + let r := if r < Scalar.min ty then Scalar.min ty else r + r + +def int_saturating_sub_in_bounds {ty} (x y : Scalar ty) : + let r := int_saturating_sub ty x.val y.val + Scalar.min ty ≤ r ∧ r ≤ Scalar.max ty := by + simp [int_saturating_sub] + split <;> constructor <;> cases ty <;> scalar_tac + +def Scalar.saturating_sub {ty} (x y : Scalar ty) : Result (Scalar ty) := + let r := int_saturating_sub ty x.val y.val + have h := int_saturating_sub_in_bounds x y + ok ⟨ r, h.1, h.2 ⟩ + +/- [core::num::{u8}::saturating_sub] -/ +def core.num.U8.saturating_sub := @Scalar.saturating_sub ScalarTy.U8 + +/- [core::num::{u16}::saturating_sub] -/ +def core.num.U16.saturating_sub := @Scalar.saturating_sub ScalarTy.U16 + +/- [core::num::{u32}::saturating_sub] -/ +def core.num.U32.saturating_sub := @Scalar.saturating_sub ScalarTy.U32 + +/- [core::num::{u64}::saturating_sub] -/ +def core.num.U64.saturating_sub := @Scalar.saturating_sub ScalarTy.U64 + +/- [core::num::{u128}::saturating_sub] -/ +def core.num.U128.saturating_sub := @Scalar.saturating_sub ScalarTy.U128 + +/- [core::num::{usize}::saturating_sub] -/ +def core.num.Usize.saturating_sub := @Scalar.saturating_sub ScalarTy.Usize + +/- [core::num::{i8}::saturating_sub] -/ +def core.num.I8.saturating_sub := @Scalar.saturating_sub ScalarTy.I8 + +/- [core::num::{i16}::saturating_sub] -/ +def core.num.I16.saturating_sub := @Scalar.saturating_sub ScalarTy.I16 + +/- [core::num::{i32}::saturating_sub] -/ +def core.num.I32.saturating_sub := @Scalar.saturating_sub ScalarTy.I32 + +/- [core::num::{i64}::saturating_sub] -/ +def core.num.I64.saturating_sub := @Scalar.saturating_sub ScalarTy.I64 + +/- [core::num::{i128}::saturating_sub] -/ +def core.num.I128.saturating_sub := @Scalar.saturating_sub ScalarTy.I128 + +/- [core::num::{isize}::saturating_sub] -/ +def core.num.Isize.saturating_sub := @Scalar.saturating_sub ScalarTy.Isize + +@[pspec] +theorem core.num.U8.saturating_sub_spec (x y : U8) : + ∃ z, saturating_sub x y = ok z ∧ + if x.val - y.val < 0 then z.val = 0 + else z.val = x.val - y.val + := by + simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U16.saturating_sub_spec (x y : U16) : + ∃ z, saturating_sub x y = ok z ∧ + if x.val - y.val < 0 then z.val = 0 + else z.val = x.val - y.val + := by + simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U32.saturating_sub_spec (x y : U32) : + ∃ z, saturating_sub x y = ok z ∧ + if x.val - y.val < 0 then z.val = 0 + else z.val = x.val - y.val + := by + simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U64.saturating_sub_spec (x y : U64) : + ∃ z, saturating_sub x y = ok z ∧ + if x.val - y.val < 0 then z.val = 0 + else z.val = x.val - y.val + := by + simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.U128.saturating_sub_spec (x y : U128) : + ∃ z, saturating_sub x y = ok z ∧ + if x.val - y.val < 0 then z.val = 0 + else z.val = x.val - y.val + := by + simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] + split <;> split <;> split <;> scalar_tac + +@[pspec] +theorem core.num.Usize.saturating_sub_spec (x y : Usize) : + ∃ z, saturating_sub x y = ok z ∧ + if x.val - y.val < 0 then z.val = 0 + else z.val = x.val - y.val + := by + simp [saturating_sub, Scalar.saturating_sub, int_saturating_sub] + split <;> split <;> split <;> scalar_tac + end Primitives diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 1250e27de..85879d3f7 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -431,7 +431,14 @@ let mk_builtin_funs () : (pattern * bool list option * builtin_fun_info) list = mk_fun ("core::num::" ^ "{" ^ int_name ^ "}::" ^ op) ~can_fail:false ()) - [ "wrapping_add"; "wrapping_sub"; "rotate_left"; "rotate_right" ]) + [ + "saturating_add"; + "saturating_sub"; + "wrapping_add"; + "wrapping_sub"; + "rotate_left"; + "rotate_right"; + ]) all_int_names) @ List.flatten (List.map