diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 0355c1c795..c3d3f27119 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -17,10 +17,11 @@ impl SearchSortedFn for PrimitiveArray { #[cfg(test)] mod test { - use super::*; use crate::array::IntoArray; use crate::compute::search_sorted::search_sorted; + use super::*; + #[test] fn test_searchsorted_primitive() { let values = vec![1u16, 2, 3].into_array(); diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index f4ee59c21f..3978dd182d 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -1,9 +1,13 @@ +use std::cmp::Ordering; +use std::cmp::Ordering::{Equal, Greater, Less}; + use vortex_error::{VortexError, VortexResult}; use crate::array::Array; +use crate::compute::scalar_at::scalar_at; use crate::scalar::Scalar; -use std::cmp::Ordering; +#[derive(Debug, Copy, Clone)] pub enum SearchSortedSide { Left, Right, @@ -22,6 +26,12 @@ pub fn search_sorted>( array .search_sorted() .map(|f| f.search_sorted(&scalar, side)) + .or_else(|| { + array + .scalar_at() + .map(|_| SearchSorted::search_sorted(&array, &scalar, side)) + .map(Ok) + }) .unwrap_or_else(|| { Err(VortexError::NotImplemented( "search_sorted", @@ -30,31 +40,118 @@ pub fn search_sorted>( }) } -pub trait SearchSorted { - fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize; +pub trait IndexOrd { + fn index_cmp(&self, idx: &Idx, elem: &V) -> Option; + + fn index_lt(&self, idx: &&Idx, elem: &V) -> bool { + matches!(self.index_cmp(*idx, elem), Some(Less)) + } + + fn index_le(&self, idx: &&Idx, elem: &V) -> bool { + matches!(self.index_cmp(*idx, elem), Some(Less | Equal)) + } + + fn index_gt(&self, idx: &&Idx, elem: &V) -> bool { + matches!(self.index_cmp(*idx, elem), Some(Greater)) + } + + fn index_ge(&self, idx: &&Idx, elem: &V) -> bool { + matches!(self.index_cmp(*idx, elem), Some(Greater | Equal)) + } +} + +#[allow(clippy::len_without_is_empty)] +pub trait Len { + fn len(&self) -> usize; } -impl SearchSorted for &[T] { - fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize { +pub trait SearchSorted { + fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize + where + Self: IndexOrd, + { match side { - SearchSortedSide::Left => self - .binary_search_by(|x| { - if x < value { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_or_else(|x| x), - SearchSortedSide::Right => self - .binary_search_by(|x| { - if x <= value { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_or_else(|x| x), + SearchSortedSide::Left => self.search_sorted_by(|idx| { + if self.index_lt(&idx, value) { + Less + } else { + Greater + } + }), + SearchSortedSide::Right => self.search_sorted_by(|idx| { + if self.index_le(&idx, value) { + Less + } else { + Greater + } + }), + } + } + + fn search_sorted_by Ordering>(&self, f: F) -> usize; +} + +impl + Len + ?Sized, T> SearchSorted for S { + // Code adapted from Rust standard library slice::binary_search_by + fn search_sorted_by Ordering>(&self, mut f: F) -> usize { + // INVARIANTS: + // - 0 <= left <= left + size = right <= self.len() + // - f returns Less for everything in self[..left] + // - f returns Greater for everything in self[right..] + let mut size = self.len(); + let mut left = 0; + let mut right = size; + while left < right { + let mid = left + size / 2; + let cmp = f(&mid); + + left = if cmp == Less { mid + 1 } else { left }; + right = if cmp == Greater { mid } else { right }; + if cmp == Equal { + return mid; + } + + size = right - left; } + + left + } +} + +impl IndexOrd for T { + fn index_cmp(&self, idx: &usize, elem: &Scalar) -> Option { + let scalar_a = scalar_at(self, *idx).ok()?; + scalar_a.partial_cmp(elem) + } +} + +impl IndexOrd for &dyn Array { + fn index_cmp(&self, idx: &usize, elem: &Scalar) -> Option { + let scalar_a = scalar_at(*self, *idx).ok()?; + scalar_a.partial_cmp(elem) + } +} + +impl IndexOrd for [T] { + fn index_cmp(&self, idx: &usize, elem: &T) -> Option { + self[*idx].partial_cmp(elem) + } +} + +impl Len for T { + fn len(&self) -> usize { + T::len(self) + } +} + +impl Len for &dyn Array { + fn len(&self) -> usize { + Array::len(*self) + } +} + +impl Len for [T] { + fn len(&self) -> usize { + self.len() } }