From fc929b4e04dc6f88812f9c30023b4ccbab17b628 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Wed, 12 Jun 2024 15:24:37 +0100 Subject: [PATCH] NullArray + statsset cleanup (#350) Add first-class NullArray that maps back/forth with Arrow. Cleans up StatsSet with all-nulls stats set a bit --- vortex-array/src/array/bool/stats.rs | 19 +--- vortex-array/src/array/chunked/flatten.rs | 7 +- vortex-array/src/array/constant/as_arrow.rs | 50 ---------- vortex-array/src/array/constant/mod.rs | 1 - vortex-array/src/array/mod.rs | 1 + vortex-array/src/array/null/as_arrow.rs | 48 ++++++++++ vortex-array/src/array/null/compute.rs | 101 ++++++++++++++++++++ vortex-array/src/array/null/mod.rs | 68 +++++++++++++ vortex-array/src/array/primitive/stats.rs | 29 +----- vortex-array/src/array/varbin/stats.rs | 15 +-- vortex-array/src/arrow/array.rs | 5 +- vortex-array/src/flatten.rs | 4 +- vortex-array/src/stats/statsset.rs | 36 +++++++ 13 files changed, 267 insertions(+), 117 deletions(-) delete mode 100644 vortex-array/src/array/constant/as_arrow.rs create mode 100644 vortex-array/src/array/null/as_arrow.rs create mode 100644 vortex-array/src/array/null/compute.rs create mode 100644 vortex-array/src/array/null/mod.rs diff --git a/vortex-array/src/array/bool/stats.rs b/vortex-array/src/array/bool/stats.rs index 712777da0e..b983e14f3e 100644 --- a/vortex-array/src/array/bool/stats.rs +++ b/vortex-array/src/array/bool/stats.rs @@ -1,14 +1,12 @@ use std::collections::HashMap; use arrow_buffer::BooleanBuffer; -use vortex_dtype::{DType, Nullability}; use vortex_error::VortexResult; -use vortex_scalar::Scalar; use crate::array::bool::BoolArray; use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet}; use crate::validity::{ArrayValidity, LogicalValidity}; -use crate::{ArrayTrait, IntoArray}; +use crate::{ArrayDType, ArrayTrait, IntoArray}; impl ArrayStatisticsCompute for BoolArray { fn compute_statistics(&self, stat: Stat) -> VortexResult { @@ -18,7 +16,7 @@ impl ArrayStatisticsCompute for BoolArray { match self.logical_validity() { LogicalValidity::AllValid(_) => self.boolean_buffer().compute_statistics(stat), - LogicalValidity::AllInvalid(v) => all_null_stats(v), + LogicalValidity::AllInvalid(v) => Ok(StatsSet::nulls(v, self.dtype())), LogicalValidity::Array(a) => NullableBools( &self.boolean_buffer(), &a.into_array().flatten_bool()?.boolean_buffer(), @@ -28,19 +26,6 @@ impl ArrayStatisticsCompute for BoolArray { } } -fn all_null_stats(len: usize) -> VortexResult { - Ok(StatsSet::from(HashMap::from([ - (Stat::Min, Scalar::null(DType::Bool(Nullability::Nullable))), - (Stat::Max, Scalar::null(DType::Bool(Nullability::Nullable))), - (Stat::IsConstant, true.into()), - (Stat::IsSorted, true.into()), - (Stat::IsStrictSorted, (len < 2).into()), - (Stat::RunCount, 1.into()), - (Stat::NullCount, len.into()), - (Stat::TrueCount, 0.into()), - ]))) -} - struct NullableBools<'a>(&'a BooleanBuffer, &'a BooleanBuffer); impl ArrayStatisticsCompute for NullableBools<'_> { diff --git a/vortex-array/src/array/chunked/flatten.rs b/vortex-array/src/array/chunked/flatten.rs index c70b5083d6..7a374ae0a7 100644 --- a/vortex-array/src/array/chunked/flatten.rs +++ b/vortex-array/src/array/chunked/flatten.rs @@ -2,13 +2,12 @@ use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer}; use itertools::Itertools; use vortex_dtype::{match_each_native_ptype, DType, Nullability, PType, StructDType}; use vortex_error::{vortex_bail, ErrString, VortexResult}; -use vortex_scalar::Scalar; use crate::accessor::ArrayAccessor; use crate::array::bool::BoolArray; use crate::array::chunked::ChunkedArray; -use crate::array::constant::ConstantArray; use crate::array::extension::ExtensionArray; +use crate::array::null::NullArray; use crate::array::primitive::PrimitiveArray; use crate::array::r#struct::StructArray; use crate::array::varbin::builder::VarBinBuilder; @@ -73,8 +72,8 @@ pub(crate) fn try_flatten_chunks(chunks: Vec, dtype: DType) -> VortexResu } DType::Null => { let len = chunks.iter().map(|chunk| chunk.len()).sum(); - let const_array = ConstantArray::new(Scalar::null(DType::Null), len); - Ok(Flattened::Null(const_array)) + let null_array = NullArray::new(len); + Ok(Flattened::Null(null_array)) } } } diff --git a/vortex-array/src/array/constant/as_arrow.rs b/vortex-array/src/array/constant/as_arrow.rs deleted file mode 100644 index f2e75cb0e0..0000000000 --- a/vortex-array/src/array/constant/as_arrow.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Implementation of the [AsArrowArray] trait for [ConstantArray] that is representing -//! [DType::Null] values. - -use std::sync::Arc; - -use arrow_array::{ArrayRef as ArrowArrayRef, NullArray}; -use vortex_dtype::DType; -use vortex_error::{vortex_bail, VortexResult}; - -use crate::array::constant::ConstantArray; -use crate::compute::as_arrow::AsArrowArray; -use crate::{ArrayDType, ArrayTrait}; - -impl AsArrowArray for ConstantArray { - fn as_arrow(&self) -> VortexResult { - if self.dtype() != &DType::Null { - vortex_bail!(InvalidArgument: "only null ConstantArrays convert to arrow"); - } - - let arrow_null = NullArray::new(self.len()); - Ok(Arc::new(arrow_null)) - } -} - -#[cfg(test)] -mod test { - use arrow_array::{Array, NullArray}; - - use crate::array::constant::ConstantArray; - use crate::arrow::FromArrowArray; - use crate::compute::as_arrow::AsArrowArray; - use crate::{ArrayData, IntoArray}; - - #[test] - fn test_round_trip() { - let arrow_nulls = NullArray::new(10); - let vortex_nulls = ArrayData::from_arrow(&arrow_nulls, true).into_array(); - - assert_eq!( - *ConstantArray::try_from(vortex_nulls) - .unwrap() - .as_arrow() - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(), - arrow_nulls - ); - } -} diff --git a/vortex-array/src/array/constant/mod.rs b/vortex-array/src/array/constant/mod.rs index d3344fe6c9..0292398166 100644 --- a/vortex-array/src/array/constant/mod.rs +++ b/vortex-array/src/array/constant/mod.rs @@ -7,7 +7,6 @@ use crate::impl_encoding; use crate::stats::Stat; use crate::validity::{ArrayValidity, LogicalValidity}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; -mod as_arrow; mod compute; mod flatten; mod stats; diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 90f61f24bf..7541113171 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -4,6 +4,7 @@ pub mod chunked; pub mod constant; pub mod datetime; pub mod extension; +pub mod null; pub mod primitive; pub mod sparse; pub mod r#struct; diff --git a/vortex-array/src/array/null/as_arrow.rs b/vortex-array/src/array/null/as_arrow.rs new file mode 100644 index 0000000000..f00fdef8be --- /dev/null +++ b/vortex-array/src/array/null/as_arrow.rs @@ -0,0 +1,48 @@ +//! Implementation of the [AsArrowArray] trait for [ConstantArray] that is representing +//! [DType::Null] values. + +use std::sync::Arc; + +use arrow_array::{ArrayRef as ArrowArrayRef, NullArray as ArrowNullArray}; +use vortex_error::VortexResult; + +use crate::array::null::NullArray; +use crate::compute::as_arrow::AsArrowArray; +use crate::ArrayTrait; + +impl AsArrowArray for NullArray { + fn as_arrow(&self) -> VortexResult { + let arrow_null = ArrowNullArray::new(self.len()); + Ok(Arc::new(arrow_null)) + } +} + +#[cfg(test)] +mod test { + use arrow_array::{Array, NullArray as ArrowNullArray}; + + use crate::array::null::NullArray; + use crate::arrow::FromArrowArray; + use crate::compute::as_arrow::AsArrowArray; + use crate::validity::{ArrayValidity, LogicalValidity}; + use crate::{ArrayData, ArrayTrait, IntoArray}; + + #[test] + fn test_round_trip() { + let arrow_nulls = ArrowNullArray::new(10); + let vortex_nulls = ArrayData::from_arrow(&arrow_nulls, true).into_array(); + + let vortex_nulls = NullArray::try_from(vortex_nulls).unwrap(); + assert_eq!(vortex_nulls.len(), 10); + assert!(matches!( + vortex_nulls.logical_validity(), + LogicalValidity::AllInvalid(10) + )); + + let to_arrow = vortex_nulls.as_arrow().unwrap(); + assert_eq!( + *to_arrow.as_any().downcast_ref::().unwrap(), + arrow_nulls + ); + } +} diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs new file mode 100644 index 0000000000..c627f29b85 --- /dev/null +++ b/vortex-array/src/array/null/compute.rs @@ -0,0 +1,101 @@ +use vortex_dtype::{match_each_integer_ptype, DType}; +use vortex_error::VortexResult; +use vortex_scalar::Scalar; + +use crate::array::null::NullArray; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::slice::SliceFn; +use crate::compute::take::TakeFn; +use crate::compute::ArrayCompute; +use crate::{Array, ArrayTrait, IntoArray}; + +impl ArrayCompute for NullArray { + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } + + fn slice(&self) -> Option<&dyn SliceFn> { + Some(self) + } + + fn take(&self) -> Option<&dyn TakeFn> { + Some(self) + } +} + +impl SliceFn for NullArray { + fn slice(&self, start: usize, stop: usize) -> VortexResult { + assert!(stop < self.len(), "cannot slice past end of the array"); + Ok(NullArray::new(stop - start).into_array()) + } +} + +impl ScalarAtFn for NullArray { + fn scalar_at(&self, index: usize) -> VortexResult { + assert!(index < self.len(), "cannot index past end of the array"); + + Ok(Scalar::null(DType::Null)) + } +} + +impl TakeFn for NullArray { + fn take(&self, indices: &Array) -> VortexResult { + let indices = indices.clone().flatten_primitive()?; + + // Enforce all indices are valid + match_each_integer_ptype!(indices.ptype(), |$T| { + for index in indices.scalar_buffer::<$T>().iter() { + assert!((*index as usize) < self.len(), "cannot take past end of the array"); + } + }); + + Ok(NullArray::new(indices.len()).into_array()) + } +} + +#[cfg(test)] +mod test { + use vortex_dtype::DType; + + use crate::array::null::NullArray; + use crate::compute::scalar_at::scalar_at; + use crate::compute::slice::slice; + use crate::compute::take::take; + use crate::validity::{ArrayValidity, LogicalValidity}; + use crate::{ArrayTrait, IntoArray}; + + #[test] + fn test_slice_nulls() { + let nulls = NullArray::new(10).into_array(); + let sliced = NullArray::try_from(slice(&nulls, 0, 4).unwrap()).unwrap(); + + assert_eq!(sliced.len(), 4); + assert!(matches!( + sliced.logical_validity(), + LogicalValidity::AllInvalid(4) + )); + } + + #[test] + fn test_take_nulls() { + let nulls = NullArray::new(10).into_array(); + let taken = + NullArray::try_from(take(&nulls, &vec![0u64, 2, 4, 6, 8].into_array()).unwrap()) + .unwrap(); + + assert_eq!(taken.len(), 5); + assert!(matches!( + taken.logical_validity(), + LogicalValidity::AllInvalid(5) + )); + } + + #[test] + fn test_scalar_at_nulls() { + let nulls = NullArray::new(10); + + let scalar = scalar_at(&nulls.into_array(), 0).unwrap(); + assert!(scalar.is_null()); + assert_eq!(scalar.dtype().clone(), DType::Null); + } +} diff --git a/vortex-array/src/array/null/mod.rs b/vortex-array/src/array/null/mod.rs new file mode 100644 index 0000000000..7eae7c4a2a --- /dev/null +++ b/vortex-array/src/array/null/mod.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; + +use crate::stats::{ArrayStatisticsCompute, Stat}; +use crate::validity::{ArrayValidity, LogicalValidity, Validity}; +use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; +use crate::{impl_encoding, ArrayFlatten}; + +mod as_arrow; +mod compute; + +impl_encoding!("vortex.null", Null); + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NullMetadata { + len: usize, +} + +impl NullArray { + pub fn new(len: usize) -> Self { + Self::try_from_parts( + DType::Null, + NullMetadata { len }, + Arc::new([]), + StatsSet::nulls(len, &DType::Null), + ) + .expect("NullArray::new cannot fail") + } +} + +impl ArrayFlatten for NullArray { + fn flatten(self) -> VortexResult { + Ok(Flattened::Null(self)) + } +} + +impl ArrayValidity for NullArray { + fn is_valid(&self, _: usize) -> bool { + false + } + + fn logical_validity(&self) -> LogicalValidity { + LogicalValidity::AllInvalid(self.len()) + } +} + +impl ArrayStatisticsCompute for NullArray { + fn compute_statistics(&self, _stat: Stat) -> VortexResult { + Ok(StatsSet::nulls(self.len(), &DType::Null)) + } +} + +impl AcceptArrayVisitor for NullArray { + fn accept(&self, visitor: &mut dyn ArrayVisitor) -> VortexResult<()> { + visitor.visit_validity(&Validity::AllInvalid) + } +} + +impl ArrayTrait for NullArray { + fn len(&self) -> usize { + self.metadata().len + } + + fn nbytes(&self) -> usize { + 0 + } +} + +impl EncodingCompression for NullEncoding {} diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index 9e573bbfa7..247f709d01 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -4,8 +4,7 @@ use std::mem::size_of; use arrow_buffer::buffer::BooleanBuffer; use num_traits::PrimInt; use vortex_dtype::half::f16; -use vortex_dtype::Nullability::Nullable; -use vortex_dtype::{match_each_native_ptype, DType, NativePType}; +use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -13,6 +12,7 @@ use crate::array::primitive::PrimitiveArray; use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet}; use crate::validity::ArrayValidity; use crate::validity::LogicalValidity; +use crate::ArrayDType; use crate::IntoArray; trait PStatsType: NativePType + Into + BitWidth {} @@ -24,7 +24,7 @@ impl ArrayStatisticsCompute for PrimitiveArray { match_each_native_ptype!(self.ptype(), |$P| { match self.logical_validity() { LogicalValidity::AllValid(_) => self.typed_data::<$P>().compute_statistics(stat), - LogicalValidity::AllInvalid(v) => all_null_stats::<$P>(v), + LogicalValidity::AllInvalid(v) => Ok(StatsSet::nulls(v, self.dtype())), LogicalValidity::Array(a) => NullableValues( self.typed_data::<$P>(), &a.into_array().flatten_bool()?.boolean_buffer(), @@ -46,29 +46,6 @@ impl ArrayStatisticsCompute for &[T] { } } -fn all_null_stats(len: usize) -> VortexResult { - Ok(StatsSet::from(HashMap::from([ - ( - Stat::Min, - Scalar::null(DType::Primitive(T::PTYPE, Nullable)), - ), - ( - Stat::Max, - Scalar::null(DType::Primitive(T::PTYPE, Nullable)), - ), - (Stat::IsConstant, true.into()), - (Stat::IsSorted, true.into()), - (Stat::IsStrictSorted, (len < 2).into()), - (Stat::RunCount, 1.into()), - (Stat::NullCount, len.into()), - (Stat::BitWidthFreq, vec![0; size_of::() * 8 + 1].into()), - ( - Stat::TrailingZeroFreq, - vec![size_of::() * 8; size_of::() * 8 + 1].into(), - ), - ]))) -} - struct NullableValues<'a, T: PStatsType>(&'a [T], &'a BooleanBuffer); impl<'a, T: PStatsType> ArrayStatisticsCompute for NullableValues<'a, T> { diff --git a/vortex-array/src/array/varbin/stats.rs b/vortex-array/src/array/varbin/stats.rs index da67d324f0..5a8ff73b15 100644 --- a/vortex-array/src/array/varbin/stats.rs +++ b/vortex-array/src/array/varbin/stats.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use vortex_dtype::DType; use vortex_error::VortexResult; -use vortex_scalar::Scalar; use crate::accessor::ArrayAccessor; use crate::array::varbin::{varbin_scalar, VarBinArray}; @@ -37,22 +36,10 @@ pub fn compute_stats(iter: &mut dyn Iterator>, dtype: &DTyp acc.n_nulls(leading_nulls); acc.finish(dtype) } else { - all_null_stats(leading_nulls, dtype) + StatsSet::nulls(leading_nulls, dtype) } } -fn all_null_stats(len: usize, dtype: &DType) -> StatsSet { - StatsSet::from(HashMap::from([ - (Stat::Min, Scalar::null(dtype.clone())), - (Stat::Max, Scalar::null(dtype.clone())), - (Stat::IsConstant, true.into()), - (Stat::IsSorted, true.into()), - (Stat::IsStrictSorted, (len < 2).into()), - (Stat::RunCount, 1.into()), - (Stat::NullCount, len.into()), - ])) -} - pub struct VarBinAccumulator<'a> { min: &'a [u8], max: &'a [u8], diff --git a/vortex-array/src/arrow/array.rs b/vortex-array/src/arrow/array.rs index 912acf6644..470ac4682a 100644 --- a/vortex-array/src/arrow/array.rs +++ b/vortex-array/src/arrow/array.rs @@ -22,11 +22,10 @@ use arrow_schema::{DataType, TimeUnit}; use itertools::Itertools; use vortex_dtype::DType; use vortex_dtype::NativePType; -use vortex_scalar::Scalar; use crate::array::bool::BoolArray; -use crate::array::constant::ConstantArray; use crate::array::datetime::LocalDateTimeArray; +use crate::array::null::NullArray; use crate::array::primitive::PrimitiveArray; use crate::array::r#struct::StructArray; use crate::array::varbin::VarBinArray; @@ -195,7 +194,7 @@ impl FromArrowArray<&ArrowStructArray> for ArrayData { impl FromArrowArray<&ArrowNullArray> for ArrayData { fn from_arrow(value: &ArrowNullArray, nullable: bool) -> Self { assert!(nullable); - ConstantArray::new(Scalar::null(DType::Null), value.len()).into_array_data() + NullArray::new(value.len()).into_array_data() } } diff --git a/vortex-array/src/flatten.rs b/vortex-array/src/flatten.rs index da991117a7..b6d5ef7b1d 100644 --- a/vortex-array/src/flatten.rs +++ b/vortex-array/src/flatten.rs @@ -1,8 +1,8 @@ use vortex_error::VortexResult; use crate::array::bool::BoolArray; -use crate::array::constant::ConstantArray; use crate::array::extension::ExtensionArray; +use crate::array::null::NullArray; use crate::array::primitive::PrimitiveArray; use crate::array::r#struct::StructArray; use crate::array::varbin::VarBinArray; @@ -12,7 +12,7 @@ use crate::{Array, IntoArray}; /// The set of encodings that can be converted to Arrow with zero-copy. pub enum Flattened { - Null(ConstantArray), + Null(NullArray), Bool(BoolArray), Primitive(PrimitiveArray), Struct(StructArray), diff --git a/vortex-array/src/stats/statsset.rs b/vortex-array/src/stats/statsset.rs index 4cd84f80cc..cc5dec36d5 100644 --- a/vortex-array/src/stats/statsset.rs +++ b/vortex-array/src/stats/statsset.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use enum_iterator::all; use itertools::Itertools; +use vortex_dtype::DType; use vortex_error::VortexError; use vortex_scalar::Scalar; @@ -26,6 +27,41 @@ impl StatsSet { } } + /// Specialized constructor for the case where the StatsSet represents + /// an array consisting entirely of [null](vortex_dtype::DType::Null) values. + pub fn nulls(len: usize, dtype: &DType) -> Self { + let mut stats = HashMap::from([ + (Stat::Min, Scalar::null(dtype.clone())), + (Stat::Max, Scalar::null(dtype.clone())), + (Stat::IsConstant, true.into()), + (Stat::IsSorted, true.into()), + (Stat::IsStrictSorted, (len < 2).into()), + (Stat::RunCount, 1.into()), + (Stat::NullCount, len.into()), + ]); + + // Add any DType-specific stats. + match dtype { + DType::Bool(_) => { + stats.insert(Stat::TrueCount, 0.into()); + } + DType::Primitive(ptype, _) => { + ptype.byte_width(); + stats.insert( + Stat::BitWidthFreq, + vec![0; ptype.byte_width() * 8 + 1].into(), + ); + stats.insert( + Stat::TrailingZeroFreq, + vec![ptype.byte_width() * 8; ptype.byte_width() * 8 + 1].into(), + ); + } + _ => {} + } + + Self::from(stats) + } + pub fn of(stat: Stat, value: Scalar) -> Self { Self::from(HashMap::from([(stat, value)])) }