Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: convert x.sdiv (2^k) into x.sshiftRight k #35

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,7 @@ theorem msb_abs {w : Nat} {x : BitVec w} :
and_intros
· by_cases h₃ : x = 0#w
· simp [h₃] at h₂
· simp [h₃]
· simp [h₁]
· simp [h₃, h₁]
· simp [h₂]
· simp [BitVec.msb, show w = 0 by omega]

Expand Down
216 changes: 216 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2323,6 +2323,27 @@ theorem toNat_udiv {x y : BitVec n} : (x / y).toNat = x.toNat / y.toNat := by
· rw [toNat_ofNat, Nat.mod_eq_of_lt]
exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) (by omega)


/-- If the LHS and RHS are both positive, then `(x udiv y)` is also positive -/
theorem msb_udiv_eq_false_of_msb_eq_false {x y : BitVec n} (hx : x.msb = false) (hy : y.msb = false) :
(x / y).msb = false := by
rw [msb_eq_decide, toNat_udiv]
rcases n with rfl | n
· simp [@of_length_zero x, @of_length_zero y]
· simp
have xLt := msb_eq_false_iff_two_mul_lt.mp hx
have yLt := msb_eq_false_iff_two_mul_lt.mp hy
apply Nat.lt_of_le_of_lt
· apply Nat.div_le_self
· omega

/-- If the LHS and RHS are both positive, then `(x udiv y).toInt` equals `(x udiv y).toNat` -/
theorem toInt_udiv_eq_toNat_udiv_of_msb_eq_false {x y : BitVec n} (hx : x.msb = false) (hy : y.msb = false) :
(x / y).toInt = (x / y).toNat:= by
rw [toInt_eq_msb_cond]
have := msb_udiv_eq_false_of_msb_eq_false hx hy
simp [this]

@[simp]
theorem zero_udiv {x : BitVec w} : (0#w) / x = 0#w := by
simp [bv_toNat]
Expand Down Expand Up @@ -2357,6 +2378,7 @@ theorem udiv_self {x : BitVec w} :
↓reduceIte, toNat_udiv]
rw [Nat.div_self (by omega), Nat.mod_eq_of_lt (by omega)]


/-! ### umod -/

theorem umod_def {x y : BitVec n} :
Expand Down Expand Up @@ -2460,6 +2482,9 @@ theorem sdiv_self {x : BitVec w} :
rcases x.msb with msb | msb <;> simp
· rcases x.msb with msb | msb <;> simp [h]

theorem sdiv_eq_udiv {x y : BitVec w} (hx : x.msb = false) (hy : y.msb = false) : x.sdiv y = x / y := by
simp [sdiv_eq, hx, hy]

/-! ### smtSDiv -/

theorem smtSDiv_eq (x y : BitVec w) : smtSDiv x y =
Expand Down Expand Up @@ -2817,6 +2842,7 @@ theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x
have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk
simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this]


/- ### cons -/

@[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by
Expand Down Expand Up @@ -2926,6 +2952,8 @@ theorem getElem_replicate {n w : Nat} (x : BitVec w) (h : i < w * n) :
/-- The bitvector of width `w` that has the smallest value when interpreted as an integer. -/
def intMin (w : Nat) := twoPow w (w - 1)

theorem intMin_eq (w : Nat) : twoPow w (w - 1) = intMin w := by rfl

theorem getLsbD_intMin (w : Nat) : (intMin w).getLsbD i = decide (i + 1 = w) := by
simp only [intMin, getLsbD_twoPow, boolToPropSimps]
omega
Expand Down Expand Up @@ -3006,6 +3034,67 @@ theorem msb_intMin {w : Nat} : (intMin w).msb = decide (0 < w) := by
simp only [msb_eq_decide, toNat_intMin, decide_eq_decide]
by_cases h : 0 < w <;> simp_all

/--
If `x = 0`, then `-x = 0`, and thus `x.msb = false`.
Otherwise, if `x = intMin w`, then `-x = intMin w`, and thus `x.msb = true`.
Otherwise, the `msb`is the negation of the msb of `x`.
-/
@[simp]
theorem msb_neg_eq_decide (x : BitVec w) : (- x).msb = (decide (x ≠ 0) && (!x.msb || (x = intMin w))) := by
rcases w with rfl | w
· simp [BitVec.eq_nil x]
· simp only [Nat.zero_lt_succ, decide_true, Bool.true_and, msb_eq_decide]
simp
by_cases hx : x = 0#(w + 1)
· simp [hx]
have : 0 < 2^w := Nat.pow_pos (by decide)
omega
· simp [hx]
have : 2^w < 2^(w + 1) := Nat.pow_lt_pow_of_lt (by simp) (by simp)
by_cases hx' : x.toNat = 2^w
· have : x = intMin (w + 1) := by
apply eq_of_toNat_eq
simp [hx']
rw [Nat.mod_eq_of_lt (by omega)]
subst this
simp
rw [Nat.mod_eq_of_lt (by omega)]
omega
· have : x ≠ intMin (w + 1) := by
rw [toNat_ne]
simp
rw [Nat.mod_eq_of_lt (by omega)]
assumption
simp [this]
by_cases hmsb : 2 ^ w ≤ x.toNat
· simp [hmsb]
omega
· simp [hmsb]
omega


/--
If both the numerator and denominator is positive,
then `(x.sdiv y).toInt` equals computing `x.toInt / y.toInt`.
-/
theorem toInt_sdiv_eq_toInt_div_toInt_of_msb_eq_false {x y : BitVec n} (hx : x.msb = false) (hy : y.msb = false) :
(x.sdiv y).toInt = x.toInt / y.toInt := by
rcases n with rfl | n
· simp [eq_nil x, eq_nil y]
· have hxLt := msb_eq_false_iff_two_mul_lt.mp hx
have hyLt := msb_eq_false_iff_two_mul_lt.mp hy
simp [sdiv_eq, hx, hy]
rw [toInt_eq_toNat_cond]
have xToInt : x.toInt = x.toNat := by simp [toInt_eq_msb_cond, hx]
have yToInt : y.toInt = y.toNat := by simp [toInt_eq_msb_cond, hy]
simp [xToInt, yToInt]
intros h
have : x.toNat / y.toNat < 2^n := by
apply Nat.lt_of_le_of_lt
· apply Nat.div_le_self
· omega
omega

/-! ### intMax -/

/-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/
Expand Down Expand Up @@ -3131,6 +3220,133 @@ theorem getMsbD_abs {i : Nat} {x : BitVec w} :
getMsbD (x.abs) i = if x.msb then getMsbD (-x) i else getMsbD x i := by
by_cases h : x.msb <;> simp [BitVec.abs, h]

/-
This characterizes signed division as unsigned division of the absolute value,
followed by an optional negation.
This is useful to port theorems from `udiv` into `sdiv`.
-/
theorem sdiv_eq_udiv_abs {x y : BitVec n} : x.sdiv y =
(if x.msb = y.msb then id else BitVec.neg) (x.abs / y.abs) := by
rcases n with rfl | n
· simp [eq_nil x, eq_nil y]
· rw [sdiv_eq]
by_cases hx : x.msb
have hxGe := msb_eq_true_iff_two_mul_ge.mp hx
· by_cases hy : y.msb
· have hyGe := msb_eq_true_iff_two_mul_ge.mp hy
simp [hx, hy]
apply eq_of_toNat_eq
simp [hx, hy]
congr
· apply Nat.mod_eq_of_lt (by omega)
· apply Nat.mod_eq_of_lt (by omega)
· simp at hy
have hyLt := msb_eq_false_iff_two_mul_lt.mp hy
simp [hx, hy]
apply eq_of_toNat_eq
simp [hx, hy]
congr
· apply Nat.mod_eq_of_lt (by omega)
· simp at hx
have hxLt := msb_eq_false_iff_two_mul_lt.mp hx
by_cases hy : y.msb
· have hyGe := msb_eq_true_iff_two_mul_ge.mp hy
simp at hy
simp [hx, hy]
apply eq_of_toNat_eq
simp [hx, hy]
congr
apply Nat.mod_eq_of_lt (by omega)
· simp at hx hy
apply eq_of_toNat_eq
simp [hx, hy]


theorem neq_zero_of_msb_eq_true (x : BitVec w) (hx : x.msb = true) : x ≠ 0 := by
have := msb_eq_true_iff_two_mul_ge.mp hx
rcases w with rfl | w
· simp at this
rw [@BitVec.of_length_zero x] at hx
simp at hx
· have : x.toNat ≥ 2^w := by omega
have : 0 < 2^w := by exact Nat.two_pow_pos w
apply toNat_ne.mpr
simp
omega

/--
The value of in`intMin` of width `w + 1` is equal to `2^w`.
This avoids the corner case when `w = 0`, where the `toNat` value is itself `0`.
-/
theorem toNat_intMin_of_lt (w : Nat) : (intMin (w + 1)).toNat = 2^w := by
simp
apply Nat.mod_eq_of_lt
apply Nat.pow_lt_pow_of_lt (by decide) (by omega)


theorem BitVec.udiv_eq_zero_of_lt (x y : BitVec w) (h : x.toNat < y.toNat ) : x / y = 0#w := by
apply eq_of_toNat_eq
simp
exact Nat.div_eq_of_lt h

theorem sdiv_twoPow_eq_neg {w : Nat} {x : BitVec w} : x.sdiv (twoPow w (w - 1)) =
if x = twoPow w (w - 1) then 1 else 0 := by
rcases w with rfl | w
· simp
apply of_length_zero
· by_cases hx : x = (twoPow (w + 1) (w + 1 - 1))
· rw [hx, intMin_eq]
simp
intro h
have : (intMin (w + 1)).msb = (BitVec.ofNat (w + 1) 0).msb := by simp [h]
simp [msb_intMin] at this
· rw [intMin_eq] at hx
rw [intMin_eq]
simp [hx]
by_cases hx' : x.msb <;> rw [sdiv_eq] <;> simp [hx', msb_intMin]
· have hxMsbNeg : (-x).msb = false := by
simp [msb_neg_eq_decide x]
intros h
simp [hx', hx]
have := msb_eq_false_iff_two_mul_lt.mp hxMsbNeg
apply BitVec.udiv_eq_zero_of_lt
simp only [toNat_intMin_of_lt]
omega
· apply BitVec.udiv_eq_zero_of_lt
simp only [toNat_intMin_of_lt]
simp at hx'
have := msb_eq_false_iff_two_mul_lt.mp hx'
omega

theorem sdiv_twoPow_eq_sshiftRight_of_lt
{w : Nat} {x : BitVec w} {k : Nat} (hk : k < w - 1) :
x.sdiv (twoPow w k) =
if x.msb then -(-x) >>> k else x >>> k := by
by_cases hx : x.msb <;> simp at hx <;> simp [hx]
· simp [sdiv_eq, hx, show (k < w) by omega, show (k ≠ w - 1) by omega]
rw [udiv_twoPow_eq_of_lt]
omega
· rw [sdiv_eq_udiv]
· apply udiv_twoPow_eq_of_lt
simp [show k < w by omega]
· assumption
· rw [msb_twoPow]
simp [show k < w by omega]
omega

theorem sdiv_twoPow_eq
{w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) :
x.sdiv (twoPow w k) =
if k = w - 1 then (if x = twoPow w (w - 1) then 1 else 0)
else (if x.msb then -(-x) >>> k else x >>> k) := by
by_cases hk : k = w - 1
· simp [hk]
apply sdiv_twoPow_eq_neg
· have hk' : k < w - 1 := by omega
simp [hk]
apply sdiv_twoPow_eq_sshiftRight_of_lt
assumption

/-! ### Decidable quantifiers -/

theorem forall_zero_iff {P : BitVec 0 → Prop} :
Expand Down
Loading