From 107a2e8b2e53ef438a5f79f77755a78ec4c90cc6 Mon Sep 17 00:00:00 2001 From: Siddharth Date: Sat, 23 Nov 2024 07:29:08 +0000 Subject: [PATCH] feat: BitVec.toInt BitVec.signExtend (#6157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds toInt theorems for BitVec.signExtend. If the current width `w` is larger than the extended width `v`, then the value when interpreted as an integer is truncated, and we compute a modulo by `2^v`. ```lean theorem toInt_signExtend_of_le (x : BitVec w) (hv : v ≤ w) : (x.signExtend v).toInt = Int.bmod (x.toNat) (2^v) ``` Co-authored-by: Siddharth Bhat Co-authored-by: Harun Khan Stacked on top of #6155 --------- Co-authored-by: Harun Khan --- src/Init/Data/BitVec/Lemmas.lean | 77 ++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 1043b2f4e2e5..5c931f353c4f 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -565,6 +565,10 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} : else simp [n_le_i, toNat_ofNat] +@[simp] theorem toInt_setWidth (x : BitVec w) : + (x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by + simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod] + theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by apply eq_of_toNat_eq rw [toNat_setWidth, toNat_setWidth'] @@ -1615,6 +1619,79 @@ theorem signExtend_eq_setWidth_of_lt (x : BitVec w) {v : Nat} (hv : v ≤ w): theorem signExtend_eq (x : BitVec w) : x.signExtend w = x := by rw [signExtend_eq_setWidth_of_lt _ (Nat.le_refl _), setWidth_eq] +/-- Sign extending to a larger bitwidth depends on the msb. +If the msb is false, then the result equals the original value. +If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/ +theorem toNat_signExtend_of_le (x : BitVec w) {v : Nat} (hv : w ≤ v) : + (x.signExtend v).toNat = x.toNat + if x.msb then 2^v - 2^w else 0 := by + apply Nat.eq_of_testBit_eq + intro i + have ⟨k, hk⟩ := Nat.exists_eq_add_of_le hv + rw [hk, testBit_toNat, getLsbD_signExtend, Nat.pow_add, ← Nat.mul_sub_one, Nat.add_comm (x.toNat)] + by_cases hx : x.msb + · simp [hx, Nat.testBit_mul_pow_two_add _ x.isLt, testBit_toNat] + -- Case analysis on i being in the intervals [0..w), [w..w + k), [w+k..∞) + have hi : i < w ∨ (w ≤ i ∧ i < w + k) ∨ w + k ≤ i := by omega + rcases hi with hi | hi | hi + · simp [hi]; omega + · simp [hi]; omega + · simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega] + omega + · simp [hx, Nat.testBit_mul_pow_two_add _ x.isLt, testBit_toNat] + have hi : i < w ∨ (w ≤ i ∧ i < w + k) ∨ w + k ≤ i := by omega + rcases hi with hi | hi | hi + · simp [hi]; omega + · simp [hi] + · simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega, getLsbD_ge x i (by omega)] + +/-- Sign extending to a larger bitwidth depends on the msb. +If the msb is false, then the result equals the original value. +If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/ +theorem toNat_signExtend (x : BitVec w) {v : Nat} : + (x.signExtend v).toNat = (x.setWidth v).toNat + if x.msb then 2^v - 2^w else 0 := by + by_cases h : v ≤ w + · have : 2^v ≤ 2^w := Nat.pow_le_pow_of_le_right Nat.two_pos h + simp [signExtend_eq_setWidth_of_lt x h, toNat_setWidth, Nat.sub_eq_zero_of_le this] + · have : 2^w ≤ 2^v := Nat.pow_le_pow_of_le_right Nat.two_pos (by omega) + rw [toNat_signExtend_of_le x (by omega), toNat_setWidth, Nat.mod_eq_of_lt (by omega)] + +/- +If the current width `w` is smaller than the extended width `v`, +then the value when interpreted as an integer does not change. +-/ +theorem toInt_signExtend_of_lt {x : BitVec w} (hv : w < v): + (x.signExtend v).toInt = x.toInt := by + simp only [toInt_eq_msb_cond, toNat_signExtend] + have : (x.signExtend v).msb = x.msb := by + rw [msb_eq_getLsbD_last, getLsbD_eq_getElem (Nat.sub_one_lt_of_lt hv)] + simp [getElem_signExtend, Nat.le_sub_one_of_lt hv] + have H : 2^w ≤ 2^v := Nat.pow_le_pow_of_le_right (by omega) (by omega) + simp only [this, toNat_setWidth, Int.natCast_add, Int.ofNat_emod, Int.natCast_mul] + by_cases h : x.msb + <;> norm_cast + <;> simp [h, Nat.mod_eq_of_lt (Nat.lt_of_lt_of_le x.isLt H)] + omega + +/- +If the current width `w` is larger than the extended width `v`, +then the value when interpreted as an integer is truncated, +and we compute a modulo by `2^v`. +-/ +theorem toInt_signExtend_of_le {x : BitVec w} (hv : v ≤ w) : + (x.signExtend v).toInt = Int.bmod x.toNat (2^v) := by + simp [signExtend_eq_setWidth_of_lt _ hv] + +/- +Interpreting the sign extension of `(x : BitVec w)` to width `v` +computes `x % 2^v` (where `%` is the balanced mod). +-/ +theorem toInt_signExtend (x : BitVec w) : + (x.signExtend v).toInt = Int.bmod x.toNat (2^(min v w)) := by + by_cases hv : v ≤ w + · simp [toInt_signExtend_of_le hv, Nat.min_eq_left hv] + · simp only [Nat.not_le] at hv + rw [toInt_signExtend_of_lt hv, Nat.min_eq_right (by omega), toInt_eq_toNat_bmod] + /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) :