diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index baf6c91768..eab61522fc 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -1,5 +1,5 @@ use std::iter::TrustedLen; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use arrow_array::BooleanArray; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer}; @@ -118,9 +118,9 @@ pub struct FilterMask { array: ArrayData, true_count: usize, range_selectivity: f64, - indices: OnceLock>, - slices: OnceLock>, - buffer: OnceLock, + indices: Arc>>, + slices: Arc>>, + buffer: Arc>, } /// We implement Clone manually to trigger population of our cached indices or slices. @@ -328,9 +328,9 @@ impl TryFrom for FilterMask { array, true_count, range_selectivity: selectivity, - indices: OnceLock::new(), - slices: OnceLock::new(), - buffer: OnceLock::new(), + indices: Arc::new(OnceLock::new()), + slices: Arc::new(OnceLock::new()), + buffer: Arc::new(OnceLock::new()), }) } } diff --git a/vortex-file/src/read/layouts/chunked.rs b/vortex-file/src/read/layouts/chunked.rs index 74ce4c0f81..cb085c5b55 100644 --- a/vortex-file/src/read/layouts/chunked.rs +++ b/vortex-file/src/read/layouts/chunked.rs @@ -409,9 +409,10 @@ mod tests { use bytes::Bytes; use flatbuffers::{root, FlatBufferBuilder}; use futures_util::TryStreamExt; - use vortex_array::array::{BoolArray, ChunkedArray, PrimitiveArray}; + use vortex_array::array::{ChunkedArray, PrimitiveArray}; + use vortex_array::compute::FilterMask; use vortex_array::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant}; - use vortex_dtype::{Nullability, PType}; + use vortex_dtype::PType; use vortex_expr::{BinaryExpr, Identity, Literal, Operator}; use vortex_flatbuffers::{footer, WriteFlatBuffer}; use vortex_ipc::messages::writer::MessageWriter; @@ -582,18 +583,8 @@ mod tests { snd_range.append_n(100, true); snd_range.append_n(50, false); let mut arr = [ - RowMask::try_new( - BoolArray::new(first_range.finish(), Nullability::NonNullable).into_array(), - 0, - 200, - ) - .unwrap(), - RowMask::try_new( - BoolArray::new(snd_range.finish(), Nullability::NonNullable).into_array(), - 200, - 400, - ) - .unwrap(), + RowMask::try_new(FilterMask::from(first_range.finish()), 0, 200).unwrap(), + RowMask::try_new(FilterMask::from(snd_range.finish()), 200, 400).unwrap(), RowMask::new_valid_between(400, 500), ] .into_iter() diff --git a/vortex-file/src/read/mask.rs b/vortex-file/src/read/mask.rs index a82749b91b..752ea363a1 100644 --- a/vortex-file/src/read/mask.rs +++ b/vortex-file/src/read/mask.rs @@ -1,19 +1,22 @@ use std::cmp::{max, min}; use std::fmt::{Display, Formatter}; -use vortex_array::array::{BoolArray, ConstantArray, PrimitiveArray, SparseArray}; +use arrow_buffer::BooleanBuffer; +use itertools::Itertools; +use vortex_array::array::{PrimitiveArray, SparseArray}; use vortex_array::compute::{and, filter, slice, try_cast, FilterMask}; -use vortex_array::stats::ArrayStatistics; -use vortex_array::validity::{ArrayValidity, LogicalValidity}; -use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; +use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity}; +use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, PType}; -use vortex_error::{vortex_bail, VortexExpect, VortexResult, VortexUnwrap}; +use vortex_error::{vortex_bail, VortexResult, VortexUnwrap}; -/// Bitmap of selected rows within given [begin, end) row range +/// A RowMask captures a set of selected rows offset by a range. +/// +/// i.e., row zero of the inner FilterMask represents the offset row of the RowMask. #[derive(Debug, Clone)] pub struct RowMask { - bitmask: ArrayData, + mask: FilterMask, begin: usize, end: usize, } @@ -21,21 +24,9 @@ pub struct RowMask { #[cfg(test)] impl PartialEq for RowMask { fn eq(&self, other: &Self) -> bool { - use vortex_error::VortexUnwrap; self.begin == other.begin && self.end == other.end - && self - .bitmask - .clone() - .into_bool() - .vortex_unwrap() - .boolean_buffer() - == other - .bitmask - .clone() - .into_bool() - .vortex_unwrap() - .boolean_buffer() + && self.mask.to_boolean_buffer().unwrap() == other.mask.to_boolean_buffer().unwrap() } } @@ -46,32 +37,22 @@ impl Display for RowMask { } impl RowMask { - pub fn try_new(bitmask: ArrayData, begin: usize, end: usize) -> VortexResult { - if bitmask.dtype() != &DType::Bool(NonNullable) { + pub fn try_new(mask: FilterMask, begin: usize, end: usize) -> VortexResult { + if mask.len() != (end - begin) { vortex_bail!( - "bitmask must be a nonnullable bool array {}", - bitmask.dtype() - ) - } - if bitmask.len() != (end - begin) { - vortex_bail!( - "Bitmask must be the same length {} as the given range {}..{}", - bitmask.len(), + "FilterMask must be the same length {} as the given range {}..{}", + mask.len(), begin, end ); } - Ok(Self { - bitmask, - begin, - end, - }) + Ok(Self { mask, begin, end }) } /// Construct a RowMask which is valid in the given range. pub fn new_valid_between(begin: usize, end: usize) -> Self { RowMask::try_new( - ConstantArray::new(true, end - begin).into_array(), + FilterMask::from(BooleanBuffer::new_set(end - begin)), begin, end, ) @@ -81,7 +62,7 @@ impl RowMask { /// Construct a RowMask which is invalid everywhere in the given range. pub fn new_invalid_between(begin: usize, end: usize) -> Self { RowMask::try_new( - ConstantArray::new(false, end - begin).into_array(), + FilterMask::from(BooleanBuffer::new_unset(end - begin)), begin, end, ) @@ -93,11 +74,13 @@ impl RowMask { /// True-valued positions are kept by the returned mask. pub fn from_mask_array(array: &ArrayData, begin: usize, end: usize) -> VortexResult { match array.logical_validity() { - LogicalValidity::AllValid(_) => Self::try_new(array.clone(), begin, end), + LogicalValidity::AllValid(_) => { + Self::try_new(FilterMask::try_from(array.clone())?, begin, end) + } LogicalValidity::AllInvalid(_) => Ok(Self::new_invalid_between(begin, end)), LogicalValidity::Array(validity) => { let bitmask = and(array.clone(), validity)?; - Self::try_new(bitmask, begin, end) + Self::try_new(FilterMask::try_from(bitmask)?, begin, end) } } } @@ -108,32 +91,28 @@ impl RowMask { pub fn from_index_array(array: &ArrayData, begin: usize, end: usize) -> VortexResult { let indices = try_cast(array, &DType::Primitive(PType::U64, NonNullable))?.into_primitive()?; - let bools = BoolArray::from_indices( + + // TODO(ngates): should from_indices take u64? + let mask = FilterMask::from_indices( end - begin, indices .maybe_null_slice::() .iter() - .copied() - .map(|i| usize::try_from(i).vortex_unwrap()), + .map(|i| *i as usize), ); - RowMask::try_new(bools.into_array(), begin, end) + + RowMask::try_new(mask, begin, end) } /// Combine the RowMask with bitmask values resulting in new RowMask containing only values true in the bitmask pub fn and_bitmask(&self, bitmask: ArrayData) -> VortexResult { // If we are a dense all true bitmap just take the bitmask array - if self - .bitmask - .statistics() - .compute_true_count() - .map(|true_count| true_count == self.len()) - .unwrap_or(false) - { + if self.mask.true_count() == self.len() { if bitmask.len() != self.len() { vortex_bail!( "Bitmask length {} does not match our length {}", bitmask.len(), - self.bitmask.len() + self.mask.len() ); } Self::from_mask_array(&bitmask, self.begin, self.end) @@ -147,12 +126,8 @@ impl RowMask { } } - pub fn is_empty(&self) -> bool { - self.bitmask - .statistics() - .compute_true_count() - .vortex_expect("Must have true count") - == 0 + pub fn is_all_false(&self) -> bool { + self.mask.true_count() == 0 } pub fn begin(&self) -> usize { @@ -164,7 +139,11 @@ impl RowMask { } pub fn len(&self) -> usize { - self.bitmask.len() + self.mask.len() + } + + pub fn is_empty(&self) -> bool { + self.mask.is_empty() } /// Limit mask to [begin..end) range @@ -173,13 +152,13 @@ impl RowMask { let range_end = min(self.end, end); RowMask::try_new( if range_begin == self.begin && range_end == self.end { - self.bitmask.clone() + self.mask.clone() } else { - slice( - &self.bitmask, - range_begin - self.begin, - range_end - self.begin, - )? + FilterMask::from( + self.mask + .to_boolean_buffer()? + .slice(range_begin - self.begin, range_end - range_begin), + ) }, range_begin, range_end, @@ -191,8 +170,8 @@ impl RowMask { /// This function assumes that Array is no longer than the mask length and that the mask starts on same offset as the array, /// i.e. the beginning of the array corresponds to the beginning of the mask with begin = 0 pub fn filter_array(&self, array: impl AsRef) -> VortexResult> { - let true_count = self.bitmask.statistics().compute_true_count(); - if true_count.map(|tc| tc == 0).unwrap_or(false) { + let true_count = self.mask.true_count(); + if true_count == 0 { return Ok(None); } @@ -204,23 +183,18 @@ impl RowMask { &slice(array, self.begin, self.end)? }; - if true_count.map(|tc| tc == sliced.len()).unwrap_or(false) { + if true_count == sliced.len() { return Ok(Some(sliced.clone())); } - let mask = FilterMask::try_from(self.bitmask.clone())?; - filter(sliced, mask).map(Some) + filter(sliced, self.mask.clone()).map(Some) } - pub fn to_indices_array(&self) -> VortexResult { - Ok(PrimitiveArray::from( - self.bitmask - .clone() - .into_bool()? - .boolean_buffer() - .set_indices() - .map(|i| i as u64) - .collect::>(), + #[allow(deprecated)] + fn to_indices_array(&self) -> VortexResult { + Ok(PrimitiveArray::from_vec( + self.mask.iter_indices()?.map(|i| i as u64).collect_vec(), + Validity::NonNullable, ) .into_array()) } @@ -233,7 +207,7 @@ impl RowMask { self.begin ) } - RowMask::try_new(self.bitmask, self.begin - offset, self.end - offset) + RowMask::try_new(self.mask, self.begin - offset, self.end - offset) } } @@ -241,34 +215,34 @@ impl RowMask { mod tests { use arrow_buffer::BooleanBuffer; use rstest::rstest; - use vortex_array::array::{BoolArray, PrimitiveArray}; + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::FilterMask; use vortex_array::{IntoArrayData, IntoArrayVariant}; - use vortex_dtype::Nullability; use vortex_error::VortexUnwrap; use crate::read::mask::RowMask; #[rstest] #[case( - RowMask::try_new(BoolArray::from_iter([true, true, true, false, false, false, false, false, true, true]).into_array(), 0, 10).unwrap(), (0, 1), - RowMask::try_new(BoolArray::from_iter([true]).into_array(), 0, 1).unwrap())] + RowMask::try_new(FilterMask::from_iter([true, true, true, false, false, false, false, false, true, true]), 0, 10).unwrap(), (0, 1), + RowMask::try_new(FilterMask::from_iter([true]), 0, 1).unwrap())] #[case( - RowMask::try_new(BoolArray::from_iter([false, false, false, false, false, true, true, true, true, true]).into_array(), 0, 10).unwrap(), (2, 5), - RowMask::try_new(BoolArray::from_iter([false, false, false]).into_array(), 2, 5).unwrap() + RowMask::try_new(FilterMask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0, 10).unwrap(), (2, 5), + RowMask::try_new(FilterMask::from_iter([false, false, false]), 2, 5).unwrap() )] #[case( - RowMask::try_new(BoolArray::from_iter([true, true, true, true, false, false, false, false, false, false]).into_array(), 0, 10).unwrap(), (2, 5), - RowMask::try_new(BoolArray::from_iter([true, true, false]).into_array(), 2, 5).unwrap() + RowMask::try_new(FilterMask::from_iter([true, true, true, true, false, false, false, false, false, false]), 0, 10).unwrap(), (2, 5), + RowMask::try_new(FilterMask::from_iter([true, true, false]), 2, 5).unwrap() )] #[case( - RowMask::try_new(BoolArray::from_iter([true, true, true, false, false, true, true, false, false, false]).into_array(), 0, 10).unwrap(), (2, 6), - RowMask::try_new(BoolArray::from_iter([true, false, false, true]).into_array(), 2, 6).unwrap())] + RowMask::try_new(FilterMask::from_iter([true, true, true, false, false, true, true, false, false, false]), 0, 10).unwrap(), (2, 6), + RowMask::try_new(FilterMask::from_iter([true, false, false, true]), 2, 6).unwrap())] #[case( - RowMask::try_new(BoolArray::from_iter([false, false, false, false, false, true, true, true, true, true]).into_array(), 0, 10).unwrap(), (7, 11), - RowMask::try_new(BoolArray::from_iter([true, true, true]).into_array(), 7, 10).unwrap())] + RowMask::try_new(FilterMask::from_iter([false, false, false, false, false, true, true, true, true, true]), 0, 10).unwrap(), (7, 11), + RowMask::try_new(FilterMask::from_iter([true, true, true]), 7, 10).unwrap())] #[case( - RowMask::try_new(BoolArray::from_iter([false, true, true, true, true, true]).into_array(), 3, 9).unwrap(), (0, 5), - RowMask::try_new(BoolArray::from_iter([false, true]).into_array(), 3, 5).unwrap())] + RowMask::try_new(FilterMask::from_iter([false, true, true, true, true, true]), 3, 9).unwrap(), (0, 5), + RowMask::try_new(FilterMask::from_iter([false, true]), 3, 5).unwrap())] #[cfg_attr(miri, ignore)] fn slice(#[case] first: RowMask, #[case] range: (usize, usize), #[case] expected: RowMask) { assert_eq!(first.slice(range.0, range.1).vortex_unwrap(), expected); @@ -278,46 +252,28 @@ mod tests { #[should_panic] #[cfg_attr(miri, ignore)] fn test_new() { - RowMask::try_new( - BoolArray::new(BooleanBuffer::new_unset(10), Nullability::NonNullable).into_array(), - 5, - 10, - ) - .unwrap(); + RowMask::try_new(FilterMask::from(BooleanBuffer::new_unset(10)), 5, 10).unwrap(); } #[test] #[should_panic] #[cfg_attr(miri, ignore)] fn shift_invalid() { - RowMask::try_new( - BoolArray::from_iter([true, true, true, true, true]).into_array(), - 5, - 10, - ) - .unwrap() - .shift(7) - .unwrap(); + RowMask::try_new(FilterMask::from_iter([true, true, true, true, true]), 5, 10) + .unwrap() + .shift(7) + .unwrap(); } #[test] #[cfg_attr(miri, ignore)] fn shift() { assert_eq!( - RowMask::try_new( - BoolArray::from_iter([true, true, true, true, true]).into_array(), - 5, - 10 - ) - .unwrap() - .shift(5) - .unwrap(), - RowMask::try_new( - BoolArray::from_iter([true, true, true, true, true]).into_array(), - 0, - 5 - ) - .unwrap() + RowMask::try_new(FilterMask::from_iter([true, true, true, true, true]), 5, 10) + .unwrap() + .shift(5) + .unwrap(), + RowMask::try_new(FilterMask::from_iter([true, true, true, true, true]), 0, 5).unwrap() ); } @@ -325,10 +281,9 @@ mod tests { #[cfg_attr(miri, ignore)] fn filter_array() { let mask = RowMask::try_new( - BoolArray::from_iter([ + FilterMask::from_iter([ false, false, false, false, false, true, true, true, true, true, - ]) - .into_array(), + ]), 0, 10, ) diff --git a/vortex-file/src/read/splits.rs b/vortex-file/src/read/splits.rs index 45a3edda77..d37aa783a8 100644 --- a/vortex-file/src/read/splits.rs +++ b/vortex-file/src/read/splits.rs @@ -109,7 +109,7 @@ impl Iterator for FixedSplitIterator { Err(e) => return Some(Err(e)), }; - if sliced.is_empty() { + if sliced.is_all_false() { continue; } Some(Ok(sliced)) @@ -141,8 +141,7 @@ impl Stream for FixedSplitIterator { mod tests { use std::collections::BTreeSet; - use vortex_array::array::BoolArray; - use vortex_array::IntoArrayData; + use vortex_array::compute::FilterMask; use vortex_error::VortexResult; use crate::read::splits::FixedSplitIterator; @@ -170,10 +169,9 @@ mod tests { 10, Some( RowMask::try_new( - BoolArray::from_iter([ + FilterMask::from_iter([ false, false, false, false, true, true, false, false, false, false, - ]) - .into_array(), + ]), 0, 10, ) @@ -183,9 +181,10 @@ mod tests { mask_iter .additional_splits(&mut BTreeSet::from([0, 2, 4, 6, 8, 10])) .unwrap(); - assert_eq!( - mask_iter.collect::>>().unwrap(), - vec![RowMask::new_valid_between(4, 6)] - ); + + let actual = mask_iter.collect::>>().unwrap(); + let expected = vec![RowMask::new_valid_between(4, 6)]; + + assert_eq!(actual, expected); } }