Skip to content

Commit

Permalink
feat: toInt_abs
Browse files Browse the repository at this point in the history
We implement `toInt_abs`.
A subtle wrinkle is to note that `abs (intMin w) = intMin w`,
which complicates our proof.
  • Loading branch information
bollu committed Oct 21, 2024
1 parent b814be6 commit 981216e
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 11 deletions.
152 changes: 141 additions & 11 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ theorem eq_of_getMsbD_eq {x y : BitVec w}
theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp

theorem toNat_zero_length (x : BitVec 0) : x.toNat = 0 := by simp [of_length_zero]
theorem toInt_zero_length (x : BitVec 0) : x.toInt = 0 := by simp [of_length_zero]
theorem getLsbD_zero_length (x : BitVec 0) : x.getLsbD i = false := by simp
theorem getMsbD_zero_length (x : BitVec 0) : x.getMsbD i = false := by simp
theorem msb_zero_length (x : BitVec 0) : x.msb = false := by simp [BitVec.msb, of_length_zero]
Expand Down Expand Up @@ -353,7 +354,16 @@ theorem msb_eq_getLsbD_last (x : BitVec w) :
@[bv_toNat] theorem getLsbD_succ_last (x : BitVec (w + 1)) :
x.getLsbD w = decide (2 ^ w ≤ x.toNat) := getLsbD_last x

@[bv_toNat] theorem msb_eq_decide (x : BitVec w) : BitVec.msb x = decide (2 ^ (w-1) ≤ x.toNat) := by

/-- An alternative to `msb_eq_decide`-/
@[bv_toNat] theorem msb_eq_decide_le_mul_two (x : BitVec w) :
BitVec.msb x = decide (2 ^ w ≤ 2 * x.toNat) := by
rw [x.msb_eq_getLsbD_last, x.getLsbD_last]
simp
rcases w with rfl | w <;> simp <;> omega

@[bv_toNat, deprecated msb_eq_decide_le_mul_two (since := "21-10-2024") ]
theorem msb_eq_decide (x : BitVec w) : BitVec.msb x = decide (2 ^ (w-1) ≤ x.toNat) := by
simp [msb_eq_getLsbD_last, getLsbD_last]

theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat ≥ 2^(n-1) := by
Expand Down Expand Up @@ -463,6 +473,38 @@ theorem toInt_pos_iff {w : Nat} {x : BitVec w} :
0 ≤ BitVec.toInt x ↔ 2 * x.toNat < 2 ^ w := by
simp [toInt_eq_toNat_cond]; omega

/-
If `x.msb` is false, then the value of `x` when interpreted as a 2s complement
integer is between `[0..2^n/2)`.
To avoid the corner case at `n = 0`, we phrase the bounds as `2 * x < 2^n` instead of `x < 2^(n-1)`.
-/
theorem toInt_bounds_of_msb_eq_false {x : BitVec n} (hmsb : x.msb = false) :
0 ≤ x.toInt ∧ 2 * x.toInt < 2^n := by
have := x.msb_eq_decide_le_mul_two
rw [hmsb] at this
simp only [false_eq_decide_iff, Nat.not_le] at this
rw [BitVec.toInt_eq_toNat_cond]
simp [this]
apply And.intro
· omega
· norm_cast

/-
If `x.msb` is true, then the value of `x` when interpreted as a 2s complement
integer is between `[-2^n..0).
-/
theorem toInt_bounds_of_msb_eq_true {x : BitVec n} (hmsb : x.msb = true) :
-2^n ≤ x.toInt ∧ x.toInt < 0 := by
have := x.msb_eq_decide_le_mul_two
rw [hmsb] at this
simp only [true_eq_decide_iff] at this
rw [BitVec.toInt_eq_toNat_cond]
simp [show ¬ 2 * x.toNat < 2 ^ n by omega]
apply And.intro
· norm_cast
omega
· omega

theorem eq_zero_or_eq_one (a : BitVec 1) : a = 0#1 ∨ a = 1#1 := by
obtain ⟨a, ha⟩ := a
simp only [Nat.reducePow]
Expand Down Expand Up @@ -2070,16 +2112,6 @@ theorem smod_zero {x : BitVec n} : x.smod 0#n = x := by
· simp
· by_cases h : x = 0#n <;> simp [h]

/-! ### abs -/

@[simp, bv_toNat]
theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat else x.toNat := by
simp only [BitVec.abs, neg_eq]
by_cases h : x.msb = true
· simp only [h, ↓reduceIte, toNat_neg]
have : 2 * x.toNat ≥ 2 ^ w := BitVec.msb_eq_true_iff_two_mul_ge.mp h
rw [Nat.mod_eq_of_lt (by omega)]
· simp [h]

/-! ### mul -/

Expand Down Expand Up @@ -2643,6 +2675,46 @@ theorem toInt_neg_of_ne_intMin {x : BitVec w} (rs : x ≠ intMin w) :
have := @Nat.two_pow_pred_mul_two w (by omega)
split <;> split <;> omega


/-- The msb of `intMin w` is `true` for all `w > 0` -/
@[simp]
theorem msb_intMin : (intMin w).msb = decide (w > 0) := by
rw [intMin]
rw [msb_eq_decide]
simp
rcases w with rfl | w
· rfl
· simp
have : 0 < 2^w := Nat.pow_pos (by decide)
have : 2^w < 2^(w + 1) := by
rw [Nat.pow_succ]
omega
rw [Nat.mod_eq_of_lt (by omega)]
simp

/--
If the width is zero, then `intMin` is `0`,
and otherwise it is `-2^(n - 1)`.
-/
theorem toInt_intMin_eq_if (n : Nat) : (BitVec.intMin n).toInt =
if n = 0 then 0 else - 2^(n - 1) := by
simp [BitVec.toInt_intMin]
rcases n with rfl | n
· simp
· simp
norm_cast
have : 2^n > 0 := by exact Nat.two_pow_pos n
have : 2^n < 2^(n + 1) := by
simp [Nat.pow_succ]
omega
rw [Nat.mod_eq_of_lt (by omega)]

theorem toInt_intMin_eq_twoPow (hn : 0 < n) : (intMin n).toInt = -2^(n - 1) := by
rw [BitVec.toInt_intMin_eq_if]
simp [show ¬ n = 0 by omega]



/-! ### intMax -/

/-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/
Expand Down Expand Up @@ -2674,6 +2746,64 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) :=
· rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)]


/-! ### abs -/

theorem abs_def {x : BitVec w} : x.abs = if x.msb then .neg x else x := rfl

theorem toInt_neg_eq_cases {x : BitVec n} :
(-x).toInt =
if x = intMin n
then x.toInt
else - x.toInt := by
by_cases hx : x = intMin n
· simp [hx]
· simp [hx]
rw [toInt_neg_of_ne_intMin hx]

theorem abs_eq_if (x : BitVec w) : x.abs =
if x.msb = true then
if x = BitVec.intMin w then (BitVec.intMin w) else -x
else x := by
· rw [BitVec.abs_def]
by_cases hx : x.msb = true <;> by_cases hx' : x = BitVec.intMin w <;> simp [hx, hx']


/-- info: 'BitVec.toInt_intMin_eq_twoPow' depends on axioms: [propext, Quot.sound] -/
#guard_msgs in #print axioms BitVec.toInt_intMin_eq_twoPow

theorem toInt_abs (x : BitVec w) :
x.abs.toInt = if x = (intMin w) then if w = 0 then 0 else - 2^(w - 1) else x.toInt.abs := by
rcases w with rfl | w
· simp [toInt_zero_length]
· simp only [gt_iff_lt, Nat.zero_lt_succ, Nat.add_one_ne_zero, ↓reduceIte]
rw [BitVec.abs_eq_if]
by_cases hx : x = intMin (w + 1)
· simp only [hx, reduceIte]
have := BitVec.msb_intMin (w := w + 1)
rw [this]
simp only [gt_iff_lt, Nat.zero_lt_succ, decide_True, ↓reduceIte]
rw [BitVec.toInt_intMin_eq_if]
simp
· simp only [hx, reduceIte]
rcases hmsb : x.msb
· simp only [Bool.false_eq_true, ↓reduceIte]
have := BitVec.toInt_bounds_of_msb_eq_false hmsb
rw [Int.abs_eq_self]
omega
· simp only [reduceIte]
have hxbounds := BitVec.toInt_bounds_of_msb_eq_true hmsb
rw [BitVec.toInt_neg_eq_cases]
have hxneq : x.toInt ≠ (intMin (w + 1)).toInt := by
rw [BitVec.toInt_ne]
exact hx
rw [BitVec.toInt_intMin_eq_twoPow (by omega)] at hxneq
-- TODO: remove toInt_eq_toNat_cond from simp set.
simp only [hx, reduceIte]
rw [Int.abs_eq_neg (by omega)]

/-- info: 'BitVec.toInt_abs' depends on axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms BitVec.toInt_abs

/-! ### Non-overflow theorems -/

/-- If `x.toNat * y.toNat < 2^w`, then the multiplication `(x * y)` does not overflow. -/
Expand Down
7 changes: 7 additions & 0 deletions src/Init/Data/Int/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ instance : Min Int := minOfLe

instance : Max Int := maxOfLe

/--
Return the absolute value of an integer.
-/
def abs : Int → Int
| ofNat n => .ofNat n
| negSucc n => .ofNat n.succ

end Int

/--
Expand Down
24 changes: 24 additions & 0 deletions src/Init/Data/Int/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,28 @@ theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl
@[simp] theorem natCast_mul (a b : Nat) : ((a * b : Nat) : Int) = (a : Int) * (b : Int) := by
simp

/-! abs lemmas -/

@[simp]
theorem abs_eq_self {x : Int} (h : x ≥ 0) : x.abs = x := by
cases x
case ofNat h =>
rfl
case negSucc h =>
contradiction

@[simp]
theorem Int.abs_zero : Int.abs 0 = 0 := rfl

@[simp]
theorem abs_eq_neg {x : Int} (h : x < 0) : x.abs = -x := by
cases x
case ofNat h =>
contradiction
case negSucc n =>
rfl

@[simp]
theorem ofNat_abs (x : Nat) : (x : Int).abs = (x : Int) := rfl

end Int

0 comments on commit 981216e

Please sign in to comment.