diff --git a/Cargo.lock b/Cargo.lock index b7185c6b7d..d67e1be9c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5045,6 +5045,7 @@ dependencies = [ "vortex-buffer", "vortex-dtype", "vortex-error", + "vortex-expr", "vortex-flatbuffers", "vortex-scalar", ] diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 1f2b783afd..838b5e105a 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -36,6 +36,7 @@ rand = { workspace = true } vortex-buffer = { path = "../vortex-buffer" } vortex-dtype = { path = "../vortex-dtype", features = ["flatbuffers", "serde"] } vortex-error = { path = "../vortex-error", features = ["flexbuffers"] } +vortex-expr = { path = "../vortex-expr" } vortex-flatbuffers = { path = "../vortex-flatbuffers" } vortex-scalar = { path = "../vortex-scalar", features = ["flatbuffers", "serde"] } serde = { workspace = true, features = ["derive"] } @@ -57,3 +58,7 @@ harness = false [[bench]] name = "scalar_subtract" harness = false + +[[bench]] +name = "filter_indices" +harness = false \ No newline at end of file diff --git a/vortex-array/benches/filter_indices.rs b/vortex-array/benches/filter_indices.rs new file mode 100644 index 0000000000..7f5e7127ff --- /dev/null +++ b/vortex-array/benches/filter_indices.rs @@ -0,0 +1,38 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::{thread_rng, Rng}; +use vortex::IntoArray; +use vortex_dtype::field_paths::FieldPath; +use vortex_error::VortexError; +use vortex_expr::expressions::{lit, Conjunction, Disjunction}; +use vortex_expr::field_paths::FieldPathOperations; + +fn filter_indices(c: &mut Criterion) { + let mut group = c.benchmark_group("filter_indices"); + + let mut rng = thread_rng(); + let range = Uniform::new(0i64, 100_000_000); + let arr = (0..10_000_000) + .map(|_| rng.sample(range)) + .collect_vec() + .into_array(); + + let predicate = Disjunction { + conjunctions: vec![Conjunction { + predicates: vec![FieldPath::builder().build().lt(lit(50_000_000i64))], + }], + }; + + group.bench_function("vortex", |b| { + b.iter(|| { + let indices = + vortex::compute::filter_indices::filter_indices(&arr, &predicate).unwrap(); + black_box(indices); + Ok::<(), VortexError>(()) + }); + }); +} + +criterion_group!(benches, filter_indices); +criterion_main!(benches); diff --git a/vortex-array/src/array/primitive/compute/filter_indices.rs b/vortex-array/src/array/primitive/compute/filter_indices.rs new file mode 100644 index 0000000000..a49834d5d3 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/filter_indices.rs @@ -0,0 +1,245 @@ +use std::ops::{BitAnd, BitOr}; + +use arrow_buffer::BooleanBuffer; +use vortex_dtype::{match_each_native_ptype, NativePType}; +use vortex_error::{vortex_bail, VortexResult}; +use vortex_expr::expressions::{Disjunction, Predicate, Value}; + +use crate::array::bool::BoolArray; +use crate::array::primitive::PrimitiveArray; +use crate::compute::filter_indices::FilterIndicesFn; +use crate::{Array, ArrayTrait, IntoArray}; + +impl FilterIndicesFn for PrimitiveArray { + fn filter_indices(&self, predicate: &Disjunction) -> VortexResult { + let conjunction_indices = predicate.conjunctions.iter().map(|conj| { + conj.predicates + .iter() + .map(|pred| indices_matching_predicate(self, pred)) + .reduce(|a, b| Ok(a?.bitand(&b?))) + .unwrap() + }); + let present_buf = self + .validity() + .to_logical(self.len()) + .to_present_null_buffer()? + .into_inner(); + + let bitset: VortexResult = conjunction_indices + .reduce(|a, b| Ok(a?.bitor(&b?))) + .map(|bitset| Ok(bitset?.bitand(&present_buf))) + .unwrap_or_else(|| Ok(BooleanBuffer::new_set(self.len()))); + + Ok(BoolArray::from(bitset?).into_array()) + } +} + +fn indices_matching_predicate( + arr: &PrimitiveArray, + predicate: &Predicate, +) -> VortexResult { + if predicate.left.head().is_some() { + vortex_bail!("Invalid path for primitive array") + } + + let rhs = match &predicate.right { + Value::Field(_) => { + vortex_bail!("Cannot apply field reference to primitive array") + } + Value::Literal(scalar) => scalar, + }; + + let matching_idxs = match_each_native_ptype!(arr.ptype(), |$T| { + let rhs_typed: $T = rhs.try_into().unwrap(); + let predicate_fn = &predicate.op.to_predicate::<$T>(); + apply_predicate(arr.typed_data::<$T>(), &rhs_typed, predicate_fn) + }); + + Ok(matching_idxs) +} + +fn apply_predicate bool>( + lhs: &[T], + rhs: &T, + f: F, +) -> BooleanBuffer { + let matches = lhs.iter().map(|lhs| f(lhs, rhs)); + BooleanBuffer::from_iter(matches) +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + use vortex_dtype::field_paths::FieldPathBuilder; + use vortex_expr::expressions::{lit, Conjunction}; + use vortex_expr::field_paths::FieldPathOperations; + + use super::*; + use crate::validity::Validity; + + fn apply_conjunctive_filter(arr: &PrimitiveArray, conj: Conjunction) -> VortexResult { + arr.filter_indices(&Disjunction { + conjunctions: vec![conj], + }) + } + + fn to_int_indices(filtered_primitive: BoolArray) -> Vec { + let filtered = filtered_primitive + .boolean_buffer() + .iter() + .enumerate() + .flat_map(|(idx, v)| if v { Some(idx as u64) } else { None }) + .collect_vec(); + filtered + } + + #[test] + fn test_basic_filter() { + let arr = PrimitiveArray::from_nullable_vec(vec![ + Some(1u32), + Some(2), + Some(3), + Some(4), + None, + Some(5), + Some(6), + Some(7), + Some(8), + None, + Some(9), + None, + ]); + + let field = FieldPathBuilder::new().build(); + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().lt(lit(5u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [0u64, 1, 2, 3]); + + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().gt(lit(5u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [6u64, 7, 8, 10]); + + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().eq(lit(5u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [5u64]); + + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().gte(lit(5u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [5u64, 6, 7, 8, 10]); + + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().lte(lit(5u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [0u64, 1, 2, 3, 5]); + } + + #[test] + fn test_multiple_predicates() { + let arr = + PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid); + let field = FieldPathBuilder::new().build(); + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(2u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [2u64, 3]) + } + + #[test] + fn test_disjoint_predicates() { + let arr = + PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid); + let field = FieldPathBuilder::new().build(); + let filtered_primitive = apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(5u32))], + }, + ) + .unwrap() + .flatten_bool() + .unwrap(); + let filtered = to_int_indices(filtered_primitive); + let expected: [u64; 0] = []; + assert_eq!(filtered, expected) + } + + #[test] + fn test_disjunctive_predicate() { + let arr = + PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid); + let field = FieldPathBuilder::new().build(); + let c1 = Conjunction { + predicates: vec![field.clone().lt(lit(5u32))], + }; + let c2 = Conjunction { + predicates: vec![field.clone().gt(lit(5u32))], + }; + + let disj = Disjunction { + conjunctions: vec![c1, c2], + }; + let filtered_primitive = arr.filter_indices(&disj).unwrap().flatten_bool().unwrap(); + let filtered = to_int_indices(filtered_primitive); + assert_eq!(filtered, [0u64, 1, 2, 3, 5, 6, 7, 8, 9]) + } + + #[test] + fn test_invalid_path_err() { + let arr = + PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid); + let field = FieldPathBuilder::new().join("some_field").build(); + apply_conjunctive_filter( + &arr, + Conjunction { + predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(5u32))], + }, + ) + .expect_err("Cannot apply field reference to primitive array"); + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index cbf64331dc..8aa141a4de 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -14,6 +14,7 @@ mod as_arrow; mod as_contiguous; mod cast; mod fill; +mod filter_indices; mod scalar_at; mod search_sorted; mod slice; diff --git a/vortex-array/src/compute/filter_indices.rs b/vortex-array/src/compute/filter_indices.rs new file mode 100644 index 0000000000..0f399786eb --- /dev/null +++ b/vortex-array/src/compute/filter_indices.rs @@ -0,0 +1,29 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexResult}; +use vortex_expr::expressions::Disjunction; + +use crate::{Array, ArrayDType}; + +pub trait FilterIndicesFn { + fn filter_indices(&self, predicate: &Disjunction) -> VortexResult; +} + +pub fn filter_indices(array: &Array, predicate: &Disjunction) -> VortexResult { + if let Some(matching_indices) = + array.with_dyn(|c| c.filter_indices().map(|t| t.filter_indices(predicate))) + { + return matching_indices; + } + // if filter is not implemented for the given array type, but the array has a numeric + // DType, we can flatten the array and apply filter to the flattened primitive array + match array.dtype() { + DType::Primitive(..) => { + let flat = array.clone().flatten_primitive()?; + flat.filter_indices(predicate) + } + _ => Err(vortex_err!( + NotImplemented: "filter_indices", + array.encoding().id() + )), + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 9974e63236..8853d27d53 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -8,12 +8,14 @@ use search_sorted::SearchSortedFn; use slice::SliceFn; use take::TakeFn; +use crate::compute::filter_indices::FilterIndicesFn; use crate::compute::scalar_subtract::SubtractScalarFn; pub mod as_arrow; pub mod as_contiguous; pub mod cast; pub mod fill; +pub mod filter_indices; pub mod patch; pub mod scalar_at; pub mod scalar_subtract; @@ -38,6 +40,10 @@ pub trait ArrayCompute { None } + fn filter_indices(&self) -> Option<&dyn FilterIndicesFn> { + None + } + fn patch(&self) -> Option<&dyn PatchFn> { None } diff --git a/vortex-dtype/src/field_paths.rs b/vortex-dtype/src/field_paths.rs index d1497fc985..00f2052fe7 100644 --- a/vortex-dtype/src/field_paths.rs +++ b/vortex-dtype/src/field_paths.rs @@ -1,8 +1,6 @@ use core::fmt; use std::fmt::{Display, Formatter}; -use vortex_error::{vortex_bail, VortexResult}; - #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct FieldPath { @@ -13,6 +11,19 @@ impl FieldPath { pub fn builder() -> FieldPathBuilder { FieldPathBuilder::default() } + + pub fn head(&self) -> Option<&FieldIdentifier> { + self.field_names.first() + } + + pub fn tail(&self) -> Option { + if self.head().is_none() { + None + } else { + let new_field_names = self.field_names[1..self.field_names.len()].to_vec(); + Some(Self::builder().join_all(new_field_names).build()) + } + } } #[derive(Clone, Debug, PartialEq)] @@ -38,13 +49,16 @@ impl FieldPathBuilder { self } - pub fn build(self) -> VortexResult { - if self.field_names.is_empty() { - vortex_bail!("Cannot build empty path"); - } - Ok(FieldPath { + pub fn join_all(mut self, identifiers: Vec>) -> Self { + self.field_names + .extend(identifiers.into_iter().map(|v| v.into())); + self + } + + pub fn build(self) -> FieldPath { + FieldPath { field_names: self.field_names, - }) + } } } diff --git a/vortex-expr/src/display.rs b/vortex-expr/src/display.rs index 90a8acbee1..cc37916207 100644 --- a/vortex-expr/src/display.rs +++ b/vortex-expr/src/display.rs @@ -69,13 +69,13 @@ mod tests { assert_eq!(format!("{}", !lit(1u32).lte(f1)), "($field <= 1)"); // nested field path - let f2 = FieldPath::builder().join("field").join(0).build().unwrap(); + let f2 = FieldPath::builder().join("field").join(0).build(); assert_eq!(format!("{}", !f2.lte(lit(1u32))), "($field.[0] > 1)"); } #[test] fn test_dnf_formatting() { - let path = FieldPath::builder().join(2).join("col1").build().unwrap(); + let path = FieldPath::builder().join(2).join("col1").build(); let d1 = Conjunction { predicates: vec![ lit(1u32).lt(path.clone()), @@ -83,7 +83,7 @@ mod tests { !lit(1u32).lte(path), ], }; - let path2 = FieldPath::builder().join("col1").join(2).build().unwrap(); + let path2 = FieldPath::builder().join("col1").join(2).build(); let d2 = Conjunction { predicates: vec![ lit(2u32).lt(path2), diff --git a/vortex-expr/src/operators.rs b/vortex-expr/src/operators.rs index da762dc586..b6c6c289af 100644 --- a/vortex-expr/src/operators.rs +++ b/vortex-expr/src/operators.rs @@ -1,5 +1,7 @@ use std::ops; +use vortex_dtype::NativePType; + use crate::expressions::Predicate; #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)] @@ -45,4 +47,15 @@ impl Operator { Operator::LessThanOrEqualTo => Operator::GreaterThan, } } + + pub fn to_predicate(&self) -> fn(&T, &T) -> bool { + match self { + Operator::EqualTo => PartialEq::eq, + Operator::NotEqualTo => PartialEq::ne, + Operator::GreaterThan => PartialOrd::gt, + Operator::GreaterThanOrEqualTo => PartialOrd::ge, + Operator::LessThan => PartialOrd::lt, + Operator::LessThanOrEqualTo => PartialOrd::le, + } + } }