From 8c71d92a2e4a980a29f014a9a6ffb98c3a77403f Mon Sep 17 00:00:00 2001 From: Josh Casale Date: Mon, 20 May 2024 16:13:35 +0100 Subject: [PATCH] bool array impl --- .../src/array/bool/compute/compare.rs | 114 ++++++++++++++++++ vortex-array/src/array/bool/compute/mod.rs | 6 + 2 files changed, 120 insertions(+) create mode 100644 vortex-array/src/array/bool/compute/compare.rs diff --git a/vortex-array/src/array/bool/compute/compare.rs b/vortex-array/src/array/bool/compute/compare.rs new file mode 100644 index 0000000000..6b250ff9cd --- /dev/null +++ b/vortex-array/src/array/bool/compute/compare.rs @@ -0,0 +1,114 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use vortex_error::VortexResult; +use vortex_expr::operators::Operator; + +use crate::array::bool::BoolArray; +use crate::compute::compare::CompareArraysFn; +use crate::{Array, ArrayTrait, IntoArray}; + +impl CompareArraysFn for BoolArray { + fn compare_arrays(&self, other: &Array, op: Operator) -> VortexResult { + let flattened = other.clone().flatten_bool()?; + let lhs = self.boolean_buffer(); + let rhs = flattened.boolean_buffer(); + let result_buf = match op { + Operator::EqualTo => lhs.bitxor(&rhs).not(), + Operator::NotEqualTo => lhs.bitxor(&rhs), + + Operator::GreaterThan => lhs.bitand(&rhs).bitxor(&lhs), + Operator::GreaterThanOrEqualTo => { + let gt = lhs.bitand(&rhs).bitxor(&lhs); + let eq = &lhs.bitxor(&rhs).not(); + gt.bitor(eq) + } + Operator::LessThan => lhs.bitor(&rhs).bitand(&lhs).not(), + Operator::LessThanOrEqualTo => { + let eq = lhs.bitxor(&rhs).not(); + let lt = lhs.bitor(&rhs).bitand(&lhs).not(); + lt.bitor(&eq) + } + }; + let present_buf = self + .validity() + .to_logical(self.len()) + .to_present_null_buffer()? + .into_inner(); + + Ok(BoolArray::from(result_buf.bitand(&present_buf)).into_array()) + } +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + + use super::*; + use crate::validity::Validity; + use crate::ToArray; + + fn to_int_indices(filtered_primitive: BoolArray) -> Vec { + let filtered = filtered_primitive + .boolean_buffer() + .iter() + .enumerate() + .flat_map(|(idx, v)| if v { Some(idx as u64) } else { None }) + .collect_vec(); + filtered + } + + #[test] + fn test_basic_filter() { + let arr = BoolArray::from_vec( + vec![true, true, false, true, false], + Validity::Array(BoolArray::from(vec![false, true, true, true, true]).into_array()), + ); + + let matches = arr + .compare_arrays(&arr.to_array(), Operator::EqualTo) + .unwrap() + .flatten_bool() + .unwrap(); + assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]); + + let matches = arr + .compare_arrays(&arr.to_array(), Operator::NotEqualTo) + .unwrap() + .flatten_bool() + .unwrap(); + assert_eq!(to_int_indices(matches), []); + + let other = BoolArray::from_vec( + vec![false, false, false, true, true], + Validity::Array(BoolArray::from(vec![false, true, true, true, true]).into_array()), + ); + + let matches = arr + .compare_arrays(&other.to_array(), Operator::LessThanOrEqualTo) + .unwrap() + .flatten_bool() + .unwrap(); + assert_eq!(to_int_indices(matches), [2u64, 3, 4]); + + let matches = arr + .compare_arrays(&other.to_array(), Operator::LessThan) + .unwrap() + .flatten_bool() + .unwrap(); + assert_eq!(to_int_indices(matches), [2u64, 4]); + + let matches = other + .compare_arrays(&arr.to_array(), Operator::GreaterThanOrEqualTo) + .unwrap() + .flatten_bool() + .unwrap(); + assert_eq!(to_int_indices(matches), [2u64, 3, 4]); + + let matches = other + .compare_arrays(&arr.to_array(), Operator::GreaterThan) + .unwrap() + .flatten_bool() + .unwrap(); + assert_eq!(to_int_indices(matches), [4u64]); + } +} diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 35dd3e4a15..91c8d75b71 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -1,6 +1,7 @@ use crate::array::bool::BoolArray; use crate::compute::as_arrow::AsArrowArray; use crate::compute::as_contiguous::AsContiguousFn; +use crate::compute::compare::CompareArraysFn; use crate::compute::fill::FillForwardFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::slice::SliceFn; @@ -9,6 +10,7 @@ use crate::compute::ArrayCompute; mod as_arrow; mod as_contiguous; +mod compare; mod fill; mod flatten; mod scalar_at; @@ -24,6 +26,10 @@ impl ArrayCompute for BoolArray { Some(self) } + fn compare_arrays(&self) -> Option<&dyn CompareArraysFn> { + Some(self) + } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { Some(self) }