Skip to content

Commit

Permalink
Implement generic search sorted using scalar_at
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 committed Mar 28, 2024
1 parent f3ce3ac commit a0b94ff
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 24 deletions.
3 changes: 2 additions & 1 deletion vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ impl SearchSortedFn for PrimitiveArray {

#[cfg(test)]
mod test {
use super::*;
use crate::array::IntoArray;
use crate::compute::search_sorted::search_sorted;

use super::*;

#[test]
fn test_searchsorted_primitive() {
let values = vec![1u16, 2, 3].into_array();
Expand Down
143 changes: 120 additions & 23 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::cmp::Ordering;
use std::cmp::Ordering::{Equal, Greater, Less};

use vortex_error::{VortexError, VortexResult};

use crate::array::Array;
use crate::compute::scalar_at::scalar_at;
use crate::scalar::Scalar;
use std::cmp::Ordering;

#[derive(Debug, Copy, Clone)]
pub enum SearchSortedSide {
Left,
Right,
Expand All @@ -22,6 +26,12 @@ pub fn search_sorted<T: Into<Scalar>>(
array
.search_sorted()
.map(|f| f.search_sorted(&scalar, side))
.or_else(|| {
array
.scalar_at()
.map(|_| SearchSorted::search_sorted(&array, &scalar, side))
.map(Ok)
})
.unwrap_or_else(|| {
Err(VortexError::NotImplemented(
"search_sorted",
Expand All @@ -30,31 +40,118 @@ pub fn search_sorted<T: Into<Scalar>>(
})
}

pub trait SearchSorted<T> {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize;
pub trait IndexOrd<Idx: ?Sized, V> {
fn index_cmp(&self, idx: &Idx, elem: &V) -> Option<Ordering>;

fn index_lt(&self, idx: &&Idx, elem: &V) -> bool {
matches!(self.index_cmp(*idx, elem), Some(Less))
}

fn index_le(&self, idx: &&Idx, elem: &V) -> bool {
matches!(self.index_cmp(*idx, elem), Some(Less | Equal))
}

fn index_gt(&self, idx: &&Idx, elem: &V) -> bool {
matches!(self.index_cmp(*idx, elem), Some(Greater))
}

fn index_ge(&self, idx: &&Idx, elem: &V) -> bool {
matches!(self.index_cmp(*idx, elem), Some(Greater | Equal))
}
}

#[allow(clippy::len_without_is_empty)]
pub trait Len {
fn len(&self) -> usize;
}

impl<T: PartialOrd> SearchSorted<T> for &[T] {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize {
pub trait SearchSorted<T, I> {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize
where
Self: IndexOrd<I, T>,
{
match side {
SearchSortedSide::Left => self
.binary_search_by(|x| {
if x < value {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_or_else(|x| x),
SearchSortedSide::Right => self
.binary_search_by(|x| {
if x <= value {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_or_else(|x| x),
SearchSortedSide::Left => self.search_sorted_by(|idx| {
if self.index_lt(&idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Right => self.search_sorted_by(|idx| {
if self.index_le(&idx, value) {
Less
} else {
Greater
}
}),
}
}

fn search_sorted_by<F: FnMut(&I) -> Ordering>(&self, f: F) -> usize;
}

impl<S: IndexOrd<usize, T> + Len + ?Sized, T> SearchSorted<T, usize> for S {
// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_by<F: FnMut(&usize) -> Ordering>(&self, mut f: F) -> usize {
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
let cmp = f(&mid);

left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return mid;
}

size = right - left;
}

left
}
}

impl<T: Array> IndexOrd<usize, Scalar> for T {
fn index_cmp(&self, idx: &usize, elem: &Scalar) -> Option<Ordering> {
let scalar_a = scalar_at(self, *idx).ok()?;
scalar_a.partial_cmp(elem)
}
}

impl IndexOrd<usize, Scalar> for &dyn Array {
fn index_cmp(&self, idx: &usize, elem: &Scalar) -> Option<Ordering> {
let scalar_a = scalar_at(*self, *idx).ok()?;
scalar_a.partial_cmp(elem)
}
}

impl<T: PartialOrd> IndexOrd<usize, T> for [T] {
fn index_cmp(&self, idx: &usize, elem: &T) -> Option<Ordering> {
self[*idx].partial_cmp(elem)
}
}

impl<T: Array> Len for T {
fn len(&self) -> usize {
T::len(self)
}
}

impl Len for &dyn Array {
fn len(&self) -> usize {
Array::len(*self)
}
}

impl<T> Len for [T] {
fn len(&self) -> usize {
self.len()
}
}

0 comments on commit a0b94ff

Please sign in to comment.