From c09295c9a69b86215ef09885e65a7373a3c45475 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 20 Nov 2024 22:54:54 +0000 Subject: [PATCH] SearchSorted VTable (#1414) --- bench-vortex/src/bin/notimplemented.rs | 2 +- .../fastlanes/src/bitpacking/compute/mod.rs | 10 +- .../src/bitpacking/compute/search_sorted.rs | 48 +++-- encodings/fastlanes/src/for/compute.rs | 23 ++- .../src/array/constant/compute/mod.rs | 63 +----- .../array/constant/compute/search_sorted.rs | 63 ++++++ .../src/array/primitive/compute/mod.rs | 8 +- .../array/primitive/compute/search_sorted.rs | 33 +-- vortex-array/src/array/sparse/compute/mod.rs | 33 +-- vortex-array/src/compute/mod.rs | 14 +- vortex-array/src/compute/search_sorted.rs | 193 ++++++++++++------ 11 files changed, 300 insertions(+), 190 deletions(-) create mode 100644 vortex-array/src/array/constant/compute/search_sorted.rs diff --git a/bench-vortex/src/bin/notimplemented.rs b/bench-vortex/src/bin/notimplemented.rs index 5c7b644330..9d931ddbac 100644 --- a/bench-vortex/src/bin/notimplemented.rs +++ b/bench-vortex/src/bin/notimplemented.rs @@ -202,7 +202,7 @@ fn compute_funcs(encodings: &[ArrayData]) { impls.push(bool_to_cell( arr.with_dyn(|a| a.subtract_scalar().is_some()), )); - impls.push(bool_to_cell(arr.with_dyn(|a| a.search_sorted().is_some()))); + impls.push(bool_to_cell(arr.encoding().search_sorted_fn().is_some())); impls.push(bool_to_cell(arr.encoding().slice_fn().is_some())); impls.push(bool_to_cell(arr.encoding().take_fn().is_some())); table.add_row(Row::new(impls)); diff --git a/encodings/fastlanes/src/bitpacking/compute/mod.rs b/encodings/fastlanes/src/bitpacking/compute/mod.rs index 6562a3bf21..0d9b84deb6 100644 --- a/encodings/fastlanes/src/bitpacking/compute/mod.rs +++ b/encodings/fastlanes/src/bitpacking/compute/mod.rs @@ -12,11 +12,7 @@ mod search_sorted; mod slice; mod take; -impl ArrayCompute for BitPackedArray { - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - Some(self) - } -} +impl ArrayCompute for BitPackedArray {} impl ComputeVTable for BitPackedEncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { @@ -27,6 +23,10 @@ impl ComputeVTable for BitPackedEncoding { Some(self) } + fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn> { + Some(self) + } + fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } diff --git a/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs b/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs index 689847afc2..d91889f4e0 100644 --- a/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs +++ b/encodings/fastlanes/src/bitpacking/compute/search_sorted.rs @@ -17,46 +17,53 @@ use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; use vortex_error::{VortexError, VortexExpect as _, VortexResult}; use vortex_scalar::Scalar; -use crate::{unpack_single_primitive, BitPackedArray}; +use crate::{unpack_single_primitive, BitPackedArray, BitPackedEncoding}; -impl SearchSortedFn for BitPackedArray { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { - match_each_unsigned_integer_ptype!(self.ptype(), |$P| { - search_sorted_typed::<$P>(self, value, side) +impl SearchSortedFn for BitPackedEncoding { + fn search_sorted( + &self, + array: &BitPackedArray, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult { + match_each_unsigned_integer_ptype!(array.ptype(), |$P| { + search_sorted_typed::<$P>(array, value, side) }) } fn search_sorted_usize( &self, + array: &BitPackedArray, value: usize, side: SearchSortedSide, ) -> VortexResult { - match_each_unsigned_integer_ptype!(self.ptype(), |$P| { + match_each_unsigned_integer_ptype!(array.ptype(), |$P| { // NOTE: conversion may truncate silently. if let Some(pvalue) = num_traits::cast::(value) { - search_sorted_native(self, pvalue, side) + search_sorted_native(array, pvalue, side) } else { // provided u64 is too large to fit in the provided PType, value must be off // the right end of the array. - Ok(SearchResult::NotFound(self.len())) + Ok(SearchResult::NotFound(array.len())) } }) } fn search_sorted_many( &self, + array: &BitPackedArray, values: &[Scalar], sides: &[SearchSortedSide], ) -> VortexResult> { - match_each_unsigned_integer_ptype!(self.ptype(), |$P| { - let searcher = BitPackedSearch::<'_, $P>::new(self); + match_each_unsigned_integer_ptype!(array.ptype(), |$P| { + let searcher = BitPackedSearch::<'_, $P>::new(array); values .iter() .zip(sides.iter().copied()) .map(|(value, side)| { // Unwrap to native value - let unwrapped_value: $P = value.cast(self.dtype())?.try_into()?; + let unwrapped_value: $P = value.cast(array.dtype())?.try_into()?; Ok(searcher.search_sorted(&unwrapped_value, side)) }) @@ -66,11 +73,12 @@ impl SearchSortedFn for BitPackedArray { fn search_sorted_usize_many( &self, + array: &BitPackedArray, values: &[usize], sides: &[SearchSortedSide], ) -> VortexResult> { - match_each_unsigned_integer_ptype!(self.ptype(), |$P| { - let searcher = BitPackedSearch::<'_, $P>::new(self); + match_each_unsigned_integer_ptype!(array.ptype(), |$P| { + let searcher = BitPackedSearch::<'_, $P>::new(array); values .iter() @@ -200,7 +208,7 @@ impl Len for BitPackedSearch<'_, T> { mod test { use vortex_array::array::PrimitiveArray; use vortex_array::compute::{ - search_sorted, search_sorted_many, slice, SearchResult, SearchSortedFn, SearchSortedSide, + search_sorted, search_sorted_many, slice, SearchResult, SearchSortedSide, }; use vortex_array::IntoArrayData; use vortex_dtype::Nullability; @@ -261,12 +269,12 @@ mod test { ) .unwrap(); - let found = bitpacked - .search_sorted( - &Scalar::primitive(1u64, Nullability::Nullable), - SearchSortedSide::Left, - ) - .unwrap(); + let found = search_sorted( + bitpacked.as_ref(), + Scalar::primitive(1u64, Nullability::Nullable), + SearchSortedSide::Left, + ) + .unwrap(); assert_eq!(found, SearchResult::Found(0)); } diff --git a/encodings/fastlanes/src/for/compute.rs b/encodings/fastlanes/src/for/compute.rs index 03663c90b7..9ab725229d 100644 --- a/encodings/fastlanes/src/for/compute.rs +++ b/encodings/fastlanes/src/for/compute.rs @@ -14,11 +14,7 @@ use vortex_scalar::{PValue, Scalar}; use crate::{FoRArray, FoREncoding}; -impl ArrayCompute for FoRArray { - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - Some(self) - } -} +impl ArrayCompute for FoRArray {} impl ComputeVTable for FoREncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { @@ -29,6 +25,10 @@ impl ComputeVTable for FoREncoding { Some(self) } + fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn> { + Some(self) + } + fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } @@ -99,10 +99,15 @@ impl SliceFn for FoREncoding { } } -impl SearchSortedFn for FoRArray { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { - match_each_integer_ptype!(self.ptype(), |$P| { - search_sorted_typed::<$P>(self, value, side) +impl SearchSortedFn for FoREncoding { + fn search_sorted( + &self, + array: &FoRArray, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult { + match_each_integer_ptype!(array.ptype(), |$P| { + search_sorted_typed::<$P>(array, value, side) }) } } diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index 2e5926479a..e44c314d30 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -1,6 +1,5 @@ mod boolean; - -use std::cmp::Ordering; +mod search_sorted; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -10,7 +9,7 @@ use crate::array::ConstantEncoding; use crate::compute::unary::ScalarAtFn; use crate::compute::{ scalar_cmp, ArrayCompute, BinaryBooleanFn, ComputeVTable, FilterFn, FilterMask, MaybeCompareFn, - Operator, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions, + Operator, SearchSortedFn, SliceFn, TakeFn, TakeOptions, }; use crate::{ArrayData, ArrayLen, IntoArrayData}; @@ -18,10 +17,6 @@ impl ArrayCompute for ConstantArray { fn compare(&self, other: &ArrayData, operator: Operator) -> Option> { MaybeCompareFn::maybe_compare(self, other, operator) } - - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - Some(self) - } } impl ComputeVTable for ConstantEncoding { @@ -43,6 +38,10 @@ impl ComputeVTable for ConstantEncoding { Some(self) } + fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn> { + Some(self) + } + fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } @@ -81,23 +80,6 @@ impl FilterFn for ConstantEncoding { } } -impl SearchSortedFn for ConstantArray { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { - match self - .scalar_value() - .partial_cmp(value.value()) - .unwrap_or(Ordering::Less) - { - Ordering::Greater => Ok(SearchResult::NotFound(0)), - Ordering::Less => Ok(SearchResult::NotFound(self.len())), - Ordering::Equal => match side { - SearchSortedSide::Left => Ok(SearchResult::Found(0)), - SearchSortedSide::Right => Ok(SearchResult::Found(self.len())), - }, - } - } -} - impl MaybeCompareFn for ConstantArray { fn maybe_compare( &self, @@ -111,36 +93,3 @@ impl MaybeCompareFn for ConstantArray { }) } } - -#[cfg(test)] -mod test { - use crate::array::constant::ConstantArray; - use crate::compute::{search_sorted, SearchResult, SearchSortedSide}; - use crate::IntoArrayData; - - #[test] - pub fn search() { - let cst = ConstantArray::new(42, 5000).into_array(); - assert_eq!( - search_sorted(&cst, 33, SearchSortedSide::Left).unwrap(), - SearchResult::NotFound(0) - ); - assert_eq!( - search_sorted(&cst, 55, SearchSortedSide::Left).unwrap(), - SearchResult::NotFound(5000) - ); - } - - #[test] - pub fn search_equals() { - let cst = ConstantArray::new(42, 5000).into_array(); - assert_eq!( - search_sorted(&cst, 42, SearchSortedSide::Left).unwrap(), - SearchResult::Found(0) - ); - assert_eq!( - search_sorted(&cst, 42, SearchSortedSide::Right).unwrap(), - SearchResult::Found(5000) - ); - } -} diff --git a/vortex-array/src/array/constant/compute/search_sorted.rs b/vortex-array/src/array/constant/compute/search_sorted.rs new file mode 100644 index 0000000000..e6ffbb5237 --- /dev/null +++ b/vortex-array/src/array/constant/compute/search_sorted.rs @@ -0,0 +1,63 @@ +use std::cmp::Ordering; + +use vortex_error::VortexResult; +use vortex_scalar::Scalar; + +use crate::array::{ConstantArray, ConstantEncoding}; +use crate::compute::{SearchResult, SearchSortedFn, SearchSortedSide}; +use crate::ArrayLen; + +impl SearchSortedFn for ConstantEncoding { + fn search_sorted( + &self, + array: &ConstantArray, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult { + match array + .scalar_value() + .partial_cmp(value.value()) + .unwrap_or(Ordering::Less) + { + Ordering::Greater => Ok(SearchResult::NotFound(0)), + Ordering::Less => Ok(SearchResult::NotFound(array.len())), + Ordering::Equal => match side { + SearchSortedSide::Left => Ok(SearchResult::Found(0)), + SearchSortedSide::Right => Ok(SearchResult::Found(array.len())), + }, + } + } +} + +#[cfg(test)] +mod test { + use crate::array::constant::ConstantArray; + use crate::compute::{search_sorted, SearchResult, SearchSortedSide}; + use crate::IntoArrayData; + + #[test] + pub fn search() { + let cst = ConstantArray::new(42, 5000).into_array(); + assert_eq!( + search_sorted(&cst, 33, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(0) + ); + assert_eq!( + search_sorted(&cst, 55, SearchSortedSide::Left).unwrap(), + SearchResult::NotFound(5000) + ); + } + + #[test] + pub fn search_equals() { + let cst = ConstantArray::new(42, 5000).into_array(); + assert_eq!( + search_sorted(&cst, 42, SearchSortedSide::Left).unwrap(), + SearchResult::Found(0) + ); + assert_eq!( + search_sorted(&cst, 42, SearchSortedSide::Right).unwrap(), + SearchResult::Found(5000) + ); + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index eb67205ff5..4f8f95642e 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -27,10 +27,6 @@ impl ArrayCompute for PrimitiveArray { fn subtract_scalar(&self) -> Option<&dyn SubtractScalarFn> { Some(self) } - - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - Some(self) - } } impl ComputeVTable for PrimitiveEncoding { @@ -50,6 +46,10 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } + fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn> { + Some(self) + } + fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 9bb7589840..11a390d649 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -6,23 +6,29 @@ use vortex_error::VortexResult; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; +use crate::array::PrimitiveEncoding; use crate::compute::{IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide}; use crate::validity::Validity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayLen}; -impl SearchSortedFn for PrimitiveArray { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { - match_each_native_ptype!(self.ptype(), |$T| { - match self.validity() { +impl SearchSortedFn for PrimitiveEncoding { + fn search_sorted( + &self, + array: &PrimitiveArray, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult { + match_each_native_ptype!(array.ptype(), |$T| { + match array.validity() { Validity::NonNullable | Validity::AllValid => { - let pvalue: $T = value.cast(self.dtype())?.try_into()?; - Ok(SearchSortedPrimitive::new(self).search_sorted(&pvalue, side)) + let pvalue: $T = value.cast(array.dtype())?.try_into()?; + Ok(SearchSortedPrimitive::new(array).search_sorted(&pvalue, side)) } Validity::AllInvalid => Ok(SearchResult::NotFound(0)), Validity::Array(_) => { - let pvalue: $T = value.cast(self.dtype())?.try_into()?; - Ok(SearchSortedNullsLast::new(self).search_sorted(&pvalue, side)) + let pvalue: $T = value.cast(array.dtype())?.try_into()?; + Ok(SearchSortedNullsLast::new(array).search_sorted(&pvalue, side)) } } }) @@ -31,26 +37,27 @@ impl SearchSortedFn for PrimitiveArray { #[allow(clippy::cognitive_complexity)] fn search_sorted_usize( &self, + array: &PrimitiveArray, value: usize, side: SearchSortedSide, ) -> VortexResult { - match_each_native_ptype!(self.ptype(), |$T| { + match_each_native_ptype!(array.ptype(), |$T| { if let Some(pvalue) = num_traits::cast::(value) { - match self.validity() { + match array.validity() { Validity::NonNullable | Validity::AllValid => { // null-free search - Ok(SearchSortedPrimitive::new(self).search_sorted(&pvalue, side)) + Ok(SearchSortedPrimitive::new(array).search_sorted(&pvalue, side)) } Validity::AllInvalid => Ok(SearchResult::NotFound(0)), Validity::Array(_) => { // null-aware search - Ok(SearchSortedNullsLast::new(self).search_sorted(&pvalue, side)) + Ok(SearchSortedNullsLast::new(array).search_sorted(&pvalue, side)) } } } else { // provided u64 is too large to fit in the provided PType, value must be off // the right end of the array. - Ok(SearchResult::NotFound(self.len())) + Ok(SearchResult::NotFound(array.len())) } }) } diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 1775ba2f1c..bb88ab3a9b 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -15,19 +15,21 @@ use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; mod slice; mod take; -impl ArrayCompute for SparseArray { - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - Some(self) - } -} +impl ArrayCompute for SparseArray {} impl ComputeVTable for SparseEncoding { fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } + + fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn> { + Some(self) + } + fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } @@ -46,21 +48,26 @@ impl ScalarAtFn for SparseEncoding { } } -impl SearchSortedFn for SparseArray { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { - search_sorted(&self.values(), value.clone(), side).and_then(|sr| { - let sidx = sr.to_offsets_index(self.metadata().indices_len); - let index: usize = scalar_at(self.indices(), sidx)?.as_ref().try_into()?; +impl SearchSortedFn for SparseEncoding { + fn search_sorted( + &self, + array: &SparseArray, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult { + search_sorted(&array.values(), value.clone(), side).and_then(|sr| { + let sidx = sr.to_offsets_index(array.metadata().indices_len); + let index: usize = scalar_at(array.indices(), sidx)?.as_ref().try_into()?; Ok(match sr { SearchResult::Found(i) => SearchResult::Found( - if i == self.metadata().indices_len { + if i == array.metadata().indices_len { index + 1 } else { index - } - self.indices_offset(), + } - array.indices_offset(), ), SearchResult::NotFound(i) => SearchResult::NotFound( - if i == 0 { index } else { index + 1 } - self.indices_offset(), + if i == 0 { index } else { index + 1 } - array.indices_offset(), ), }) }) diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 6013efd055..18259bb367 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -69,6 +69,13 @@ pub trait ComputeVTable { None } + /// Perform a search over an ordered array. + /// + /// See: [SearchSortedFn]. + fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn> { + None + } + /// Perform zero-copy slicing of an array. /// /// See: [SliceFn]. @@ -100,11 +107,4 @@ pub trait ArrayCompute { fn subtract_scalar(&self) -> Option<&dyn SubtractScalarFn> { None } - - /// Perform a search over an ordered array. - /// - /// See: [SearchSortedFn]. - fn search_sorted(&self) -> Option<&dyn SearchSortedFn> { - None - } } diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 36ff3c1d2d..31dfe877be 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -4,10 +4,11 @@ use std::fmt::{Debug, Display, Formatter}; use std::hint; use itertools::Itertools; -use vortex_error::{vortex_bail, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_scalar::Scalar; use crate::compute::unary::scalar_at; +use crate::encoding::Encoding; use crate::{ArrayDType, ArrayData}; #[derive(Debug, Copy, Clone)] @@ -97,33 +98,41 @@ impl Display for SearchResult { /// Searches for value assuming the array is sorted. /// /// For nullable arrays we assume that the nulls are sorted last, i.e. they're the greatest value -pub trait SearchSortedFn { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; +pub trait SearchSortedFn { + fn search_sorted( + &self, + array: &Array, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult; fn search_sorted_usize( &self, + array: &Array, value: usize, side: SearchSortedSide, ) -> VortexResult { let usize_scalar = Scalar::from(value); - self.search_sorted(&usize_scalar, side) + self.search_sorted(array, &usize_scalar, side) } /// Bulk search for many values. fn search_sorted_many( &self, + array: &Array, values: &[Scalar], sides: &[SearchSortedSide], ) -> VortexResult> { values .iter() .zip(sides.iter()) - .map(|(value, side)| self.search_sorted(value, *side)) + .map(|(value, side)| self.search_sorted(array, value, *side)) .try_collect() } fn search_sorted_usize_many( &self, + array: &Array, values: &[usize], sides: &[SearchSortedSide], ) -> VortexResult> { @@ -131,11 +140,77 @@ pub trait SearchSortedFn { .iter() .copied() .zip(sides.iter().copied()) - .map(|(value, side)| self.search_sorted_usize(value, side)) + .map(|(value, side)| self.search_sorted_usize(array, value, side)) .try_collect() } } +impl SearchSortedFn for E +where + E: SearchSortedFn, + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + fn search_sorted( + &self, + array: &ArrayData, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult { + let array_ref = <&E::Array>::try_from(array)?; + let encoding = array + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + SearchSortedFn::search_sorted(encoding, array_ref, value, side) + } + + fn search_sorted_usize( + &self, + array: &ArrayData, + value: usize, + side: SearchSortedSide, + ) -> VortexResult { + let array_ref = <&E::Array>::try_from(array)?; + let encoding = array + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + SearchSortedFn::search_sorted_usize(encoding, array_ref, value, side) + } + + fn search_sorted_many( + &self, + array: &ArrayData, + values: &[Scalar], + sides: &[SearchSortedSide], + ) -> VortexResult> { + let array_ref = <&E::Array>::try_from(array)?; + let encoding = array + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + SearchSortedFn::search_sorted_many(encoding, array_ref, values, sides) + } + + fn search_sorted_usize_many( + &self, + array: &ArrayData, + values: &[usize], + sides: &[SearchSortedSide], + ) -> VortexResult> { + let array_ref = <&E::Array>::try_from(array)?; + let encoding = array + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + SearchSortedFn::search_sorted_usize_many(encoding, array_ref, values, sides) + } +} + pub fn search_sorted>( array: &ArrayData, target: T, @@ -146,20 +221,19 @@ pub fn search_sorted>( vortex_bail!("Search sorted with null value is not supported"); } - array.with_dyn(|a| { - if let Some(search_sorted) = a.search_sorted() { - return search_sorted.search_sorted(&scalar, side); - } + if let Some(f) = array.encoding().search_sorted_fn() { + return f.search_sorted(array, &scalar, side); + } - if array.encoding().scalar_at_fn().is_some() { - return Ok(array.search_sorted(&scalar, side)); - } + // Fallback to a generic search_sorted using scalar_at + if array.encoding().scalar_at_fn().is_some() { + return Ok(SearchSorted::search_sorted(array, &scalar, side)); + } - vortex_bail!( - NotImplemented: "search_sorted", - array.encoding().id() - ) - }) + vortex_bail!( + NotImplemented: "search_sorted", + array.encoding().id() + ) } pub fn search_sorted_usize( @@ -167,19 +241,20 @@ pub fn search_sorted_usize( target: usize, side: SearchSortedSide, ) -> VortexResult { - array.with_dyn(|a| { - if let Some(search_sorted) = a.search_sorted() { - search_sorted.search_sorted_usize(target, side) - } else if a.encoding().scalar_at_fn().is_some() { - let scalar = Scalar::primitive(target as u64, array.dtype().nullability()); - Ok(array.search_sorted(&scalar, side)) - } else { - vortex_bail!( - NotImplemented: "search_sorted_usize", - array.encoding().id() - ) - } - }) + if let Some(f) = array.encoding().search_sorted_fn() { + return f.search_sorted_usize(array, target, side); + } + + // Fallback to a generic search_sorted using scalar_at + if array.encoding().scalar_at_fn().is_some() { + let scalar = Scalar::primitive(target as u64, array.dtype().nullability()); + return Ok(SearchSorted::search_sorted(array, &scalar, side)); + } + + vortex_bail!( + NotImplemented: "search_sorted_usize", + array.encoding().id() + ) } /// Search for many elements in the array. @@ -188,23 +263,21 @@ pub fn search_sorted_many + Clone>( targets: &[T], sides: &[SearchSortedSide], ) -> VortexResult> { - array.with_dyn(|a| { - if let Some(search_sorted) = a.search_sorted() { - let values: Vec = targets - .iter() - .map(|t| t.clone().into().cast(array.dtype())) - .try_collect()?; - - search_sorted.search_sorted_many(&values, sides) - } else { - // Call in loop and collect - targets - .iter() - .zip(sides.iter().copied()) - .map(|(target, side)| search_sorted(array, target.clone(), side)) - .try_collect() - } - }) + if let Some(f) = array.encoding().search_sorted_fn() { + let values: Vec = targets + .iter() + .map(|t| t.clone().into().cast(array.dtype())) + .try_collect()?; + + return f.search_sorted_many(array, &values, sides); + } + + // Call in loop and collect + targets + .iter() + .zip(sides.iter().copied()) + .map(|(target, side)| search_sorted(array, target.clone(), side)) + .try_collect() } // Native functions for each of the values, cast up to u64 or down to something lower. @@ -213,19 +286,17 @@ pub fn search_sorted_usize_many( targets: &[usize], sides: &[SearchSortedSide], ) -> VortexResult> { - array.with_dyn(|a| { - if let Some(search_sorted) = a.search_sorted() { - search_sorted.search_sorted_usize_many(targets, sides) - } else { - // Call in loop and collect - targets - .iter() - .copied() - .zip(sides.iter().copied()) - .map(|(target, side)| search_sorted_usize(array, target, side)) - .try_collect() - } - }) + if let Some(f) = array.encoding().search_sorted_fn() { + return f.search_sorted_usize_many(array, targets, sides); + } + + // Call in loop and collect + targets + .iter() + .copied() + .zip(sides.iter().copied()) + .map(|(target, side)| search_sorted_usize(array, target, side)) + .try_collect() } pub trait IndexOrd {