Skip to content

Commit

Permalink
Search sorted (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Mar 9, 2024
1 parent 1da315e commit 284e31a
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 478 deletions.
290 changes: 13 additions & 277 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ log = "0.4.20"
num-traits = "0.2.18"
num_enum = "0.7.2"
once_cell = "1.19.0"
polars-arrow = { version = "0.37.0", features = ["arrow_rs"] }
polars-core = "0.37.0"
polars-ops = { version = "0.37.0", features = ["search_sorted"] }
rand = { version = "0.8.5", features = [] }
rayon = "1.8.1"
roaring = "0.10.3"
Expand Down
6 changes: 6 additions & 0 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use crate::compute::cast::CastPrimitiveFn;
use crate::compute::fill::FillForwardFn;
use crate::compute::patch::PatchFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::search_sorted::SearchSortedFn;
use crate::compute::ArrayCompute;

mod as_contiguous;
mod cast;
mod fill;
mod patch;
mod scalar_at;
mod search_sorted;

impl ArrayCompute for PrimitiveArray {
fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> {
Expand All @@ -32,4 +34,8 @@ impl ArrayCompute for PrimitiveArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}

fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
Some(self)
}
}
57 changes: 57 additions & 0 deletions vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::array::primitive::PrimitiveArray;
use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide};
use crate::error::VortexResult;
use crate::match_each_native_ptype;
use crate::ptype::NativePType;
use crate::scalar::Scalar;

impl SearchSortedFn for PrimitiveArray {
fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult<usize> {
match_each_native_ptype!(self.ptype(), |$T| {
let pvalue: $T = value.try_into()?;
Ok(search_sorted(self.typed_data::<$T>(), pvalue, side))
})
}
}

fn search_sorted<T: NativePType>(arr: &[T], target: T, side: SearchSortedSide) -> usize {
match side {
SearchSortedSide::Left => search_sorted_cmp(arr, target, |a, b| a < b),
SearchSortedSide::Right => search_sorted_cmp(arr, target, |a, b| a <= b),
}
}

fn search_sorted_cmp<T: NativePType, Cmp>(arr: &[T], target: T, cmp: Cmp) -> usize
where
Cmp: Fn(T, T) -> bool + 'static,
{
let mut low = 0;
let mut high = arr.len();

while low < high {
let mid = low + (high - low) / 2;

if cmp(arr[mid], target) {
low = mid + 1;
} else {
high = mid;
}
}

low
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_searchsorted_primitive() {
let values = vec![1u16, 2, 3];

assert_eq!(search_sorted(&values, 0, SearchSortedSide::Left), 0);
assert_eq!(search_sorted(&values, 1, SearchSortedSide::Left), 0);
assert_eq!(search_sorted(&values, 1, SearchSortedSide::Right), 1);
assert_eq!(search_sorted(&values, 4, SearchSortedSide::Left), 3);
}
}
32 changes: 15 additions & 17 deletions vortex-array/src/array/sparse/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::array::sparse::SparseArray;
use crate::array::{Array, ArrayRef};
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide};
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
use crate::scalar::{NullableScalar, Scalar, ScalarRef};
Expand Down Expand Up @@ -48,21 +48,19 @@ impl ScalarAtFn for SparseArray {
// First, get the index of the patch index array that is the first index
// greater than or equal to the true index
let true_patch_index = index + self.indices_offset;
search_sorted_usize(self.indices(), true_patch_index, SearchSortedSide::Left).and_then(
|idx| {
// If the value at this index is equal to the true index, then it exists in the patch index array
// and we should return the value at the corresponding index in the patch values array
scalar_at(self.indices(), idx)
.or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed()))
.and_then(usize::try_from)
.and_then(|patch_index| {
if patch_index == true_patch_index {
scalar_at(self.values(), idx)
} else {
Ok(NullableScalar::none(self.values().dtype().clone()).boxed())
}
})
},
)
search_sorted(self.indices(), true_patch_index, SearchSortedSide::Left).and_then(|idx| {
// If the value at this index is equal to the true index, then it exists in the patch index array
// and we should return the value at the corresponding index in the patch values array
scalar_at(self.indices(), idx)
.or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed()))
.and_then(usize::try_from)
.and_then(|patch_index| {
if patch_index == true_patch_index {
scalar_at(self.values(), idx)
} else {
Ok(NullableScalar::none(self.values().dtype().clone()).boxed())
}
})
})
}
}
6 changes: 3 additions & 3 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::array::{
check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef,
};
use crate::compress::EncodingCompression;
use crate::compute::search_sorted::{search_sorted_usize, SearchSortedSide};
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::dtype::DType;
use crate::error::{VortexError, VortexResult};
use crate::formatter::{ArrayDisplay, ArrayFormatter};
Expand Down Expand Up @@ -161,8 +161,8 @@ impl Array for SparseArray {
check_slice_bounds(self, start, stop)?;

// Find the index of the first patch index that is greater than or equal to the offset of this array
let index_start_index = search_sorted_usize(self.indices(), start, SearchSortedSide::Left)?;
let index_end_index = search_sorted_usize(self.indices(), stop, SearchSortedSide::Left)?;
let index_start_index = search_sorted(self.indices(), start, SearchSortedSide::Left)?;
let index_end_index = search_sorted(self.indices(), stop, SearchSortedSide::Left)?;

Ok(SparseArray {
indices_offset: self.indices_offset + start,
Expand Down
5 changes: 5 additions & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::compute::as_contiguous::AsContiguousFn;
use crate::compute::search_sorted::SearchSortedFn;
use cast::{CastBoolFn, CastPrimitiveFn};
use fill::FillForwardFn;
use patch::PatchFn;
Expand Down Expand Up @@ -40,6 +41,10 @@ pub trait ArrayCompute {
None
}

fn search_sorted(&self) -> Option<&dyn SearchSortedFn> {
None
}

fn take(&self) -> Option<&dyn TakeFn> {
None
}
Expand Down
74 changes: 17 additions & 57 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,29 @@
use crate::array::Array;
use crate::error::VortexResult;
use crate::polars::IntoPolarsSeries;
use crate::polars::IntoPolarsValue;
use crate::scalar::ScalarRef;
use polars_core::prelude::*;
use polars_ops::prelude::*;
use crate::error::{VortexError, VortexResult};
use crate::scalar::{Scalar, ScalarRef};

pub enum SearchSortedSide {
Left,
Right,
}

impl From<SearchSortedSide> for polars_ops::prelude::SearchSortedSide {
fn from(side: SearchSortedSide) -> Self {
match side {
SearchSortedSide::Left => polars_ops::prelude::SearchSortedSide::Left,
SearchSortedSide::Right => polars_ops::prelude::SearchSortedSide::Right,
}
}
pub trait SearchSortedFn {
fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult<usize>;
}

pub fn search_sorted_usize(
indices: &dyn Array,
index: usize,
pub fn search_sorted<T: Into<ScalarRef>>(
array: &dyn Array,
target: T,
side: SearchSortedSide,
) -> VortexResult<usize> {
let enc_scalar: ScalarRef = index.into();
// Convert index into correctly typed Arrow scalar.
let enc_scalar = enc_scalar.cast(indices.dtype())?;

let series: Series = indices.iter_arrow().into_polars();
Ok(search_sorted(
&series,
&Series::from_any_values("needle", &[enc_scalar.into_polars()], true)?,
side.into(),
false,
)?
.get(0)
.unwrap() as usize)
}

#[cfg(test)]
mod test {
use super::*;
use crate::array::ArrayRef;

#[test]
fn test_searchsorted_scalar() {
let haystack: ArrayRef = vec![1, 2, 3].into();

assert_eq!(
search_sorted_usize(haystack.as_ref(), 0, SearchSortedSide::Left).unwrap(),
0
);
assert_eq!(
search_sorted_usize(haystack.as_ref(), 1, SearchSortedSide::Left).unwrap(),
0
);
assert_eq!(
search_sorted_usize(haystack.as_ref(), 1, SearchSortedSide::Right).unwrap(),
1
);
assert_eq!(
search_sorted_usize(haystack.as_ref(), 4, SearchSortedSide::Left).unwrap(),
3
);
}
let scalar = target.into().cast(array.dtype())?;
array
.search_sorted()
.map(|f| f.search_sorted(scalar.as_ref(), side))
.unwrap_or_else(|| {
Err(VortexError::NotImplemented(
"search_sorted",
array.encoding().id(),
))
})
}
18 changes: 0 additions & 18 deletions vortex-array/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ pub enum VortexError {
MismatchedTypes(DType, DType),
#[error("unexpected arrow data type: {0:?}")]
InvalidArrowDataType(arrow::datatypes::DataType),
#[error("polars error: {0:?}")]
PolarsError(PolarsError),
#[error("arrow error: {0:?}")]
ArrowError(ArrowError),
#[error("patch values may not be null for base dtype {0}")]
Expand Down Expand Up @@ -102,19 +100,3 @@ impl From<arrow::error::ArrowError> for VortexError {
VortexError::ArrowError(ArrowError(err))
}
}

#[derive(Debug)]
#[allow(dead_code)]
pub struct PolarsError(polars_core::error::PolarsError);

impl PartialEq for PolarsError {
fn eq(&self, _other: &Self) -> bool {
false
}
}

impl From<polars_core::error::PolarsError> for VortexError {
fn from(err: polars_core::error::PolarsError) -> Self {
VortexError::PolarsError(PolarsError(err))
}
}
1 change: 0 additions & 1 deletion vortex-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub mod dtype;
pub mod encode;
pub mod error;
pub mod formatter;
mod polars;
pub mod ptype;
mod sampling;
pub mod serde;
Expand Down
Loading

0 comments on commit 284e31a

Please sign in to comment.