diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 0d6ef271a5..b2095e3854 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -1,13 +1,13 @@ use vortex_error::VortexResult; use crate::array::primitive::compute::PrimitiveTrait; -use crate::compute::search_sorted::SearchSorted; +use crate::compute::search_sorted::{SearchResult, SearchSorted}; use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide}; use crate::ptype::NativePType; use crate::scalar::Scalar; impl SearchSortedFn for &dyn PrimitiveTrait { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { let pvalue: T = value.try_into()?; Ok(self.typed_data().search_sorted(&pvalue, side)) } @@ -24,19 +24,27 @@ mod test { let values = vec![1u16, 2, 3].into_array(); assert_eq!( - search_sorted(&values, 0, SearchSortedSide::Left).unwrap(), + search_sorted(&values, 0, SearchSortedSide::Left) + .unwrap() + .to_index(), 0 ); assert_eq!( - search_sorted(&values, 1, SearchSortedSide::Left).unwrap(), + search_sorted(&values, 1, SearchSortedSide::Left) + .unwrap() + .to_index(), 0 ); assert_eq!( - search_sorted(&values, 1, SearchSortedSide::Right).unwrap(), + search_sorted(&values, 1, SearchSortedSide::Right) + .unwrap() + .to_index(), 1 ); assert_eq!( - search_sorted(&values, 4, SearchSortedSide::Left).unwrap(), + search_sorted(&values, 4, SearchSortedSide::Left) + .unwrap() + .to_index(), 3 ); } diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index aff8f77150..62f84dc2b9 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -11,7 +11,6 @@ use crate::array::{Array, ArrayRef}; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; -use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::slice::SliceFn; use crate::compute::take::{take, TakeFn}; use crate::compute::ArrayCompute; @@ -182,51 +181,21 @@ fn take_search_sorted( array: &SparseArray, indices: &PrimitiveArray, ) -> VortexResult<(PrimitiveArray, PrimitiveArray)> { - // adjust the input indices (to take) by the internal index offset of the array - let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| { - indices.typed_data::<$P>() - .iter() - .map(|i| *i as usize + array.indices_offset()) - .collect::>() - }); - - // TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work - // search_sorted for the adjusted indices (need to validate that they are an exact match still) - let physical_indices = adjusted_indices - .iter() - .map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64)) - .collect::>>()?; - - // filter out indices that are out of bounds, which will cause the take to fail - let (adjusted_indices, physical_indices): (Vec, Vec) = adjusted_indices - .iter() - .zip_eq(physical_indices) - .filter(|(_, phys_idx)| *phys_idx < array.indices().len() as u64) - .unzip(); - - let physical_indices = PrimitiveArray::from(physical_indices); - let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?; - let exact_matches: Vec = match_each_integer_ptype!(taken_indices.ptype(), |$P| { - taken_indices + let resolved = match_each_integer_ptype!(indices.ptype(), |$P| { + indices .typed_data::<$P>() .iter() - .zip_eq(adjusted_indices) - .map(|(taken_idx, adj_idx)| *taken_idx as usize == adj_idx) - .collect() + .enumerate() + .map(|(pos, i)| { + array + .find_index(*i as usize) + .map(|r| r.map(|ii| (pos as u64, ii as u64))) + }) + .filter_map_ok(|r| r) + .collect::>>()? }); - let (positions, patch_indices): (Vec, Vec) = physical_indices - .typed_data::() - .iter() - .enumerate() - .filter_map(|(i, phy_idx)| { - // search_sorted != binary search, so we need to filter out indices that weren't found - if exact_matches[i] { - Some((i as u64, *phy_idx)) - } else { - None - } - }) - .unzip(); + + let (positions, patch_indices): (Vec, Vec) = resolved.into_iter().unzip(); Ok(( PrimitiveArray::from(positions), PrimitiveArray::from(patch_indices), diff --git a/vortex-array/src/array/sparse/compute/slice.rs b/vortex-array/src/array/sparse/compute/slice.rs index 1870c4756f..29c2ee43a0 100644 --- a/vortex-array/src/array/sparse/compute/slice.rs +++ b/vortex-array/src/array/sparse/compute/slice.rs @@ -8,8 +8,10 @@ use crate::compute::slice::{slice, SliceFn}; impl SliceFn for SparseArray { fn slice(&self, start: usize, stop: usize) -> VortexResult { // Find the index of the first patch index that is greater than or equal to the offset of this array - let index_start_index = search_sorted(self.indices(), start, SearchSortedSide::Left)?; - let index_end_index = search_sorted(self.indices(), stop, SearchSortedSide::Left)?; + let index_start_index = + search_sorted(self.indices(), start, SearchSortedSide::Left)?.to_index(); + let index_end_index = + search_sorted(self.indices(), stop, SearchSortedSide::Left)?.to_index(); Ok(SparseArray::try_new_with_offset( slice(self.indices(), index_start_index, index_end_index)?, diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index f1d5622160..596b596f49 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -8,7 +8,6 @@ use crate::array::constant::ConstantArray; use crate::array::{Array, ArrayRef}; use crate::compress::EncodingCompression; use crate::compute::flatten::flatten_primitive; -use crate::compute::scalar_at::scalar_at; use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::ArrayCompute; use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS}; @@ -92,22 +91,12 @@ impl SparseArray { /// Returns the position of a given index in the indices array if it exists. pub fn find_index(&self, index: usize) -> VortexResult> { - let true_index = self.indices_offset + index; - - // TODO(ngates): replace this with a binary search that tells us if we get an exact match. - let idx = search_sorted(self.indices(), true_index, SearchSortedSide::Left)?; - if idx >= self.indices().len() { - return Ok(None); - } - - // If the value at this index is equal to the true index, then it exists in the - // indices array. - let patch_index: usize = scalar_at(self.indices(), idx)?.try_into()?; - if true_index == patch_index { - Ok(Some(idx)) - } else { - Ok(None) - } + search_sorted( + self.indices(), + self.indices_offset + index, + SearchSortedSide::Left, + ) + .map(|r| r.to_found()) } /// Return indices as a vector of usize with the indices_offset applied. diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 90f21b8023..469452aa54 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -13,15 +13,37 @@ pub enum SearchSortedSide { Right, } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum SearchResult { + Found(usize), + NotFound(usize), +} + +impl SearchResult { + pub fn to_found(self) -> Option { + match self { + SearchResult::Found(i) => Some(i), + SearchResult::NotFound(_) => None, + } + } + + pub fn to_index(self) -> usize { + match self { + SearchResult::Found(i) => i, + SearchResult::NotFound(i) => i, + } + } +} + pub trait SearchSortedFn { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; } pub fn search_sorted>( array: &dyn Array, target: T, side: SearchSortedSide, -) -> VortexResult { +) -> VortexResult { let scalar = target.into().cast(array.dtype())?; array.with_compute(|c| { if let Some(search_sorted) = c.search_sorted() { @@ -65,56 +87,97 @@ pub trait Len { } pub trait SearchSorted { - fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize + fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult where Self: IndexOrd, { match side { - SearchSortedSide::Left => self.search_sorted_by(|idx| { - if self.index_lt(idx, value) { - Less - } else { - Greater - } - }), - SearchSortedSide::Right => self.search_sorted_by(|idx| { - if self.index_le(idx, value) { - Less - } else { - Greater - } - }), + SearchSortedSide::Left => self.search_sorted_by( + |idx| self.index_cmp(idx, value).unwrap_or(Less), + |idx| { + if self.index_lt(idx, value) { + Less + } else { + Greater + } + }, + side, + ), + SearchSortedSide::Right => self.search_sorted_by( + |idx| self.index_cmp(idx, value).unwrap_or(Less), + |idx| { + if self.index_le(idx, value) { + Less + } else { + Greater + } + }, + side, + ), } } - fn search_sorted_by Ordering>(&self, f: F) -> usize; + /// find function is used to find the element if it exists, if element exists side_find will be used to find desired index amongst equal values + fn search_sorted_by Ordering, N: FnMut(usize) -> Ordering>( + &self, + find: F, + side_find: N, + side: SearchSortedSide, + ) -> SearchResult; } impl + Len + ?Sized, T> SearchSorted for S { - // Code adapted from Rust standard library slice::binary_search_by - fn search_sorted_by Ordering>(&self, mut f: F) -> usize { - // INVARIANTS: - // - 0 <= left <= left + size = right <= self.len() - // - f returns Less for everything in self[..left] - // - f returns Greater for everything in self[right..] - let mut size = self.len(); - let mut left = 0; - let mut right = size; - while left < right { - let mid = left + size / 2; - let cmp = f(mid); - - left = if cmp == Less { mid + 1 } else { left }; - right = if cmp == Greater { mid } else { right }; - if cmp == Equal { - return mid; + fn search_sorted_by Ordering, N: FnMut(usize) -> Ordering>( + &self, + find: F, + side_find: N, + side: SearchSortedSide, + ) -> SearchResult { + match search_sorted_side_idx(find, 0, self.len()) { + SearchResult::Found(found) => { + let idx_search = match side { + SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found), + SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()), + }; + match idx_search { + SearchResult::NotFound(i) => SearchResult::Found(i), + _ => unreachable!( + "searching amongst equal values should never return Found result" + ), + } } + s => s, + } + } +} + +// Code adapted from Rust standard library slice::binary_search_by +fn search_sorted_side_idx Ordering>( + mut find: F, + from: usize, + to: usize, +) -> SearchResult { + // INVARIANTS: + // - from <= left <= left + size = right <= to + // - f returns Less for everything in self[..left] + // - f returns Greater for everything in self[right..] + let mut size = to - from; + let mut left = from; + let mut right = to; + while left < right { + let mid = left + size / 2; + let cmp = find(mid); - size = right - left; + left = if cmp == Less { mid + 1 } else { left }; + right = if cmp == Greater { mid } else { right }; + if cmp == Equal { + return SearchResult::Found(mid); } - left + size = right - left; } + + SearchResult::NotFound(left) } impl IndexOrd for &dyn Array { @@ -142,3 +205,56 @@ impl Len for [T] { self.len() } } + +#[cfg(test)] +mod test { + use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide}; + + #[test] + fn left_side_equal() { + let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&2, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 2); + assert_eq!(res, SearchResult::Found(2)); + } + + #[test] + fn right_side_equal() { + let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&2, SearchSortedSide::Right); + assert_eq!(arr[res.to_index() - 1], 2); + assert_eq!(res, SearchResult::Found(6)); + } + + #[test] + fn left_side_equal_beginning() { + let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&0, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 0); + assert_eq!(res, SearchResult::Found(0)); + } + + #[test] + fn right_side_equal_beginning() { + let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&0, SearchSortedSide::Right); + assert_eq!(arr[res.to_index() - 1], 0); + assert_eq!(res, SearchResult::Found(4)); + } + + #[test] + fn left_side_equal_end() { + let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; + let res = arr.search_sorted(&9, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 9); + assert_eq!(res, SearchResult::Found(9)); + } + + #[test] + fn right_side_equal_end() { + let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; + let res = arr.search_sorted(&9, SearchSortedSide::Right); + assert_eq!(arr[res.to_index() - 1], 9); + assert_eq!(res, SearchResult::Found(13)); + } +} diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 3ea323bda8..c87d112d83 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{Array, ArrayKind, ArrayRef}; use vortex::compress::EncodingCompression; use vortex::compute::scalar_at::scalar_at; -use vortex::compute::search_sorted::SearchSortedSide; +use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; use vortex::compute::ArrayCompute; use vortex::encoding::{Encoding, EncodingId, EncodingRef}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -12,7 +12,7 @@ use vortex::stats::{Stat, Stats, StatsCompute, StatsSet}; use vortex::validity::Validity; use vortex::validity::{OwnedValidity, ValidityView}; use vortex::view::{AsView, ToOwnedView}; -use vortex::{compute, impl_array, ArrayWalker}; +use vortex::{impl_array, ArrayWalker}; use vortex_error::{vortex_bail, vortex_err, VortexResult}; use vortex_schema::DType; @@ -68,11 +68,8 @@ impl REEArray { } pub fn find_physical_index(&self, index: usize) -> VortexResult { - compute::search_sorted::search_sorted( - self.ends(), - index + self.offset, - SearchSortedSide::Right, - ) + search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right) + .map(|s| s.to_index()) } pub fn encode(array: &dyn Array) -> VortexResult {