diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index ff916cf7d4fe..76c8033daff5 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -482,8 +482,7 @@ theorem msb_abs {w : Nat} {x : BitVec w} : and_intros · by_cases h₃ : x = 0#w · simp [h₃] at h₂ - · simp [h₃] - · simp [h₁] + · simp [h₃, h₁] · simp [h₂] · simp [BitVec.msb, show w = 0 by omega] diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 954c13a7e942..4079632c2992 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -2323,6 +2323,27 @@ theorem toNat_udiv {x y : BitVec n} : (x / y).toNat = x.toNat / y.toNat := by · rw [toNat_ofNat, Nat.mod_eq_of_lt] exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega) + +/-- If the LHS and RHS are both positive, then `(x udiv y)` is also positive -/ +theorem msb_udiv_eq_false_of_msb_eq_false {x y : BitVec n} (hx : x.msb = false) (hy : y.msb = false) : + (x / y).msb = false := by + rw [msb_eq_decide, toNat_udiv] + rcases n with rfl | n + · simp [@of_length_zero x, @of_length_zero y] + · simp + have xLt := msb_eq_false_iff_two_mul_lt.mp hx + have yLt := msb_eq_false_iff_two_mul_lt.mp hy + apply Nat.lt_of_le_of_lt + · apply Nat.div_le_self + · omega + +/-- If the LHS and RHS are both positive, then `(x udiv y).toInt` equals `(x udiv y).toNat` -/ +theorem toInt_udiv_eq_toNat_udiv_of_msb_eq_false {x y : BitVec n} (hx : x.msb = false) (hy : y.msb = false) : + (x / y).toInt = (x / y).toNat:= by + rw [toInt_eq_msb_cond] + have := msb_udiv_eq_false_of_msb_eq_false hx hy + simp [this] + @[simp] theorem zero_udiv {x : BitVec w} : (0#w) / x = 0#w := by simp [bv_toNat] @@ -2357,6 +2378,7 @@ theorem udiv_self {x : BitVec w} : ↓reduceIte, toNat_udiv] rw [Nat.div_self (by omega), Nat.mod_eq_of_lt (by omega)] + /-! ### umod -/ theorem umod_def {x y : BitVec n} : @@ -2460,6 +2482,9 @@ theorem sdiv_self {x : BitVec w} : rcases x.msb with msb | msb <;> simp · rcases x.msb with msb | msb <;> simp [h] +theorem sdiv_eq_udiv {x y : BitVec w} (hx : x.msb = false) (hy : y.msb = false) : x.sdiv y = x / y := by + simp [sdiv_eq, hx, hy] + /-! ### smtSDiv -/ theorem smtSDiv_eq (x y : BitVec w) : smtSDiv x y = @@ -2817,6 +2842,7 @@ theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this] + /- ### cons -/ @[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by @@ -2926,6 +2952,8 @@ theorem getElem_replicate {n w : Nat} (x : BitVec w) (h : i < w * n) : /-- The bitvector of width `w` that has the smallest value when interpreted as an integer. -/ def intMin (w : Nat) := twoPow w (w - 1) +theorem intMin_eq (w : Nat) : twoPow w (w - 1) = intMin w := by rfl + theorem getLsbD_intMin (w : Nat) : (intMin w).getLsbD i = decide (i + 1 = w) := by simp only [intMin, getLsbD_twoPow, boolToPropSimps] omega @@ -3006,6 +3034,67 @@ theorem msb_intMin {w : Nat} : (intMin w).msb = decide (0 < w) := by simp only [msb_eq_decide, toNat_intMin, decide_eq_decide] by_cases h : 0 < w <;> simp_all +/-- +If `x = 0`, then `-x = 0`, and thus `x.msb = false`. +Otherwise, if `x = intMin w`, then `-x = intMin w`, and thus `x.msb = true`. +Otherwise, the `msb`is the negation of the msb of `x`. +-/ +@[simp] +theorem msb_neg_eq_decide (x : BitVec w) : (- x).msb = (decide (x ≠ 0) && (!x.msb || (x = intMin w))) := by + rcases w with rfl | w + · simp [BitVec.eq_nil x] + · simp only [Nat.zero_lt_succ, decide_true, Bool.true_and, msb_eq_decide] + simp + by_cases hx : x = 0#(w + 1) + · simp [hx] + have : 0 < 2^w := Nat.pow_pos (by decide) + omega + · simp [hx] + have : 2^w < 2^(w + 1) := Nat.pow_lt_pow_of_lt (by simp) (by simp) + by_cases hx' : x.toNat = 2^w + · have : x = intMin (w + 1) := by + apply eq_of_toNat_eq + simp [hx'] + rw [Nat.mod_eq_of_lt (by omega)] + subst this + simp + rw [Nat.mod_eq_of_lt (by omega)] + omega + · have : x ≠ intMin (w + 1) := by + rw [toNat_ne] + simp + rw [Nat.mod_eq_of_lt (by omega)] + assumption + simp [this] + by_cases hmsb : 2 ^ w ≤ x.toNat + · simp [hmsb] + omega + · simp [hmsb] + omega + + +/-- +If both the numerator and denominator is positive, +then `(x.sdiv y).toInt` equals computing `x.toInt / y.toInt`. +-/ +theorem toInt_sdiv_eq_toInt_div_toInt_of_msb_eq_false {x y : BitVec n} (hx : x.msb = false) (hy : y.msb = false) : + (x.sdiv y).toInt = x.toInt / y.toInt := by + rcases n with rfl | n + · simp [eq_nil x, eq_nil y] + · have hxLt := msb_eq_false_iff_two_mul_lt.mp hx + have hyLt := msb_eq_false_iff_two_mul_lt.mp hy + simp [sdiv_eq, hx, hy] + rw [toInt_eq_toNat_cond] + have xToInt : x.toInt = x.toNat := by simp [toInt_eq_msb_cond, hx] + have yToInt : y.toInt = y.toNat := by simp [toInt_eq_msb_cond, hy] + simp [xToInt, yToInt] + intros h + have : x.toNat / y.toNat < 2^n := by + apply Nat.lt_of_le_of_lt + · apply Nat.div_le_self + · omega + omega + /-! ### intMax -/ /-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/ @@ -3131,6 +3220,133 @@ theorem getMsbD_abs {i : Nat} {x : BitVec w} : getMsbD (x.abs) i = if x.msb then getMsbD (-x) i else getMsbD x i := by by_cases h : x.msb <;> simp [BitVec.abs, h] +/- +This characterizes signed division as unsigned division of the absolute value, +followed by an optional negation. +This is useful to port theorems from `udiv` into `sdiv`. +-/ +theorem sdiv_eq_udiv_abs {x y : BitVec n} : x.sdiv y = + (if x.msb = y.msb then id else BitVec.neg) (x.abs / y.abs) := by + rcases n with rfl | n + · simp [eq_nil x, eq_nil y] + · rw [sdiv_eq] + by_cases hx : x.msb + have hxGe := msb_eq_true_iff_two_mul_ge.mp hx + · by_cases hy : y.msb + · have hyGe := msb_eq_true_iff_two_mul_ge.mp hy + simp [hx, hy] + apply eq_of_toNat_eq + simp [hx, hy] + congr + · apply Nat.mod_eq_of_lt (by omega) + · apply Nat.mod_eq_of_lt (by omega) + · simp at hy + have hyLt := msb_eq_false_iff_two_mul_lt.mp hy + simp [hx, hy] + apply eq_of_toNat_eq + simp [hx, hy] + congr + · apply Nat.mod_eq_of_lt (by omega) + · simp at hx + have hxLt := msb_eq_false_iff_two_mul_lt.mp hx + by_cases hy : y.msb + · have hyGe := msb_eq_true_iff_two_mul_ge.mp hy + simp at hy + simp [hx, hy] + apply eq_of_toNat_eq + simp [hx, hy] + congr + apply Nat.mod_eq_of_lt (by omega) + · simp at hx hy + apply eq_of_toNat_eq + simp [hx, hy] + + +theorem neq_zero_of_msb_eq_true (x : BitVec w) (hx : x.msb = true) : x ≠ 0 := by + have := msb_eq_true_iff_two_mul_ge.mp hx + rcases w with rfl | w + · simp at this + rw [@BitVec.of_length_zero x] at hx + simp at hx + · have : x.toNat ≥ 2^w := by omega + have : 0 < 2^w := by exact Nat.two_pow_pos w + apply toNat_ne.mpr + simp + omega + +/-- +The value of in`intMin` of width `w + 1` is equal to `2^w`. +This avoids the corner case when `w = 0`, where the `toNat` value is itself `0`. +-/ +theorem toNat_intMin_of_lt (w : Nat) : (intMin (w + 1)).toNat = 2^w := by + simp + apply Nat.mod_eq_of_lt + apply Nat.pow_lt_pow_of_lt (by decide) (by omega) + + +theorem BitVec.udiv_eq_zero_of_lt (x y : BitVec w) (h : x.toNat < y.toNat ) : x / y = 0#w := by + apply eq_of_toNat_eq + simp + exact Nat.div_eq_of_lt h + +theorem sdiv_twoPow_eq_neg {w : Nat} {x : BitVec w} : x.sdiv (twoPow w (w - 1)) = +if x = twoPow w (w - 1) then 1 else 0 := by + rcases w with rfl | w + · simp + apply of_length_zero + · by_cases hx : x = (twoPow (w + 1) (w + 1 - 1)) + · rw [hx, intMin_eq] + simp + intro h + have : (intMin (w + 1)).msb = (BitVec.ofNat (w + 1) 0).msb := by simp [h] + simp [msb_intMin] at this + · rw [intMin_eq] at hx + rw [intMin_eq] + simp [hx] + by_cases hx' : x.msb <;> rw [sdiv_eq] <;> simp [hx', msb_intMin] + · have hxMsbNeg : (-x).msb = false := by + simp [msb_neg_eq_decide x] + intros h + simp [hx', hx] + have := msb_eq_false_iff_two_mul_lt.mp hxMsbNeg + apply BitVec.udiv_eq_zero_of_lt + simp only [toNat_intMin_of_lt] + omega + · apply BitVec.udiv_eq_zero_of_lt + simp only [toNat_intMin_of_lt] + simp at hx' + have := msb_eq_false_iff_two_mul_lt.mp hx' + omega + +theorem sdiv_twoPow_eq_sshiftRight_of_lt + {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w - 1) : + x.sdiv (twoPow w k) = + if x.msb then -(-x) >>> k else x >>> k := by + by_cases hx : x.msb <;> simp at hx <;> simp [hx] + · simp [sdiv_eq, hx, show (k < w) by omega, show (k ≠ w - 1) by omega] + rw [udiv_twoPow_eq_of_lt] + omega + · rw [sdiv_eq_udiv] + · apply udiv_twoPow_eq_of_lt + simp [show k < w by omega] + · assumption + · rw [msb_twoPow] + simp [show k < w by omega] + omega + +theorem sdiv_twoPow_eq + {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : + x.sdiv (twoPow w k) = + if k = w - 1 then (if x = twoPow w (w - 1) then 1 else 0) + else (if x.msb then -(-x) >>> k else x >>> k) := by + by_cases hk : k = w - 1 + · simp [hk] + apply sdiv_twoPow_eq_neg + · have hk' : k < w - 1 := by omega + simp [hk] + apply sdiv_twoPow_eq_sshiftRight_of_lt + assumption + /-! ### Decidable quantifiers -/ theorem forall_zero_iff {P : BitVec 0 → Prop} :