Skip to content

Commit

Permalink
comparison fn
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcasale committed May 20, 2024
1 parent 8b6606a commit e94c536
Showing 4 changed files with 183 additions and 0 deletions.
137 changes: 137 additions & 0 deletions vortex-array/src/array/primitive/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use std::ops::BitAnd;

use arrow_buffer::BooleanBuffer;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::VortexResult;
use vortex_expr::operators::Operator;

use crate::array::bool::BoolArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::compare::CompareArraysFn;
use crate::{Array, ArrayTrait, IntoArray};

impl CompareArraysFn for PrimitiveArray {
fn compare_arrays(&self, other: &Array, predicate: Operator) -> VortexResult<Array> {
let flattened = other.clone().flatten_primitive()?;

let matching_idxs = match_each_native_ptype!(self.ptype(), |$T| {
let predicate_fn = &predicate.to_predicate::<$T>();
apply_predicate(self.typed_data::<$T>(), flattened.typed_data::<$T>(), predicate_fn)
});

let present = self
.validity()
.to_logical(self.len())
.to_present_null_buffer()?
.into_inner();
let present_other = flattened
.validity()
.to_logical(self.len())
.to_present_null_buffer()?
.into_inner();

Ok(BoolArray::from(matching_idxs.bitand(&present).bitand(&present_other)).into_array())
}
}

fn apply_predicate<T: NativePType, F: Fn(&T, &T) -> bool>(
lhs: &[T],
rhs: &[T],
f: F,
) -> BooleanBuffer {
let matches = lhs.iter().zip(rhs.iter()).map(|(lhs, rhs)| f(lhs, rhs));
BooleanBuffer::from_iter(matches)
}

#[cfg(test)]
mod test {
use itertools::Itertools;

use super::*;
use crate::ToArray;

fn to_int_indices(filtered_primitive: BoolArray) -> Vec<u64> {
let filtered = filtered_primitive
.boolean_buffer()
.iter()
.enumerate()
.flat_map(|(idx, v)| if v { Some(idx as u64) } else { None })
.collect_vec();
filtered
}

#[test]
fn test_basic_filter() {
let arr = PrimitiveArray::from_nullable_vec(vec![
Some(1i32),
Some(2),
Some(3),
Some(4),
None,
Some(5),
Some(6),
Some(7),
Some(8),
None,
Some(9),
None,
]);

let matches = arr
.compare_arrays(&arr.to_array(), Operator::EqualTo)
.unwrap()
.flatten_bool()
.unwrap();
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = arr
.compare_arrays(&arr.to_array(), Operator::NotEqualTo)
.unwrap()
.flatten_bool()
.unwrap();
assert_eq!(to_int_indices(matches), []);

let other = PrimitiveArray::from_nullable_vec(vec![
Some(1i32),
Some(2),
Some(3),
Some(4),
None,
Some(6),
Some(7),
Some(8),
Some(9),
None,
Some(10),
None,
]);

let matches = arr
.compare_arrays(&other.to_array(), Operator::LessThanOrEqualTo)
.unwrap()
.flatten_bool()
.unwrap();
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = arr
.compare_arrays(&other.to_array(), Operator::LessThan)
.unwrap()
.flatten_bool()
.unwrap();
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);

let matches = other
.compare_arrays(&arr.to_array(), Operator::GreaterThanOrEqualTo)
.unwrap()
.flatten_bool()
.unwrap();
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = other
.compare_arrays(&arr.to_array(), Operator::GreaterThan)
.unwrap()
.flatten_bool()
.unwrap();
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
}
}
10 changes: 10 additions & 0 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,9 @@ use crate::array::primitive::PrimitiveArray;
use crate::compute::as_arrow::AsArrowArray;
use crate::compute::as_contiguous::AsContiguousFn;
use crate::compute::cast::CastFn;
use crate::compute::compare::CompareArraysFn;
use crate::compute::fill::FillForwardFn;
use crate::compute::filter_indices::FilterIndicesFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::scalar_subtract::SubtractScalarFn;
use crate::compute::search_sorted::SearchSortedFn;
@@ -13,6 +15,7 @@ use crate::compute::ArrayCompute;
mod as_arrow;
mod as_contiguous;
mod cast;
mod compare;
mod fill;
mod filter_indices;
mod scalar_at;
@@ -34,9 +37,16 @@ impl ArrayCompute for PrimitiveArray {
Some(self)
}

fn compare_arrays(&self) -> Option<&dyn CompareArraysFn> {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}
fn filter_indices(&self) -> Option<&dyn FilterIndicesFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
30 changes: 30 additions & 0 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};
use vortex_expr::operators::Operator;

use crate::{Array, ArrayDType};

pub trait CompareArraysFn {
fn compare_arrays(&self, array: &Array, predicate: Operator) -> VortexResult<Array>;
}

pub fn compare_arrays(array: &Array, other: &Array, predicate: Operator) -> VortexResult<Array> {
if let Some(matching_indices) = array.with_dyn(|c| {
c.compare_arrays()
.map(|t| t.compare_arrays(other, predicate))
}) {
return matching_indices;
}
// if filter is not implemented for the given array type, but the array has a numeric
// DType, we can flatten the array and apply filter to the flattened primitive array
match array.dtype() {
DType::Primitive(..) => {
let flat = array.clone().flatten_primitive()?;
flat.compare_arrays(other, predicate)
}
_ => Err(vortex_err!(
NotImplemented: "compare_arrays",
array.encoding().id()
)),
}
}
6 changes: 6 additions & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use as_arrow::AsArrowArray;
use as_contiguous::AsContiguousFn;
use cast::CastFn;
use compare::CompareArraysFn;
use fill::FillForwardFn;
use patch::PatchFn;
use scalar_at::ScalarAtFn;
@@ -14,6 +15,7 @@ use crate::compute::scalar_subtract::SubtractScalarFn;
pub mod as_arrow;
pub mod as_contiguous;
pub mod cast;
pub mod compare;
pub mod fill;
pub mod filter_indices;
pub mod patch;
@@ -36,6 +38,10 @@ pub trait ArrayCompute {
None
}

fn compare_arrays(&self) -> Option<&dyn CompareArraysFn> {
None
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
None
}

0 comments on commit e94c536

Please sign in to comment.