Skip to content

Commit

Permalink
SearchSorted VTable (#1414)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Nov 20, 2024
1 parent dddf5f2 commit c09295c
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 190 deletions.
2 changes: 1 addition & 1 deletion bench-vortex/src/bin/notimplemented.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
impls.push(bool_to_cell(
arr.with_dyn(|a| a.subtract_scalar().is_some()),
));
impls.push(bool_to_cell(arr.with_dyn(|a| a.search_sorted().is_some())));
impls.push(bool_to_cell(arr.encoding().search_sorted_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().slice_fn().is_some()));
impls.push(bool_to_cell(arr.encoding().take_fn().is_some()));
table.add_row(Row::new(impls));
Expand Down
10 changes: 5 additions & 5 deletions encodings/fastlanes/src/bitpacking/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ mod search_sorted;
mod slice;
mod take;

impl ArrayCompute for BitPackedArray {
fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
Some(self)
}
}
impl ArrayCompute for BitPackedArray {}

impl ComputeVTable for BitPackedEncoding {
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Expand All @@ -27,6 +23,10 @@ impl ComputeVTable for BitPackedEncoding {
Some(self)
}

fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand Down
48 changes: 28 additions & 20 deletions encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,53 @@ use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
use vortex_error::{VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;

use crate::{unpack_single_primitive, BitPackedArray};
use crate::{unpack_single_primitive, BitPackedArray, BitPackedEncoding};

impl SearchSortedFn for BitPackedArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
search_sorted_typed::<$P>(self, value, side)
impl SearchSortedFn<BitPackedArray> for BitPackedEncoding {
fn search_sorted(
&self,
array: &BitPackedArray,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
search_sorted_typed::<$P>(array, value, side)
})
}

fn search_sorted_usize(
&self,
array: &BitPackedArray,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
// NOTE: conversion may truncate silently.
if let Some(pvalue) = num_traits::cast::<usize, $P>(value) {
search_sorted_native(self, pvalue, side)
search_sorted_native(array, pvalue, side)
} else {
// provided u64 is too large to fit in the provided PType, value must be off
// the right end of the array.
Ok(SearchResult::NotFound(self.len()))
Ok(SearchResult::NotFound(array.len()))
}
})
}

fn search_sorted_many(
&self,
array: &BitPackedArray,
values: &[Scalar],
sides: &[SearchSortedSide],
) -> VortexResult<Vec<SearchResult>> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
let searcher = BitPackedSearch::<'_, $P>::new(self);
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
let searcher = BitPackedSearch::<'_, $P>::new(array);

values
.iter()
.zip(sides.iter().copied())
.map(|(value, side)| {
// Unwrap to native value
let unwrapped_value: $P = value.cast(self.dtype())?.try_into()?;
let unwrapped_value: $P = value.cast(array.dtype())?.try_into()?;

Ok(searcher.search_sorted(&unwrapped_value, side))
})
Expand All @@ -66,11 +73,12 @@ impl SearchSortedFn for BitPackedArray {

fn search_sorted_usize_many(
&self,
array: &BitPackedArray,
values: &[usize],
sides: &[SearchSortedSide],
) -> VortexResult<Vec<SearchResult>> {
match_each_unsigned_integer_ptype!(self.ptype(), |$P| {
let searcher = BitPackedSearch::<'_, $P>::new(self);
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
let searcher = BitPackedSearch::<'_, $P>::new(array);

values
.iter()
Expand Down Expand Up @@ -200,7 +208,7 @@ impl<T> Len for BitPackedSearch<'_, T> {
mod test {
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::{
search_sorted, search_sorted_many, slice, SearchResult, SearchSortedFn, SearchSortedSide,
search_sorted, search_sorted_many, slice, SearchResult, SearchSortedSide,
};
use vortex_array::IntoArrayData;
use vortex_dtype::Nullability;
Expand Down Expand Up @@ -261,12 +269,12 @@ mod test {
)
.unwrap();

let found = bitpacked
.search_sorted(
&Scalar::primitive(1u64, Nullability::Nullable),
SearchSortedSide::Left,
)
.unwrap();
let found = search_sorted(
bitpacked.as_ref(),
Scalar::primitive(1u64, Nullability::Nullable),
SearchSortedSide::Left,
)
.unwrap();
assert_eq!(found, SearchResult::Found(0));
}

Expand Down
23 changes: 14 additions & 9 deletions encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ use vortex_scalar::{PValue, Scalar};

use crate::{FoRArray, FoREncoding};

impl ArrayCompute for FoRArray {
fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
Some(self)
}
}
impl ArrayCompute for FoRArray {}

impl ComputeVTable for FoREncoding {
fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Expand All @@ -29,6 +25,10 @@ impl ComputeVTable for FoREncoding {
Some(self)
}

fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand Down Expand Up @@ -99,10 +99,15 @@ impl SliceFn<FoRArray> for FoREncoding {
}
}

impl SearchSortedFn for FoRArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match_each_integer_ptype!(self.ptype(), |$P| {
search_sorted_typed::<$P>(self, value, side)
impl SearchSortedFn<FoRArray> for FoREncoding {
fn search_sorted(
&self,
array: &FoRArray,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
match_each_integer_ptype!(array.ptype(), |$P| {
search_sorted_typed::<$P>(array, value, side)
})
}
}
Expand Down
63 changes: 6 additions & 57 deletions vortex-array/src/array/constant/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mod boolean;

use std::cmp::Ordering;
mod search_sorted;

use vortex_error::VortexResult;
use vortex_scalar::Scalar;
Expand All @@ -10,18 +9,14 @@ use crate::array::ConstantEncoding;
use crate::compute::unary::ScalarAtFn;
use crate::compute::{
scalar_cmp, ArrayCompute, BinaryBooleanFn, ComputeVTable, FilterFn, FilterMask, MaybeCompareFn,
Operator, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions,
Operator, SearchSortedFn, SliceFn, TakeFn, TakeOptions,
};
use crate::{ArrayData, ArrayLen, IntoArrayData};

impl ArrayCompute for ConstantArray {
fn compare(&self, other: &ArrayData, operator: Operator) -> Option<VortexResult<ArrayData>> {
MaybeCompareFn::maybe_compare(self, other, operator)
}

fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
Some(self)
}
}

impl ComputeVTable for ConstantEncoding {
Expand All @@ -43,6 +38,10 @@ impl ComputeVTable for ConstantEncoding {
Some(self)
}

fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand Down Expand Up @@ -81,23 +80,6 @@ impl FilterFn<ConstantArray> for ConstantEncoding {
}
}

impl SearchSortedFn for ConstantArray {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
match self
.scalar_value()
.partial_cmp(value.value())
.unwrap_or(Ordering::Less)
{
Ordering::Greater => Ok(SearchResult::NotFound(0)),
Ordering::Less => Ok(SearchResult::NotFound(self.len())),
Ordering::Equal => match side {
SearchSortedSide::Left => Ok(SearchResult::Found(0)),
SearchSortedSide::Right => Ok(SearchResult::Found(self.len())),
},
}
}
}

impl MaybeCompareFn for ConstantArray {
fn maybe_compare(
&self,
Expand All @@ -111,36 +93,3 @@ impl MaybeCompareFn for ConstantArray {
})
}
}

#[cfg(test)]
mod test {
use crate::array::constant::ConstantArray;
use crate::compute::{search_sorted, SearchResult, SearchSortedSide};
use crate::IntoArrayData;

#[test]
pub fn search() {
let cst = ConstantArray::new(42, 5000).into_array();
assert_eq!(
search_sorted(&cst, 33, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(0)
);
assert_eq!(
search_sorted(&cst, 55, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(5000)
);
}

#[test]
pub fn search_equals() {
let cst = ConstantArray::new(42, 5000).into_array();
assert_eq!(
search_sorted(&cst, 42, SearchSortedSide::Left).unwrap(),
SearchResult::Found(0)
);
assert_eq!(
search_sorted(&cst, 42, SearchSortedSide::Right).unwrap(),
SearchResult::Found(5000)
);
}
}
63 changes: 63 additions & 0 deletions vortex-array/src/array/constant/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::cmp::Ordering;

use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::array::{ConstantArray, ConstantEncoding};
use crate::compute::{SearchResult, SearchSortedFn, SearchSortedSide};
use crate::ArrayLen;

impl SearchSortedFn<ConstantArray> for ConstantEncoding {
fn search_sorted(
&self,
array: &ConstantArray,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
match array
.scalar_value()
.partial_cmp(value.value())
.unwrap_or(Ordering::Less)
{
Ordering::Greater => Ok(SearchResult::NotFound(0)),
Ordering::Less => Ok(SearchResult::NotFound(array.len())),
Ordering::Equal => match side {
SearchSortedSide::Left => Ok(SearchResult::Found(0)),
SearchSortedSide::Right => Ok(SearchResult::Found(array.len())),
},
}
}
}

#[cfg(test)]
mod test {
use crate::array::constant::ConstantArray;
use crate::compute::{search_sorted, SearchResult, SearchSortedSide};
use crate::IntoArrayData;

#[test]
pub fn search() {
let cst = ConstantArray::new(42, 5000).into_array();
assert_eq!(
search_sorted(&cst, 33, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(0)
);
assert_eq!(
search_sorted(&cst, 55, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(5000)
);
}

#[test]
pub fn search_equals() {
let cst = ConstantArray::new(42, 5000).into_array();
assert_eq!(
search_sorted(&cst, 42, SearchSortedSide::Left).unwrap(),
SearchResult::Found(0)
);
assert_eq!(
search_sorted(&cst, 42, SearchSortedSide::Right).unwrap(),
SearchResult::Found(5000)
);
}
}
8 changes: 4 additions & 4 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ impl ArrayCompute for PrimitiveArray {
fn subtract_scalar(&self) -> Option<&dyn SubtractScalarFn> {
Some(self)
}

fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
Some(self)
}
}

impl ComputeVTable for PrimitiveEncoding {
Expand All @@ -50,6 +46,10 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand Down
Loading

0 comments on commit c09295c

Please sign in to comment.