From d943392d543cd2d590946abed9bd962904f8cd94 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 27 Nov 2024 10:51:27 -0500 Subject: [PATCH] Remove uses of with_dyn for validity (#1487) --- encodings/alp/src/alp/compute/mod.rs | 3 +- encodings/alp/src/alp_rd/compute/scalar_at.rs | 3 +- encodings/datetime-parts/src/array.rs | 6 +- encodings/dict/src/array.rs | 4 +- .../src/bitpacking/compute/scalar_at.rs | 3 +- vortex-array/src/array/chunked/mod.rs | 6 +- vortex-array/src/array/sparse/mod.rs | 17 ++---- vortex-array/src/compress.rs | 6 +- vortex-array/src/compute/scalar_at.rs | 3 +- vortex-array/src/data/mod.rs | 41 +++++-------- vortex-array/src/data/owned.rs | 26 +-------- vortex-array/src/data/viewed.rs | 24 +------- vortex-array/src/nbytes.rs | 2 +- vortex-array/src/tree.rs | 58 +++++++++---------- vortex-file/src/read/mask.rs | 4 +- 15 files changed, 76 insertions(+), 130 deletions(-) diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index 944b003716..950ceb3436 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -4,6 +4,7 @@ use vortex_array::compute::{ filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, TakeFn, TakeOptions, }; +use vortex_array::validity::ArrayValidity; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_error::VortexResult; @@ -36,7 +37,7 @@ impl ComputeVTable for ALPEncoding { impl ScalarAtFn for ALPEncoding { fn scalar_at(&self, array: &ALPArray, index: usize) -> VortexResult { if let Some(patches) = array.patches() { - if patches.with_dyn(|a| a.is_valid(index)) { + if patches.is_valid(index) { // We need to make sure the value is actually in the patches array return scalar_at(&patches, index); } diff --git a/encodings/alp/src/alp_rd/compute/scalar_at.rs b/encodings/alp/src/alp_rd/compute/scalar_at.rs index 72f6c5511c..c3932c9d6d 100644 --- a/encodings/alp/src/alp_rd/compute/scalar_at.rs +++ b/encodings/alp/src/alp_rd/compute/scalar_at.rs @@ -1,4 +1,5 @@ use vortex_array::compute::{scalar_at, ScalarAtFn}; +use vortex_array::validity::ArrayValidity; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -10,7 +11,7 @@ impl ScalarAtFn for ALPRDEncoding { // The left value can either be a direct value, or an exception. // The exceptions array represents exception positions with non-null values. let left: u16 = match array.left_parts_exceptions() { - Some(exceptions) if exceptions.with_dyn(|a| a.is_valid(index)) => { + Some(exceptions) if exceptions.is_valid(index) => { scalar_at(&exceptions, index)?.try_into()? } _ => { diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index c0c628a57f..1c24369b25 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -5,7 +5,7 @@ use vortex_array::array::StructArray; use vortex_array::compute::try_cast; use vortex_array::encoding::ids; use vortex_array::stats::{Stat, StatisticsVTable, StatsSet}; -use vortex_array::validity::{LogicalValidity, Validity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable}; use vortex_array::variants::{ArrayVariants, ExtensionArrayTrait}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -105,9 +105,7 @@ impl DateTimePartsArray { pub fn validity(&self) -> Validity { if self.dtype().is_nullable() { - self.days() - .with_dyn(|a| a.logical_validity()) - .into_validity() + self.days().logical_validity().into_validity() } else { Validity::NonNullable } diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index 0b5064d4d6..c8f3c6d024 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -6,7 +6,7 @@ use vortex_array::array::BoolArray; use vortex_array::compute::{scalar_at, take, TakeOptions}; use vortex_array::encoding::ids; use vortex_array::stats::StatsSet; -use vortex_array::validity::{LogicalValidity, ValidityVTable}; +use vortex_array::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; use vortex_array::{ @@ -91,7 +91,7 @@ impl ValidityVTable for DictEncoding { .as_ref() .try_into() .vortex_expect("Failed to convert dictionary code to usize"); - array.values().with_dyn(|a| a.is_valid(values_index)) + array.values().is_valid(values_index) } fn logical_validity(&self, array: &DictArray) -> LogicalValidity { diff --git a/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs b/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs index 4ddd200ff3..93695a96c4 100644 --- a/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs +++ b/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs @@ -1,4 +1,5 @@ use vortex_array::compute::{scalar_at, ScalarAtFn}; +use vortex_array::validity::ArrayValidity; use vortex_array::ArrayDType; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -9,7 +10,7 @@ impl ScalarAtFn for BitPackedEncoding { fn scalar_at(&self, array: &BitPackedArray, index: usize) -> VortexResult { if let Some(patches) = array.patches() { // NB: All non-null values are considered patches - if patches.with_dyn(|a| a.is_valid(index)) { + if patches.is_valid(index) { return scalar_at(&patches, index)?.cast(array.dtype()); } } diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index 142ea8fc46..146d4726db 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -20,7 +20,7 @@ use crate::iter::{ArrayIterator, ArrayIteratorAdapter}; use crate::stats::ArrayStatistics; use crate::stream::{ArrayStream, ArrayStreamAdapter}; use crate::validity::Validity::NonNullable; -use crate::validity::{LogicalValidity, Validity, ValidityVTable}; +use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, IntoArrayData, IntoCanonical, @@ -230,13 +230,13 @@ impl ValidityVTable for ChunkedEncoding { .unwrap_or_else(|e| { vortex_panic!(e, "ChunkedArray: is_valid failed to find chunk {}", index) }) - .with_dyn(|a| a.is_valid(offset_in_chunk)) + .is_valid(offset_in_chunk) } fn logical_validity(&self, array: &ChunkedArray) -> LogicalValidity { let validity = array .chunks() - .map(|a| a.with_dyn(|arr| arr.logical_validity())) + .map(|a| a.logical_validity()) .collect::(); validity.to_logical(array.len()) } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index df6bb552d5..0c6f4db45d 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -9,7 +9,7 @@ use crate::array::constant::ConstantArray; use crate::compute::{scalar_at, search_sorted, SearchResult, SearchSortedSide}; use crate::encoding::ids; use crate::stats::{ArrayStatistics, Stat, StatisticsVTable, StatsSet}; -use crate::validity::{LogicalValidity, ValidityVTable}; +use crate::validity::{ArrayValidity, LogicalValidity, ValidityVTable}; use crate::variants::PrimitiveArrayTrait; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ @@ -213,7 +213,7 @@ impl ValidityVTable for SparseEncoding { fn is_valid(&self, array: &SparseArray, index: usize) -> bool { match array.search_index(index).map(SearchResult::to_found) { Ok(None) => !array.fill_scalar().is_null(), - Ok(Some(idx)) => array.values().with_dyn(|a| a.is_valid(idx)), + Ok(Some(idx)) => array.values().is_valid(idx), Err(e) => vortex_panic!(e, "Error while finding index {} in sparse array", index), } } @@ -234,9 +234,7 @@ impl ValidityVTable for SparseEncoding { // existing values. SparseArray::try_new_with_offset( array.indices(), - array - .values() - .with_dyn(|a| a.logical_validity().into_array()), + array.values().logical_validity().into_array(), array.len(), array.indices_offset(), true.into(), @@ -257,6 +255,7 @@ mod test { use crate::array::sparse::SparseArray; use crate::compute::{scalar_at, slice, try_cast}; + use crate::validity::ArrayValidity; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; fn nullable_fill() -> Scalar { @@ -348,11 +347,7 @@ mod test { #[test] pub fn sparse_logical_validity() { let array = sparse_array(nullable_fill()); - let validity = array - .with_dyn(|a| a.logical_validity()) - .into_array() - .into_bool() - .unwrap(); + let validity = array.logical_validity().into_array().into_bool().unwrap(); assert_eq!( validity.boolean_buffer().iter().collect_vec(), [false, false, true, false, false, true, false, false, true, false] @@ -365,7 +360,7 @@ mod test { assert_eq!( array - .with_dyn(|a| a.logical_validity()) + .logical_validity() .into_array() .into_bool() .unwrap() diff --git a/vortex-array/src/compress.rs b/vortex-array/src/compress.rs index afcb64edfc..7498486644 100644 --- a/vortex-array/src/compress.rs +++ b/vortex-array/src/compress.rs @@ -17,8 +17,10 @@ pub fn check_validity_unchanged(arr: &ArrayData, compressed: &ArrayData) { let _ = compressed; #[cfg(debug_assertions)] { - let old_validity = arr.with_dyn(|a| a.logical_validity().len()); - let new_validity = compressed.with_dyn(|a| a.logical_validity().len()); + use crate::validity::ArrayValidity; + + let old_validity = arr.logical_validity().len(); + let new_validity = compressed.logical_validity().len(); debug_assert!( old_validity == new_validity, diff --git a/vortex-array/src/compute/scalar_at.rs b/vortex-array/src/compute/scalar_at.rs index d0a4ceaf01..14b78e23f1 100644 --- a/vortex-array/src/compute/scalar_at.rs +++ b/vortex-array/src/compute/scalar_at.rs @@ -2,6 +2,7 @@ use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_scalar::Scalar; use crate::encoding::Encoding; +use crate::validity::ArrayValidity; use crate::{ArrayDType, ArrayData}; /// Implementation of scalar_at for an encoding. @@ -33,7 +34,7 @@ pub fn scalar_at(array: impl AsRef, index: usize) -> VortexResult EncodingRef { match &self.0 { - InnerArrayData::Owned(d) => d.encoding(), - InnerArrayData::Viewed(v) => v.encoding(), + InnerArrayData::Owned(d) => d.encoding, + InnerArrayData::Viewed(v) => v.encoding, } } @@ -138,27 +138,14 @@ impl ArrayData { #[allow(clippy::same_name_method)] pub fn len(&self) -> usize { match &self.0 { - InnerArrayData::Owned(d) => d.len(), - InnerArrayData::Viewed(v) => v.len(), + InnerArrayData::Owned(d) => d.len, + InnerArrayData::Viewed(v) => v.len, } } /// Check whether the array has any data pub fn is_empty(&self) -> bool { - match &self.0 { - InnerArrayData::Owned(d) => d.is_empty(), - InnerArrayData::Viewed(v) => v.is_empty(), - } - } - - /// Return whether the element at the given index is valid (true) or null (false). - fn is_valid(&self, index: usize) -> bool { - self.encoding().is_valid(self, index) - } - - /// Return the logical validity of the array. - fn logical_validity(&self) -> LogicalValidity { - self.encoding().logical_validity(self) + self.len() == 0 } /// Whether the array is of a canonical encoding. @@ -207,7 +194,7 @@ impl ArrayData { /// Returns a Vec of Arrays with all the array's child arrays. pub fn children(&self) -> Vec { match &self.0 { - InnerArrayData::Owned(d) => d.children().iter().cloned().collect_vec(), + InnerArrayData::Owned(d) => d.children().to_vec(), InnerArrayData::Viewed(v) => v.children(), } } @@ -395,8 +382,8 @@ impl Display for ArrayData { impl> ArrayDType for T { fn dtype(&self) -> &DType { match &self.as_ref().0 { - InnerArrayData::Owned(d) => d.dtype(), - InnerArrayData::Viewed(v) => v.dtype(), + InnerArrayData::Owned(d) => &d.dtype, + InnerArrayData::Viewed(v) => &v.dtype, } } } @@ -412,20 +399,22 @@ impl> ArrayLen for T { } impl> ArrayValidity for A { + /// Return whether the element at the given index is valid (true) or null (false). fn is_valid(&self, index: usize) -> bool { - self.as_ref().is_valid(index) + ValidityVTable::::is_valid(self.as_ref().encoding(), self.as_ref(), index) } + /// Return the logical validity of the array. fn logical_validity(&self) -> LogicalValidity { - self.as_ref().logical_validity() + ValidityVTable::::logical_validity(self.as_ref().encoding(), self.as_ref()) } } impl> ArrayStatistics for T { fn statistics(&self) -> &(dyn Statistics + '_) { match &self.as_ref().0 { - InnerArrayData::Owned(d) => d.statistics(), - InnerArrayData::Viewed(v) => v.statistics(), + InnerArrayData::Owned(d) => d, + InnerArrayData::Viewed(v) => v, } } diff --git a/vortex-array/src/data/owned.rs b/vortex-array/src/data/owned.rs index db440fcc81..b6377a6fa2 100644 --- a/vortex-array/src/data/owned.rs +++ b/vortex-array/src/data/owned.rs @@ -22,22 +22,6 @@ pub(super) struct OwnedArrayData { } impl OwnedArrayData { - pub fn encoding(&self) -> EncodingRef { - self.encoding - } - - pub fn dtype(&self) -> &DType { - &self.dtype - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn is_empty(&self) -> bool { - self.len == 0 - } - pub fn metadata(&self) -> &Arc { &self.metadata } @@ -63,7 +47,7 @@ impl OwnedArrayData { child.dtype(), dtype, "child {index} requested with incorrect dtype for encoding {}", - self.encoding().id().as_ref(), + self.encoding.id().as_ref(), ); assert_eq!( child.len(), @@ -80,13 +64,9 @@ impl OwnedArrayData { self.children.len() } - pub fn children(&self) -> &[ArrayData] { + pub fn children(&self) -> &Arc<[ArrayData]> { &self.children } - - pub fn statistics(&self) -> &dyn Statistics { - self - } } impl Statistics for OwnedArrayData { @@ -136,7 +116,7 @@ impl Statistics for OwnedArrayData { } let computed = self - .encoding() + .encoding .compute_statistics(&ArrayData::from(self.clone()), stat) .ok()?; diff --git a/vortex-array/src/data/viewed.rs b/vortex-array/src/data/viewed.rs index 16b69302bf..69c6bfaaaa 100644 --- a/vortex-array/src/data/viewed.rs +++ b/vortex-array/src/data/viewed.rs @@ -46,22 +46,6 @@ impl ViewedArrayData { } } - pub fn encoding(&self) -> EncodingRef { - self.encoding - } - - pub fn dtype(&self) -> &DType { - &self.dtype - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn is_empty(&self) -> bool { - self.len == 0 - } - pub fn metadata_bytes(&self) -> Option<&[u8]> { self.flatbuffer().metadata().map(|m| m.bytes()) } @@ -109,7 +93,7 @@ impl ViewedArrayData { pub fn children(&self) -> Vec { let mut collector = ChildrenCollector::default(); - self.encoding() + self.encoding .accept(&ArrayData::from(self.clone()), &mut collector) .vortex_expect("Failed to get children"); collector.children @@ -120,10 +104,6 @@ impl ViewedArrayData { .buffer_index() .map(|idx| &self.buffers[idx as usize]) } - - pub fn statistics(&self) -> &dyn Statistics { - self - } } #[derive(Default, Debug)] @@ -217,7 +197,7 @@ impl Statistics for ViewedArrayData { return Some(s); } - self.encoding() + self.encoding .compute_statistics(&ArrayData::from(self.clone()), stat) .ok()? .get(stat) diff --git a/vortex-array/src/nbytes.rs b/vortex-array/src/nbytes.rs index 57b3dab4ce..fd4f861061 100644 --- a/vortex-array/src/nbytes.rs +++ b/vortex-array/src/nbytes.rs @@ -32,7 +32,7 @@ struct NBytesVisitor(usize); impl ArrayVisitor for NBytesVisitor { fn visit_child(&mut self, _name: &str, array: &ArrayData) -> VortexResult<()> { - self.0 += array.with_dyn(|a| a.nbytes()); + self.0 += array.nbytes(); Ok(()) } diff --git a/vortex-array/src/tree.rs b/vortex-array/src/tree.rs index 6fcab34f89..0f67cc423d 100644 --- a/vortex-array/src/tree.rs +++ b/vortex-array/src/tree.rs @@ -43,39 +43,37 @@ pub struct TreeFormatter<'a, 'b: 'a> { /// control over how their metadata etc is displayed. impl<'a, 'b: 'a> ArrayVisitor for TreeFormatter<'a, 'b> { fn visit_child(&mut self, name: &str, array: &ArrayData) -> VortexResult<()> { - array.with_dyn(|a| { - let nbytes = a.nbytes(); - let total_size = self.total_size.unwrap_or(nbytes); - writeln!( - self.fmt, - "{}{}: {} nbytes={} ({:.2}%)", - self.indent, - name, - array, - format_size(nbytes, DECIMAL), - 100f64 * nbytes as f64 / total_size as f64 - )?; - self.indent(|i| writeln!(i.fmt, "{}metadata: {}", i.indent, array.array_metadata()))?; - - let old_total_size = self.total_size; - if array.is_encoding(ChunkedEncoding.id()) { - // Clear the total size so each chunk is treated as a new root. - self.total_size = None - } else { - self.total_size = Some(total_size); - } + let nbytes = array.nbytes(); + let total_size = self.total_size.unwrap_or(nbytes); + writeln!( + self.fmt, + "{}{}: {} nbytes={} ({:.2}%)", + self.indent, + name, + array, + format_size(nbytes, DECIMAL), + 100f64 * nbytes as f64 / total_size as f64 + )?; + self.indent(|i| writeln!(i.fmt, "{}metadata: {}", i.indent, array.array_metadata()))?; - self.indent(|i| { - array - .encoding() - .accept(array, i) - .map_err(fmt::Error::custom) - }) - .map_err(VortexError::from)?; + let old_total_size = self.total_size; + if array.is_encoding(ChunkedEncoding.id()) { + // Clear the total size so each chunk is treated as a new root. + self.total_size = None + } else { + self.total_size = Some(total_size); + } - self.total_size = old_total_size; - Ok(()) + self.indent(|i| { + array + .encoding() + .accept(array, i) + .map_err(fmt::Error::custom) }) + .map_err(VortexError::from)?; + + self.total_size = old_total_size; + Ok(()) } fn visit_buffer(&mut self, buffer: &Buffer) -> VortexResult<()> { diff --git a/vortex-file/src/read/mask.rs b/vortex-file/src/read/mask.rs index 594df9aefd..3a7bc918df 100644 --- a/vortex-file/src/read/mask.rs +++ b/vortex-file/src/read/mask.rs @@ -5,7 +5,7 @@ use arrow_buffer::BooleanBuffer; use vortex_array::array::{BoolArray, PrimitiveArray, SparseArray}; use vortex_array::compute::{and, filter, slice, take, try_cast, FilterMask, TakeOptions}; use vortex_array::stats::ArrayStatistics; -use vortex_array::validity::LogicalValidity; +use vortex_array::validity::{ArrayValidity, LogicalValidity}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, PType}; @@ -110,7 +110,7 @@ impl RowMask { /// /// True-valued positions are kept by the returned mask. pub fn from_mask_array(array: &ArrayData, begin: usize, end: usize) -> VortexResult { - match array.with_dyn(|a| a.logical_validity()) { + match array.logical_validity() { LogicalValidity::AllValid(_) => Self::try_new(array.clone(), begin, end), LogicalValidity::AllInvalid(_) => Ok(Self::new_invalid_between(begin, end)), LogicalValidity::Array(validity) => {