Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FilterIndices compute function #326

Merged
merged 17 commits into from
May 20, 2024
Prev Previous commit
Next Next commit
temp
  • Loading branch information
jdcasale committed May 17, 2024
commit f9e9268df199801549e7629554168b3eccdd47b1
132 changes: 78 additions & 54 deletions vortex-array/src/array/primitive/compute/filter_indices.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use croaring::Bitmap;
use std::ops::{BitAnd, BitOr};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
use itertools::Itertools;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::{vortex_bail, VortexResult};
Expand All @@ -13,52 +14,64 @@ use crate::{Array, ArrayTrait, IntoArray};
impl FilterIndicesFn for PrimitiveArray {
fn filter_indices(&self, predicate: &Disjunction) -> VortexResult<Array> {
let mut conjunction_indices = predicate.conjunctions.iter().flat_map(|conj| {
BitmapMergeOp::All(
BooleanBufferMergeOp::All(
&mut conj
.predicates
.iter()
.map(|pred| indices_matching_predicate(self, pred).unwrap()),
)
.merge()
.merge()
});
let indices = BitmapMergeOp::Any(&mut conjunction_indices)
let indices = BooleanBufferMergeOp::Any(&mut conjunction_indices)
.merge()
.map(|bitmap| bitmap.iter().map(|idx| idx as u64).collect_vec())
.map(|bitset| bitset.iter()
.enumerate()
.flat_map(|(idx, v)| if v {
Some(idx as u64)
} else {
None
})
.collect_vec())
.unwrap_or(Vec::new());
Ok(PrimitiveArray::from_vec(indices, Validity::AllValid).into_array())
}
}

fn indices_matching_predicate(arr: &PrimitiveArray, predicate: &Predicate) -> VortexResult<Bitmap> {
fn indices_matching_predicate(arr: &PrimitiveArray, predicate: &Predicate) -> VortexResult<BooleanBuffer> {
if predicate.left.head().is_some() {
vortex_bail!("Invalid path for primitive array")
}
let validity = arr.validity();

let rhs = match &predicate.right {
Value::Field(_) => {
vortex_bail!("Right-hand-side fields not yet supported.")
}
Value::Literal(scalar) => scalar,
};

let matching_idxs: Vec<u32> = match_each_native_ptype!(arr.ptype(), |$T| {
let validity_buf = arr.validity()
.to_logical(arr.len())
.to_present_null_buffer()?.into_inner();

let matching_idxs = match_each_native_ptype!(arr.ptype(), |$T| {
let rhs_typed: $T = rhs.try_into().unwrap();
let predicate_fn = get_predicate::<$T>(&predicate.op);

arr.typed_data::<$T>().iter().enumerate().filter(|(idx, &v)| {
predicate_fn(&v, &rhs_typed)
})
.filter(|(idx, _)| validity.is_valid(idx.clone()))
//todo(@jcasale): 64-bit RoaringBitmap?
.map(|(idx, _)| idx as u32)
.collect_vec()
apply_predicate(arr.typed_data::<$T>(), &rhs_typed, predicate_fn)
});
//todo(@jcasale): 64-bit RoaringBitmap?
let mut bitmap = Bitmap::with_container_capacity(arr.len() as u32);

matching_idxs.into_iter().for_each(|idx| bitmap.add(idx));
Ok(matching_idxs.bitand(&validity_buf))
}

Ok(bitmap)
fn apply_predicate<T: NativePType, F: Fn(&T, &T) -> bool>(lhs: &[T], rhs: &T, f: F) -> BooleanBuffer {
let matches = lhs.iter()
.filter(|lhs| f(lhs, rhs))
.enumerate()
.map(|(idx, _)| idx)
.collect_vec();
let mut matching_idx_bitset = BooleanBufferBuilder::new(lhs.len());
matching_idx_bitset.resize(lhs.len());
matches.into_iter().for_each(|idx| matching_idx_bitset.set_bit(idx, true));
matching_idx_bitset.finish()
}

fn get_predicate<T: NativePType>(op: &Operator) -> fn(&T, &T) -> bool {
Expand All @@ -72,17 +85,21 @@ fn get_predicate<T: NativePType>(op: &Operator) -> fn(&T, &T) -> bool {
}
}

/// Merge an arbitrary number of bitmaps
enum BitmapMergeOp<'a> {
Any(&'a mut dyn Iterator<Item = Bitmap>),
All(&'a mut dyn Iterator<Item = Bitmap>),
/// Merge an arbitrary number of bitsets
enum BooleanBufferMergeOp<'a> {
Any(&'a mut dyn Iterator<Item=BooleanBuffer>),
All(&'a mut dyn Iterator<Item=BooleanBuffer>),
}

impl BitmapMergeOp<'_> {
fn merge(self) -> Option<Bitmap> {
impl BooleanBufferMergeOp<'_> {
fn merge(self) -> Option<BooleanBuffer> {
match self {
BitmapMergeOp::Any(bitmaps) => bitmaps.reduce(|a, b| a.or(&b)),
BitmapMergeOp::All(bitmaps) => bitmaps.reduce(|a, b| a.and(&b)),
BooleanBufferMergeOp::Any(bitsets) => bitsets.reduce(|a, b| {
a.bitor(&b)
}),
BooleanBufferMergeOp::All(bitsets) => bitsets.reduce(|a, b| {
a.bitand(&b)
}),
}
}
}
Expand All @@ -104,7 +121,14 @@ mod test {

#[test]
fn test_basic_filter() {
let arr = PrimitiveArray::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::AllValid);
let arr = PrimitiveArray::from_nullable_vec(vec![
Some(1u32), Some(2), Some(3), Some(4),
None,
Some(5), Some(6), Some(7), Some(8),
None,
Some(9),
None,
]);

let field = FieldPathBuilder::new().build();
let filtered_primitive = apply_conjunctive_filter(
Expand All @@ -113,9 +137,9 @@ mod test {
predicates: vec![field.clone().lt(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [0u64, 1, 2, 3]);

Expand All @@ -125,47 +149,47 @@ mod test {
predicates: vec![field.clone().gt(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [5u64, 6, 7, 8]);
assert_eq!(filtered, [6u64, 7, 8, 10]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().eq(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [4]);
assert_eq!(filtered, [5]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().gte(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [4u64, 5, 6, 7, 8]);
assert_eq!(filtered, [5u64, 6, 7, 8, 10]);

let filtered_primitive = apply_conjunctive_filter(
&arr,
Conjunction {
predicates: vec![field.clone().lte(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [0u64, 1, 2, 3, 4]);
assert_eq!(filtered, [0u64, 1, 2, 3, 5]);
}

#[test]
Expand All @@ -179,9 +203,9 @@ mod test {
predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(2u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
assert_eq!(filtered, [2, 3])
}
Expand All @@ -197,9 +221,9 @@ mod test {
predicates: vec![field.clone().lt(lit(5u32)), field.clone().gt(lit(5u32))],
},
)
.unwrap()
.flatten_primitive()
.unwrap();
.unwrap()
.flatten_primitive()
.unwrap();
let filtered = filtered_primitive.typed_data::<u64>();
let expected: [u64; 0] = [];
assert_eq!(filtered, expected)
Expand Down