Skip to content

Commit

Permalink
chore: cleanup and shorten names in Array.Merge (leanprover-community…
Browse files Browse the repository at this point in the history
  • Loading branch information
digama0 authored Apr 24, 2024
1 parent db73659 commit 0ad0ebc
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 82 deletions.
137 changes: 56 additions & 81 deletions Std/Data/Array/Merge.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,140 +9,115 @@ import Std.Data.Nat.Lemmas
namespace Array

/--
Merge arrays `xs` and `ys`, which must be sorted according to `compare`. The
result is sorted as well. If two (or more) elements are equal according to
`compare`, they are preserved.
`O(|xs| + |ys|)`. Merge arrays `xs` and `ys`. If the arrays are sorted according to `lt`, then the
result is sorted as well. If two (or more) elements are equal according to `lt`, they are preserved.
-/
def mergeSortedPreservingDuplicates [ord : Ord α] (xs ys : Array α) :
Array α :=
let acc := Array.mkEmpty (xs.size + ys.size)
go acc 0 0
def merge (lt : α → α → Bool) (xs ys : Array α) : Array α :=
go (Array.mkEmpty (xs.size + ys.size)) 0 0
where
/-- Auxiliary definition for `mergeSortedPreservingDuplicates`. -/
/-- Auxiliary definition for `merge`. -/
go (acc : Array α) (i j : Nat) : Array α :=
if hi : i ≥ xs.size then
acc ++ ys[j:]
else if hj : j ≥ ys.size then
acc ++ xs[i:]
else
have hi : i < xs.size := Nat.lt_of_not_le hi
have hj : j < ys.size := Nat.lt_of_not_le hj
have hij : i + j < xs.size + ys.size := Nat.add_lt_add hi hj
let x := xs[i]
let y := ys[j]
if compare x y |>.isLE then
have : xs.size + ys.size - (i + 1 + j) < xs.size + ys.size - (i + j) := by
rw [show i + 1 + j = i + j + 1 by simp_arith]
exact Nat.sub_succ_lt_self _ _ hij
go (acc.push x) (i + 1) j
else
have : xs.size + ys.size - (i + j + 1) < xs.size + ys.size - (i + j) :=
Nat.sub_succ_lt_self _ _ hij
go (acc.push y) i (j + 1)
if lt x y then go (acc.push x) (i + 1) j else go (acc.push y) i (j + 1)
termination_by xs.size + ys.size - (i + j)

set_option linter.unusedVariables false in
@[deprecated merge, inherit_doc merge]
def mergeSortedPreservingDuplicates [ord : Ord α] (xs ys : Array α) : Array α :=
merge (compare · · |>.isLT) xs ys

/--
Merge arrays `xs` and `ys`, which must be sorted according to `compare` and must
not contain duplicates. Equal elements are merged using `merge`. If `merge`
respects the order (i.e. for all `x`, `y`, `y'`, `z`, if `x < y < z` and
`x < y' < z` then `x < merge y y' < z`) then the resulting array is again
sorted.
`O(|xs| + |ys|)`. Merge arrays `xs` and `ys`, which must be sorted according to `compare` and must
not contain duplicates. Equal elements are merged using `merge`. If `merge` respects the order
(i.e. for all `x`, `y`, `y'`, `z`, if `x < y < z` and `x < y' < z` then `x < merge y y' < z`)
then the resulting array is again sorted.
-/
def mergeSortedMergingDuplicates [ord : Ord α] (xs ys : Array α)
(merge : α → α → α) : Array α :=
let acc := Array.mkEmpty (xs.size + ys.size)
go acc 0 0
def mergeDedupWith [ord : Ord α] (xs ys : Array α) (merge : α → α → α) : Array α :=
go (Array.mkEmpty (xs.size + ys.size)) 0 0
where
/-- Auxiliary definition for `mergeSortedMergingDuplicates`. -/
/-- Auxiliary definition for `mergeDedupWith`. -/
go (acc : Array α) (i j : Nat) : Array α :=
if hi : i ≥ xs.size then
acc ++ ys[j:]
else if hj : j ≥ ys.size then
acc ++ xs[i:]
else
have hi : i < xs.size := Nat.lt_of_not_le hi
have hj : j < ys.size := Nat.lt_of_not_le hj
have hij : i + j < xs.size + ys.size := Nat.add_lt_add hi hj
let x := xs[i]
let y := ys[j]
match compare x y with
| Ordering.lt =>
have : xs.size + ys.size - (i + 1 + j) < xs.size + ys.size - (i + j) := by
rw [show i + 1 + j = i + j + 1 by simp_arith]
exact Nat.sub_succ_lt_self _ _ hij
go (acc.push x) (i + 1) j
| Ordering.gt =>
have : xs.size + ys.size - (i + j + 1) < xs.size + ys.size - (i + j) :=
Nat.sub_succ_lt_self _ _ hij
go (acc.push y) i (j + 1)
| Ordering.eq =>
have : xs.size + ys.size - (i + 1 + (j + 1)) < xs.size + ys.size - (i + j) := by
rw [show i + 1 + (j + 1) = i + j + 2 by simp_arith]
apply Nat.sub_add_lt_sub _ (by decide)
rw [show i + j + 2 = (i + 1) + (j + 1) by simp_arith]
exact Nat.add_le_add hi hj
go (acc.push (merge x y)) (i + 1) (j + 1)
termination_by xs.size + ys.size - (i + j)
| .lt => go (acc.push x) (i + 1) j
| .gt => go (acc.push y) i (j + 1)
| .eq => go (acc.push (merge x y)) (i + 1) (j + 1)
termination_by xs.size + ys.size - (i + j)

@[deprecated] alias mergeSortedMergingDuplicates := mergeDedupWith

/--
Merge arrays `xs` and `ys`, which must be sorted according to `compare` and must
not contain duplicates. If an element appears in both `xs` and `ys`, only one
copy is kept.
`O(|xs| + |ys|)`. Merge arrays `xs` and `ys`, which must be sorted according to `compare` and must
not contain duplicates. If an element appears in both `xs` and `ys`, only one copy is kept.
-/
@[inline]
def mergeSortedDeduplicating [ord : Ord α] (xs ys : Array α) : Array α :=
mergeSortedMergingDuplicates (ord := ord) xs ys fun x _ => x
@[inline] def mergeDedup [ord : Ord α] (xs ys : Array α) : Array α :=
mergeDedupWith (ord := ord) xs ys fun x _ => x

@[deprecated] alias mergeSortedDeduplicating := mergeDedup

set_option linter.unusedVariables false in
/--
Merge `xs` and `ys`, which do not need to be sorted. Elements which occur in
both `xs` and `ys` are only added once. If `xs` and `ys` do not contain
duplicates, then neither does the result. O(n*m)!
`O(|xs| * |ys|)`. Merge `xs` and `ys`, which do not need to be sorted. Elements which occur in
both `xs` and `ys` are only added once. If `xs` and `ys` do not contain duplicates, then neither
does the result.
-/
def mergeUnsortedDeduplicating [eq : BEq α] (xs ys : Array α) : Array α :=
def mergeUnsortedDedup [eq : BEq α] (xs ys : Array α) : Array α :=
-- Ideally we would check whether `xs` or `ys` have spare capacity, to prevent
-- copying if possible. But Lean arrays don't expose their capacity.
if xs.size < ys.size then go ys xs else go xs ys
where
/-- Auxiliary definition for `mergeUnsortedDeduplicating`. -/
@[inline]
go (xs ys : Array α) :=
/-- Auxiliary definition for `mergeUnsortedDedup`. -/
@[inline] go (xs ys : Array α) :=
let xsSize := xs.size
ys.foldl (init := xs) fun xs y =>
if xs.any (· == y) (stop := xsSize) then xs else xs.push y

@[deprecated] alias mergeUnsortedDeduplicating := mergeUnsortedDedup

/--
Replace each run `[x₁, ⋯, xₙ]` of equal elements in `xs` with
`O(|xs|)`. Replace each run `[x₁, ⋯, xₙ]` of equal elements in `xs` with
`f ⋯ (f (f x₁ x₂) x₃) ⋯ xₙ`.
-/
def mergeAdjacentDuplicates [eq : BEq α] (f : α → α → α) (xs : Array α) :
Array α :=
if h : 0 < xs.size then go #[] 1 (xs.get ⟨0, h⟩) else xs
def mergeAdjacentDups [eq : BEq α] (f : α → α → α) (xs : Array α) : Array α :=
if h : 0 < xs.size then go (mkEmpty xs.size) 1 (xs.get ⟨0, h⟩) else xs
where
/-- Auxiliary definition for `mergeAdjacentDuplicates`. -/
/-- Auxiliary definition for `mergeAdjacentDups`. -/
go (acc : Array α) (i : Nat) (hd : α) :=
if h : i < xs.size then
let x := xs[i]
if x == hd then
go acc (i + 1) (f hd x)
else
go (acc.push hd) (i + 1) x
if x == hd then go acc (i + 1) (f hd x) else go (acc.push hd) (i + 1) x
else
acc.push hd
termination_by xs.size - i

/--
Deduplicate a sorted array. The array must be sorted with to an order which
agrees with `==`, i.e. whenever `x == y` then `compare x y == .eq`.
-/
def deduplicateSorted [eq : BEq α] (xs : Array α) : Array α :=
xs.mergeAdjacentDuplicates (eq := eq) fun x _ => x
@[deprecated] alias mergeAdjacentDuplicates := mergeAdjacentDups

/--
Sort and deduplicate an array.
`O(|xs|)`. Deduplicate a sorted array. The array must be sorted with to an order which agrees with
`==`, i.e. whenever `x == y` then `compare x y == .eq`.
-/
def sortAndDeduplicate [ord : Ord α] (xs : Array α) : Array α :=
def dedupSorted [eq : BEq α] (xs : Array α) : Array α :=
xs.mergeAdjacentDups (eq := eq) fun x _ => x

@[deprecated] alias deduplicateSorted := dedupSorted

/-- `O(|xs| log |xs|)`. Sort and deduplicate an array. -/
def sortDedup [ord : Ord α] (xs : Array α) : Array α :=
have := ord.toBEq
deduplicateSorted <| xs.qsort (compare · · |>.isLT)
dedupSorted <| xs.qsort (compare · · |>.isLT)

@[deprecated] alias sortAndDeduplicate := sortDedup

end Array
2 changes: 1 addition & 1 deletion Std/Lean/Meta/DiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where
/-- Auxiliary definition for `mergePreservingDuplicates`. -/
mergeChildren (cs₁ cs₂ : Array (Key × Trie α)) :
Array (Key × Trie α) :=
Array.mergeSortedMergingDuplicates
Array.mergeDedupWith
(ord := ⟨compareOn (·.fst)⟩) cs₁ cs₂
(fun (k₁, t₁) (_, t₂) => (k₁, mergePreservingDuplicates t₁ t₂))

Expand Down

0 comments on commit 0ad0ebc

Please sign in to comment.