Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FilterIndices compute function #326

Merged
merged 17 commits into from
May 20, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ rand = { workspace = true }
vortex-buffer = { path = "../vortex-buffer" }
vortex-dtype = { path = "../vortex-dtype", features = ["flatbuffers", "serde"] }
vortex-error = { path = "../vortex-error", features = ["flexbuffers"] }
vortex-expr = { path = "../vortex-expr" }
vortex-flatbuffers = { path = "../vortex-flatbuffers" }
vortex-scalar = { path = "../vortex-scalar", features = ["flatbuffers", "serde"] }
serde = { workspace = true, features = ["derive"] }
Expand Down
246 changes: 246 additions & 0 deletions vortex-array/src/array/primitive/compute/filter_indices.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
use itertools::Itertools;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::{vortex_bail, VortexResult};
use vortex_expr::expressions::{Disjunction, Predicate, Value};
use vortex_expr::operators::Operator;

use crate::array::primitive::PrimitiveArray;
use crate::compute::filter_indices::FilterIndicesFn;
use crate::validity::Validity;
use crate::{Array, ArrayTrait, IntoArray};

impl FilterIndicesFn for PrimitiveArray {
fn filter_indices(&self, predicate: &Disjunction) -> VortexResult<Array> {
let conjunction_indices = predicate
.conjunctions
.iter()
.map(|conj| {
MergeOp::All(
conj.predicates
.iter()
.map(|pred| indices_matching_predicate(self, pred).unwrap())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to echo my understanding:

right now, this instantiates a byte per value per predicate, then converts those Vec<bool> instances to iterators, then calls collect_vec again...? So we have a Vec of iterators (which are backed by Vecs), basically (N + 1) allocations of self.len bytes for N predicates, which get passed to MergeOp.

MergeOp is itself an iterator, which we collect in order to force evaluation of the All predicate, then we do it all over again with a cycle of iterators and collect calls to evaluate the Any predicate, enumerate/filter/collect to get indices.

I'm concerned that that's an awful lot of heap allocations on an extremely hot code path. From a machine efficiency point-of-view, we would ideally have at most 2 allocations and end up producing SIMD instructions to do bitwise AND on bitmaps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point -- initially I wanted to avoid doing pairwise reductions in favor of a single-pass, but the allocations/vec overhead here might outweigh that benefit anyway. I'll rewrite this to be fully lazy and use bitmaps instead of vecs.

The one thing I'm not sure of here is whether we have a 64-bit bitmap in the rust croaring crate -- at first glance I didn't see one, will take a closer look tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's not actually that hot. It's an allocation per predicate, so an expression of (X > 2 & X < 10) is two predicates, regardless of how big the array X is.

With the right bitset utility, you should be able to mutate-in-place instead of allocating a third bitset for the result. But without benchmarking, it's hard to intuit which will perform better.

.map(|a| a.into_iter())
.collect_vec(),
)
.collect_vec()
.into_iter()
})
.collect_vec();
let indices = MergeOp::Any(conjunction_indices)
.enumerate()
.filter(|(_, v)| *v)
.map(|(idx, _)| idx as u64)
.collect_vec();
Ok(PrimitiveArray::from_vec(indices, Validity::AllValid).into_array())
}
}

fn indices_matching_predicate(
arr: &PrimitiveArray,
predicate: &Predicate,
) -> VortexResult<Vec<bool>> {
if predicate.left.first().is_some() {
vortex_bail!("Invalid path for primitive array")
}
let validity = arr.validity();
let rhs = match &predicate.right {
Value::Field(_) => {
vortex_bail!("")
jdcasale marked this conversation as resolved.
Show resolved Hide resolved
}
Value::Literal(scalar) => scalar,
};

let matching_idxs = match_each_native_ptype!(arr.ptype(), |$T| {
let rhs_typed: $T = rhs.try_into().unwrap();
let predicate_fn = get_predicate::<$T>(&predicate.op);
arr.typed_data::<$T>().iter().enumerate().filter(|(idx, &v)| {
jdcasale marked this conversation as resolved.
Show resolved Hide resolved
predicate_fn(&v, &rhs_typed)
})
.filter(|(idx, _)| validity.is_valid(idx.clone()))
.map(|(idx, _)| idx )
.collect_vec()
});
let mut bitmap = vec![false; arr.len()];
jdcasale marked this conversation as resolved.
Show resolved Hide resolved
matching_idxs.into_iter().for_each(|idx| bitmap[idx] = true);

Ok(bitmap)
}

fn get_predicate<T: NativePType>(op: &Operator) -> fn(&T, &T) -> bool {
match op {
Operator::EqualTo => PartialEq::eq,
Operator::NotEqualTo => PartialEq::ne,
Operator::GreaterThan => PartialOrd::gt,
Operator::GreaterThanOrEqualTo => PartialOrd::ge,
Operator::LessThan => PartialOrd::lt,
Operator::LessThanOrEqualTo => PartialOrd::le,
}
}

/// Merge an arbitrary number of boolean iterators
enum MergeOp {
Any(Vec<std::vec::IntoIter<bool>>),
jdcasale marked this conversation as resolved.
Show resolved Hide resolved
All(Vec<std::vec::IntoIter<bool>>),
}

impl Iterator for MergeOp {
type Item = bool;

fn next(&mut self) -> Option<Self::Item> {
let zipped = match self {
MergeOp::Any(vecs) => vecs,
MergeOp::All(vecs) => vecs,
}
.iter_mut()
.map(|iter| iter.next())
.collect::<Option<Vec<_>>>();

match self {
MergeOp::Any(_) => zipped.map(|inner| inner.iter().any(|&v| v)),
MergeOp::All(_) => zipped.map(|inner| inner.iter().all(|&v| v)),
}
}
}

#[cfg(test)]
mod test {
use vortex_dtype::field_paths::FieldPathBuilder;
use vortex_expr::expressions::{lit, Conjunction};
use vortex_expr::field_paths::FieldPathOperations;

use super::*;
use crate::validity::Validity;

fn apply_conjunctive_filter(arr: &PrimitiveArray, conj: Conjunction) -> VortexResult<Array> {
arr.filter_indices(&Disjunction {
conjunctions: vec![conj],
})
}

#[test]
fn test_basic_filter() {
let arr = PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::AllValid);

let field = FieldPathBuilder::new().build();
let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().lt(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [0u64, 1, 2, 3]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().gt(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [5u64, 6, 7, 8]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().eq(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [4]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().gte(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [4u64, 5, 6, 7, 8]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().lte(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [0u64, 1, 2, 3, 4]);
}

#[test]
fn test_multiple_predicates() {
let arr =
PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid);
let field = FieldPathBuilder::new().build();
let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(2u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [2, 3])
}

#[test]
fn test_disjoint_predicates() {
let arr =
PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid);
let field = FieldPathBuilder::new().build();
let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
let expected: [u64; 0] = [];
assert_eq!(filtered, expected)
}

#[test]
fn test_disjunctive_predicate() {
let arr =
PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10], Validity::AllValid);
let field = FieldPathBuilder::new().build();
let c1 = Conjunction {
predicates: vec![field.clone().lt(lit(5u32))],
};
let c2 = Conjunction {
predicates: vec![field.clone().gt(lit(5u32))],
};

let disj = Disjunction {
conjunctions: vec![c1, c2],
};
let filtered_primitive = arr
.filter_indices(&disj)
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [0u64, 1, 2, 3, 5, 6, 7, 8, 9])
}
}
1 change: 1 addition & 0 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod as_arrow;
mod as_contiguous;
mod cast;
mod fill;
mod filter_indices;
mod scalar_at;
mod search_sorted;
mod slice;
Expand Down
29 changes: 29 additions & 0 deletions vortex-array/src/compute/filter_indices.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};
use vortex_expr::expressions::Disjunction;

use crate::{Array, ArrayDType};

pub trait FilterIndicesFn {
fn filter_indices(&self, predicate: &Disjunction) -> VortexResult<Array>;
}

pub fn filter_indices(array: &Array, predicate: &Disjunction) -> VortexResult<Array> {
if let Some(subtraction_result) =
jdcasale marked this conversation as resolved.
Show resolved Hide resolved
array.with_dyn(|c| c.filter_indices().map(|t| t.filter_indices(predicate)))
{
return subtraction_result;
}
// if filter is not implemented for the given array type, but the array has a numeric
// DType, we can flatten the array and apply filter to the flattened primitive array
match array.dtype() {
DType::Primitive(..) => {
let flat = array.clone().flatten_primitive()?;
flat.filter_indices(predicate)
}
_ => Err(vortex_err!(
NotImplemented: "filter_indices",
array.encoding().id()
)),
}
}
6 changes: 6 additions & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ use search_sorted::SearchSortedFn;
use slice::SliceFn;
use take::TakeFn;

use crate::compute::filter_indices::FilterIndicesFn;
use crate::compute::scalar_subtract::SubtractScalarFn;

pub mod as_arrow;
pub mod as_contiguous;
pub mod cast;
pub mod fill;
pub mod filter_indices;
pub mod patch;
pub mod scalar_at;
pub mod scalar_subtract;
Expand All @@ -38,6 +40,10 @@ pub trait ArrayCompute {
None
}

fn filter_indices(&self) -> Option<&dyn FilterIndicesFn> {
None
}

fn patch(&self) -> Option<&dyn PatchFn> {
None
}
Expand Down
30 changes: 22 additions & 8 deletions vortex-dtype/src/field_paths.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use core::fmt;
use std::fmt::{Display, Formatter};

use vortex_error::{vortex_bail, VortexResult};

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FieldPath {
Expand All @@ -13,6 +11,19 @@ impl FieldPath {
pub fn builder() -> FieldPathBuilder {
FieldPathBuilder::default()
}

pub fn first(&self) -> Option<&FieldIdentifier> {
jdcasale marked this conversation as resolved.
Show resolved Hide resolved
self.field_names.first()
}

pub fn tail(&self) -> Option<Self> {
if self.first().is_none() {
None
} else {
let new_field_names = self.field_names[1..self.field_names.len()].to_vec();
Some(Self::builder().join_all(new_field_names).build())
}
}
}

#[derive(Clone, Debug, PartialEq)]
Expand All @@ -38,13 +49,16 @@ impl FieldPathBuilder {
self
}

pub fn build(self) -> VortexResult<FieldPath> {
if self.field_names.is_empty() {
vortex_bail!("Cannot build empty path");
}
Ok(FieldPath {
pub fn join_all(mut self, identifiers: Vec<impl Into<FieldIdentifier>>) -> Self {
self.field_names
.extend(identifiers.into_iter().map(|v| v.into()));
self
}

pub fn build(self) -> FieldPath {
FieldPath {
field_names: self.field_names,
})
}
}
}

Expand Down
Loading