diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 8c7471551e0f..937e870cc6a1 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1240,6 +1240,48 @@ theorem toNat_ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) : · apply hn · apply Nat.pow_pos (by decide) + +/-- Shifting right by `n`, which is larger than the bitwidth `w` produces `0. -/ +theorem ushiftRight_eq_zero {x : BitVec w} {n : Nat} (hn : w ≤ n) : + x >>> n = 0#w := by + simp [bv_toNat] + have : 2^w ≤ 2^n := Nat.pow_le_pow_of_le Nat.one_lt_two hn + rw [Nat.shiftRight_eq_div_pow, Nat.div_eq_of_lt (by omega)] + + +/-- +Unsigned shift right by at least one bit makes the value of the bitvector less than or equal to `2^(w-1)`, +makes the interpretation of the bitvector `Int` and `Nat` agree. +-/ +theorem toInt_ushiftRight_of_lt {x : BitVec w} {n : Nat} (hn : 0 < n) : + (x >>> n).toInt = x.toNat >>> n := by + rw [toInt_eq_toNat_cond] + simp only [toNat_ushiftRight, ite_eq_left_iff, Nat.not_lt] + intros h + by_cases hn: n ≤ w + · have h1 := Nat.mul_lt_mul_of_pos_left (toNat_ushiftRight_lt x n hn) Nat.two_pos + simp only [toNat_ushiftRight, Nat.zero_lt_succ, Nat.mul_lt_mul_left] at h1 + have : 2 ^ (w - n).succ ≤ 2^ w := Nat.pow_le_pow_of_le (by decide) (by omega) + have := show 2 * x.toNat >>> n < 2 ^ w by + omega + omega + · have : x.toNat >>> n = 0 := by + apply Nat.shiftRight_eq_zero + have : 2^w ≤ 2^n := Nat.pow_le_pow_of_le (by decide) (by omega) + omega + simp [this] at h + omega + +@[simp] +theorem toFin_uShiftRight {x : BitVec w} {n : Nat} : + (x >>> n).toFin = x.toFin / (Fin.ofNat' (2^w) (2^n)) := by + apply Fin.eq_of_val_eq + by_cases hn : n < w + · simp [Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt (Nat.pow_lt_pow_of_lt Nat.one_lt_two hn)] + · simp only [Nat.not_lt] at hn + rw [ushiftRight_eq_zero _ _ (by omega)] + simp [Nat.dvd_iff_mod_eq_zero.mp (Nat.pow_dvd_pow 2 hn)] + @[simp] theorem getMsbD_ushiftRight {x : BitVec w} {i n : Nat} : (x >>> n).getMsbD i = (decide (i < w) && (!decide (i < n) && x.getMsbD (i - n))) := by diff --git a/src/Init/Data/Int/Bitwise/Lemmas.lean b/src/Init/Data/Int/Bitwise/Lemmas.lean index d2741d67f05b..9b7b41597a88 100644 --- a/src/Init/Data/Int/Bitwise/Lemmas.lean +++ b/src/Init/Data/Int/Bitwise/Lemmas.lean @@ -34,4 +34,8 @@ theorem shiftRight_eq_div_pow (m : Int) (n : Nat) : theorem zero_shiftRight (n : Nat) : (0 : Int) >>> n = 0 := by simp [Int.shiftRight_eq_div_pow] +@[simp] +theorem shiftRight_zero (n : Int) : n >>> 0 = n := by + simp [Int.shiftRight_eq_div_pow] + end Int diff --git a/src/Init/Data/Nat/Bitwise/Basic.lean b/src/Init/Data/Nat/Bitwise/Basic.lean index 5395c15ffa0f..edc8d67656aa 100644 --- a/src/Init/Data/Nat/Bitwise/Basic.lean +++ b/src/Init/Data/Nat/Bitwise/Basic.lean @@ -71,6 +71,9 @@ theorem shiftRight_eq_div_pow (m : Nat) : ∀ n, m >>> n = m / 2 ^ n rw [shiftRight_add, shiftRight_eq_div_pow m k] simp [Nat.div_div_eq_div_mul, ← Nat.pow_succ, shiftRight_succ] +theorem shiftRight_eq_zero (m n : Nat) (hn : m < 2^n) : m >>> n = 0 := by + simp [Nat.shiftRight_eq_div_pow, Nat.div_eq_of_lt hn] + /-! ### testBit We define an operation for testing individual bits in the binary representation