Skip to content

Commit

Permalink
SearchSorted can return whether search resulted in exact match (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Apr 12, 2024
1 parent 2c2f6df commit 5e623f9
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 112 deletions.
20 changes: 14 additions & 6 deletions vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use vortex_error::VortexResult;

use crate::array::primitive::compute::PrimitiveTrait;
use crate::compute::search_sorted::SearchSorted;
use crate::compute::search_sorted::{SearchResult, SearchSorted};
use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide};
use crate::ptype::NativePType;
use crate::scalar::Scalar;

impl<T: NativePType> SearchSortedFn for &dyn PrimitiveTrait<T> {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<usize> {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
let pvalue: T = value.try_into()?;
Ok(self.typed_data().search_sorted(&pvalue, side))
}
Expand All @@ -24,19 +24,27 @@ mod test {
let values = vec![1u16, 2, 3].into_array();

assert_eq!(
search_sorted(&values, 0, SearchSortedSide::Left).unwrap(),
search_sorted(&values, 0, SearchSortedSide::Left)
.unwrap()
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 1, SearchSortedSide::Left).unwrap(),
search_sorted(&values, 1, SearchSortedSide::Left)
.unwrap()
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 1, SearchSortedSide::Right).unwrap(),
search_sorted(&values, 1, SearchSortedSide::Right)
.unwrap()
.to_index(),
1
);
assert_eq!(
search_sorted(&values, 4, SearchSortedSide::Left).unwrap(),
search_sorted(&values, 4, SearchSortedSide::Left)
.unwrap()
.to_index(),
3
);
}
Expand Down
55 changes: 12 additions & 43 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::array::{Array, ArrayRef};
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::slice::SliceFn;
use crate::compute::take::{take, TakeFn};
use crate::compute::ArrayCompute;
Expand Down Expand Up @@ -182,51 +181,21 @@ fn take_search_sorted(
array: &SparseArray,
indices: &PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
// adjust the input indices (to take) by the internal index offset of the array
let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
.map(|i| *i as usize + array.indices_offset())
.collect::<Vec<_>>()
});

// TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work
// search_sorted for the adjusted indices (need to validate that they are an exact match still)
let physical_indices = adjusted_indices
.iter()
.map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64))
.collect::<VortexResult<Vec<_>>>()?;

// filter out indices that are out of bounds, which will cause the take to fail
let (adjusted_indices, physical_indices): (Vec<usize>, Vec<u64>) = adjusted_indices
.iter()
.zip_eq(physical_indices)
.filter(|(_, phys_idx)| *phys_idx < array.indices().len() as u64)
.unzip();

let physical_indices = PrimitiveArray::from(physical_indices);
let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?;
let exact_matches: Vec<bool> = match_each_integer_ptype!(taken_indices.ptype(), |$P| {
taken_indices
let resolved = match_each_integer_ptype!(indices.ptype(), |$P| {
indices
.typed_data::<$P>()
.iter()
.zip_eq(adjusted_indices)
.map(|(taken_idx, adj_idx)| *taken_idx as usize == adj_idx)
.collect()
.enumerate()
.map(|(pos, i)| {
array
.find_index(*i as usize)
.map(|r| r.map(|ii| (pos as u64, ii as u64)))
})
.filter_map_ok(|r| r)
.collect::<VortexResult<Vec<_>>>()?
});
let (positions, patch_indices): (Vec<u64>, Vec<u64>) = physical_indices
.typed_data::<u64>()
.iter()
.enumerate()
.filter_map(|(i, phy_idx)| {
// search_sorted != binary search, so we need to filter out indices that weren't found
if exact_matches[i] {
Some((i as u64, *phy_idx))
} else {
None
}
})
.unzip();

let (positions, patch_indices): (Vec<u64>, Vec<u64>) = resolved.into_iter().unzip();
Ok((
PrimitiveArray::from(positions),
PrimitiveArray::from(patch_indices),
Expand Down
6 changes: 4 additions & 2 deletions vortex-array/src/array/sparse/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use crate::compute::slice::{slice, SliceFn};
impl SliceFn for SparseArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<ArrayRef> {
// Find the index of the first patch index that is greater than or equal to the offset of this array
let index_start_index = search_sorted(self.indices(), start, SearchSortedSide::Left)?;
let index_end_index = search_sorted(self.indices(), stop, SearchSortedSide::Left)?;
let index_start_index =
search_sorted(self.indices(), start, SearchSortedSide::Left)?.to_index();
let index_end_index =
search_sorted(self.indices(), stop, SearchSortedSide::Left)?.to_index();

Ok(SparseArray::try_new_with_offset(
slice(self.indices(), index_start_index, index_end_index)?,
Expand Down
23 changes: 6 additions & 17 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::array::constant::ConstantArray;
use crate::array::{Array, ArrayRef};
use crate::compress::EncodingCompression;
use crate::compute::flatten::flatten_primitive;
use crate::compute::scalar_at::scalar_at;
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::ArrayCompute;
use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS};
Expand Down Expand Up @@ -92,22 +91,12 @@ impl SparseArray {

/// Returns the position of a given index in the indices array if it exists.
pub fn find_index(&self, index: usize) -> VortexResult<Option<usize>> {
let true_index = self.indices_offset + index;

// TODO(ngates): replace this with a binary search that tells us if we get an exact match.
let idx = search_sorted(self.indices(), true_index, SearchSortedSide::Left)?;
if idx >= self.indices().len() {
return Ok(None);
}

// If the value at this index is equal to the true index, then it exists in the
// indices array.
let patch_index: usize = scalar_at(self.indices(), idx)?.try_into()?;
if true_index == patch_index {
Ok(Some(idx))
} else {
Ok(None)
}
search_sorted(
self.indices(),
self.indices_offset + index,
SearchSortedSide::Left,
)
.map(|r| r.to_found())
}

/// Return indices as a vector of usize with the indices_offset applied.
Expand Down
190 changes: 153 additions & 37 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,37 @@ pub enum SearchSortedSide {
Right,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SearchResult {
Found(usize),
NotFound(usize),
}

impl SearchResult {
pub fn to_found(self) -> Option<usize> {
match self {
SearchResult::Found(i) => Some(i),
SearchResult::NotFound(_) => None,
}
}

pub fn to_index(self) -> usize {
match self {
SearchResult::Found(i) => i,
SearchResult::NotFound(i) => i,
}
}
}

pub trait SearchSortedFn {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<usize>;
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult>;
}

pub fn search_sorted<T: Into<Scalar>>(
array: &dyn Array,
target: T,
side: SearchSortedSide,
) -> VortexResult<usize> {
) -> VortexResult<SearchResult> {
let scalar = target.into().cast(array.dtype())?;
array.with_compute(|c| {
if let Some(search_sorted) = c.search_sorted() {
Expand Down Expand Up @@ -65,56 +87,97 @@ pub trait Len {
}

pub trait SearchSorted<T> {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
where
Self: IndexOrd<T>,
{
match side {
SearchSortedSide::Left => self.search_sorted_by(|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Right => self.search_sorted_by(|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Left => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
},
side,
),
SearchSortedSide::Right => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
},
side,
),
}
}

fn search_sorted_by<F: FnMut(usize) -> Ordering>(&self, f: F) -> usize;
/// find function is used to find the element if it exists, if element exists side_find will be used to find desired index amongst equal values
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult;
}

impl<S: IndexOrd<T> + Len + ?Sized, T> SearchSorted<T> for S {
// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_by<F: FnMut(usize) -> Ordering>(&self, mut f: F) -> usize {
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
let cmp = f(mid);

left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return mid;
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult {
match search_sorted_side_idx(find, 0, self.len()) {
SearchResult::Found(found) => {
let idx_search = match side {
SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
};
match idx_search {
SearchResult::NotFound(i) => SearchResult::Found(i),
_ => unreachable!(
"searching amongst equal values should never return Found result"
),
}
}
s => s,
}
}
}

// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
mut find: F,
from: usize,
to: usize,
) -> SearchResult {
// INVARIANTS:
// - from <= left <= left + size = right <= to
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = to - from;
let mut left = from;
let mut right = to;
while left < right {
let mid = left + size / 2;
let cmp = find(mid);

size = right - left;
left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return SearchResult::Found(mid);
}

left
size = right - left;
}

SearchResult::NotFound(left)
}

impl IndexOrd<Scalar> for &dyn Array {
Expand Down Expand Up @@ -142,3 +205,56 @@ impl<T> Len for [T] {
self.len()
}
}

#[cfg(test)]
mod test {
use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};

#[test]
fn left_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 2);
assert_eq!(res, SearchResult::Found(2));
}

#[test]
fn right_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 2);
assert_eq!(res, SearchResult::Found(6));
}

#[test]
fn left_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 0);
assert_eq!(res, SearchResult::Found(0));
}

#[test]
fn right_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 0);
assert_eq!(res, SearchResult::Found(4));
}

#[test]
fn left_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 9);
assert_eq!(res, SearchResult::Found(9));
}

#[test]
fn right_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 9);
assert_eq!(res, SearchResult::Found(13));
}
}
Loading

0 comments on commit 5e623f9

Please sign in to comment.