diff --git a/bench-vortex/src/vortex_utils.rs b/bench-vortex/src/vortex_utils.rs index b490e73cd5..8e05d9baa3 100644 --- a/bench-vortex/src/vortex_utils.rs +++ b/bench-vortex/src/vortex_utils.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use vortex::array::chunked::ChunkedArray; use vortex::array::struct_::StructArray; +use vortex::variants::StructArrayTrait; use vortex::ArrayDType; use vortex_dtype::DType; use vortex_error::VortexResult; diff --git a/encodings/alp/src/array.rs b/encodings/alp/src/array.rs index c62ba337bf..29eef79ea8 100644 --- a/encodings/alp/src/array.rs +++ b/encodings/alp/src/array.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use vortex::array::primitive::PrimitiveArray; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_dtype::PType; @@ -92,6 +93,14 @@ impl ALPArray { impl ArrayTrait for ALPArray {} +impl ArrayVariants for ALPArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for ALPArray {} + impl ArrayValidity for ALPArray { fn is_valid(&self, index: usize) -> bool { self.encoded().with_dyn(|a| a.is_valid(index)) diff --git a/encodings/byte_bool/src/lib.rs b/encodings/byte_bool/src/lib.rs index 4eb3d6a7be..01a6296e79 100644 --- a/encodings/byte_bool/src/lib.rs +++ b/encodings/byte_bool/src/lib.rs @@ -3,6 +3,7 @@ use std::mem::ManuallyDrop; use arrow_buffer::BooleanBuffer; use serde::{Deserialize, Serialize}; use vortex::array::bool::BoolArray; +use vortex::variants::{ArrayVariants, BoolArrayTrait}; use vortex::{ impl_encoding, validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}, @@ -73,6 +74,14 @@ impl ByteBoolArray { impl ArrayTrait for ByteBoolArray {} +impl ArrayVariants for ByteBoolArray { + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + Some(self) + } +} + +impl BoolArrayTrait for ByteBoolArray {} + impl From> for ByteBoolArray { fn from(value: Vec) -> Self { Self::try_from_vec(value, Validity::AllValid).unwrap() diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index 0dd60cf671..15556b995d 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; +use vortex::variants::{ArrayVariants, ExtensionArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_error::vortex_bail; @@ -79,6 +80,14 @@ impl DateTimePartsArray { impl ArrayTrait for DateTimePartsArray {} +impl ArrayVariants for DateTimePartsArray { + fn as_extension_array(&self) -> Option<&dyn ExtensionArrayTrait> { + Some(self) + } +} + +impl ExtensionArrayTrait for DateTimePartsArray {} + impl IntoCanonical for DateTimePartsArray { fn into_canonical(self) -> VortexResult { Ok(Canonical::Extension( diff --git a/encodings/dict/src/lib.rs b/encodings/dict/src/lib.rs index e2f9228cbf..fe41c79549 100644 --- a/encodings/dict/src/lib.rs +++ b/encodings/dict/src/lib.rs @@ -9,3 +9,4 @@ mod compress; mod compute; mod dict; mod stats; +mod variants; diff --git a/encodings/dict/src/variants.rs b/encodings/dict/src/variants.rs new file mode 100644 index 0000000000..843086b18f --- /dev/null +++ b/encodings/dict/src/variants.rs @@ -0,0 +1,37 @@ +use vortex::variants::{ArrayVariants, BinaryArrayTrait, PrimitiveArrayTrait, Utf8ArrayTrait}; +use vortex::ArrayDType; +use vortex_dtype::DType; + +use crate::DictArray; + +impl ArrayVariants for DictArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + if matches!(self.dtype(), DType::Primitive(..)) { + Some(self) + } else { + None + } + } + + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + if matches!(self.dtype(), DType::Utf8(..)) { + Some(self) + } else { + None + } + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + if matches!(self.dtype(), DType::Binary(..)) { + Some(self) + } else { + None + } + } +} + +impl PrimitiveArrayTrait for DictArray {} + +impl Utf8ArrayTrait for DictArray {} + +impl BinaryArrayTrait for DictArray {} diff --git a/encodings/fastlanes/src/bitpacking/mod.rs b/encodings/fastlanes/src/bitpacking/mod.rs index a0129bdb93..03afea638c 100644 --- a/encodings/fastlanes/src/bitpacking/mod.rs +++ b/encodings/fastlanes/src/bitpacking/mod.rs @@ -3,6 +3,7 @@ pub use compress::*; use vortex::array::primitive::{Primitive, PrimitiveArray}; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_dtype::{Nullability, PType}; @@ -201,6 +202,14 @@ impl ArrayTrait for BitPackedArray { } } +impl ArrayVariants for BitPackedArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for BitPackedArray {} + #[cfg(test)] mod test { use vortex::array::primitive::PrimitiveArray; diff --git a/encodings/fastlanes/src/delta/mod.rs b/encodings/fastlanes/src/delta/mod.rs index c17d07cdab..79e2949ea0 100644 --- a/encodings/fastlanes/src/delta/mod.rs +++ b/encodings/fastlanes/src/delta/mod.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::ValidityMetadata; use vortex::validity::{ArrayValidity, LogicalValidity, Validity}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_dtype::match_each_unsigned_integer_ptype; @@ -92,6 +93,14 @@ impl DeltaArray { impl ArrayTrait for DeltaArray {} +impl ArrayVariants for DeltaArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for DeltaArray {} + impl IntoCanonical for DeltaArray { fn into_canonical(self) -> VortexResult { delta_decompress(self).map(Canonical::Primitive) diff --git a/encodings/fastlanes/src/for/mod.rs b/encodings/fastlanes/src/for/mod.rs index 009e5cfb6a..b328b60766 100644 --- a/encodings/fastlanes/src/for/mod.rs +++ b/encodings/fastlanes/src/for/mod.rs @@ -2,6 +2,7 @@ pub use compress::*; use serde::{Deserialize, Serialize}; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_dtype::PType; @@ -95,3 +96,11 @@ impl ArrayTrait for FoRArray { self.encoded().nbytes() } } + +impl ArrayVariants for FoRArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for FoRArray {} diff --git a/encodings/roaring/src/boolean/mod.rs b/encodings/roaring/src/boolean/mod.rs index 03f5809ef4..6641febb2e 100644 --- a/encodings/roaring/src/boolean/mod.rs +++ b/encodings/roaring/src/boolean/mod.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; use vortex::array::bool::{Bool, BoolArray}; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity, Validity}; +use vortex::variants::{ArrayVariants, BoolArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_buffer::Buffer; @@ -62,6 +63,14 @@ impl RoaringBoolArray { impl ArrayTrait for RoaringBoolArray {} +impl ArrayVariants for RoaringBoolArray { + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + Some(self) + } +} + +impl BoolArrayTrait for RoaringBoolArray {} + impl AcceptArrayVisitor for RoaringBoolArray { fn accept(&self, _visitor: &mut dyn ArrayVisitor) -> VortexResult<()> { // TODO(ngates): should we store a buffer in memory? Or delay serialization? diff --git a/encodings/roaring/src/integer/mod.rs b/encodings/roaring/src/integer/mod.rs index 91c0dcefdd..5baaa5a275 100644 --- a/encodings/roaring/src/integer/mod.rs +++ b/encodings/roaring/src/integer/mod.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use vortex::array::primitive::{Primitive, PrimitiveArray}; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, Canonical, IntoCanonical}; use vortex_buffer::Buffer; @@ -29,7 +30,7 @@ impl RoaringIntArray { let length = bitmap.statistics().cardinality as usize; Ok(Self { typed: TypedArray::try_from_parts( - DType::Bool(NonNullable), + DType::Primitive(ptype, NonNullable), length, RoaringIntMetadata { ptype }, Some(Buffer::from(bitmap.serialize::())), @@ -64,6 +65,14 @@ impl RoaringIntArray { impl ArrayTrait for RoaringIntArray {} +impl ArrayVariants for RoaringIntArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for RoaringIntArray {} + impl ArrayValidity for RoaringIntArray { fn is_valid(&self, _index: usize) -> bool { true diff --git a/encodings/runend/src/runend.rs b/encodings/runend/src/runend.rs index 042477c1e0..6f799a7fb2 100644 --- a/encodings/runend/src/runend.rs +++ b/encodings/runend/src/runend.rs @@ -4,6 +4,7 @@ use vortex::compute::unary::scalar_at::scalar_at; use vortex::compute::{search_sorted, SearchSortedSide}; use vortex::stats::{ArrayStatistics, ArrayStatisticsCompute}; use vortex::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoArrayVariant, IntoCanonical}; use vortex_error::vortex_bail; @@ -107,6 +108,14 @@ impl RunEndArray { impl ArrayTrait for RunEndArray {} +impl ArrayVariants for RunEndArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for RunEndArray {} + impl ArrayValidity for RunEndArray { fn is_valid(&self, index: usize) -> bool { self.validity().is_valid(index) diff --git a/encodings/zigzag/src/zigzag.rs b/encodings/zigzag/src/zigzag.rs index 3c0b99f47f..eea809c95e 100644 --- a/encodings/zigzag/src/zigzag.rs +++ b/encodings/zigzag/src/zigzag.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use vortex::array::primitive::PrimitiveArray; use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; +use vortex::variants::{ArrayVariants, PrimitiveArrayTrait}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; use vortex_dtype::PType; @@ -52,6 +53,14 @@ impl ZigZagArray { impl ArrayTrait for ZigZagArray {} +impl ArrayVariants for ZigZagArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for ZigZagArray {} + impl ArrayValidity for ZigZagArray { fn is_valid(&self, index: usize) -> bool { self.encoded().with_dyn(|a| a.is_valid(index)) diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index 5f4c160a99..2997d2f244 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -5,6 +5,7 @@ use vortex_buffer::Buffer; use crate::validity::{ArrayValidity, ValidityMetadata}; use crate::validity::{LogicalValidity, Validity}; +use crate::variants::{ArrayVariants, BoolArrayTrait}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; use crate::{impl_encoding, Canonical, IntoCanonical}; @@ -74,6 +75,14 @@ impl BoolArray { impl ArrayTrait for BoolArray {} +impl ArrayVariants for BoolArray { + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + Some(self) + } +} + +impl BoolArrayTrait for BoolArray {} + impl From for BoolArray { fn from(value: BooleanBuffer) -> Self { Self::try_new(value, Validity::NonNullable).unwrap() diff --git a/vortex-array/src/array/chunked/canonical.rs b/vortex-array/src/array/chunked/canonical.rs index c42f361ee5..778a33f61e 100644 --- a/vortex-array/src/array/chunked/canonical.rs +++ b/vortex-array/src/array/chunked/canonical.rs @@ -13,6 +13,7 @@ use crate::array::struct_::StructArray; use crate::array::varbin::builder::VarBinBuilder; use crate::array::varbin::VarBinArray; use crate::validity::Validity; +use crate::variants::StructArrayTrait; use crate::{ Array, ArrayDType, ArrayValidity, Canonical, IntoArray, IntoArrayVariant, IntoCanonical, }; diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index 8323d3a3e0..b56aabe104 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -22,6 +22,7 @@ use crate::{impl_encoding, ArrayDType}; mod canonical; mod compute; mod stats; +mod variants; impl_encoding!("vortex.chunked", 11u16, Chunked); diff --git a/vortex-array/src/array/chunked/variants.rs b/vortex-array/src/array/chunked/variants.rs new file mode 100644 index 0000000000..02dfff90ef --- /dev/null +++ b/vortex-array/src/array/chunked/variants.rs @@ -0,0 +1,103 @@ +use vortex_dtype::DType; + +use crate::array::chunked::ChunkedArray; +use crate::variants::{ + ArrayVariants, BinaryArrayTrait, BoolArrayTrait, ExtensionArrayTrait, ListArrayTrait, + NullArrayTrait, PrimitiveArrayTrait, StructArrayTrait, Utf8ArrayTrait, +}; +use crate::{Array, ArrayDType, IntoArray}; + +/// Chunked arrays support all DTypes +impl ArrayVariants for ChunkedArray { + fn as_null_array(&self) -> Option<&dyn NullArrayTrait> { + if matches!(self.dtype(), DType::Null) { + Some(self) + } else { + None + } + } + + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + if matches!(self.dtype(), DType::Bool(_)) { + Some(self) + } else { + None + } + } + + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + if matches!(self.dtype(), DType::Primitive(..)) { + Some(self) + } else { + None + } + } + + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + if matches!(self.dtype(), DType::Utf8(_)) { + Some(self) + } else { + None + } + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + if matches!(self.dtype(), DType::Binary(_)) { + Some(self) + } else { + None + } + } + + fn as_struct_array(&self) -> Option<&dyn StructArrayTrait> { + if matches!(self.dtype(), DType::Struct(..)) { + Some(self) + } else { + None + } + } + + fn as_list_array(&self) -> Option<&dyn ListArrayTrait> { + if matches!(self.dtype(), DType::List(..)) { + Some(self) + } else { + None + } + } + + fn as_extension_array(&self) -> Option<&dyn ExtensionArrayTrait> { + if matches!(self.dtype(), DType::Extension(..)) { + Some(self) + } else { + None + } + } +} + +impl NullArrayTrait for ChunkedArray {} + +impl BoolArrayTrait for ChunkedArray {} + +impl PrimitiveArrayTrait for ChunkedArray {} + +impl Utf8ArrayTrait for ChunkedArray {} + +impl BinaryArrayTrait for ChunkedArray {} + +impl StructArrayTrait for ChunkedArray { + fn field(&self, idx: usize) -> Option { + let mut chunks = Vec::with_capacity(self.nchunks()); + for chunk in self.chunks() { + let array = chunk.with_dyn(|a| a.as_struct_array().and_then(|s| s.field(idx)))?; + chunks.push(array); + } + let chunked = ChunkedArray::try_new(chunks, self.dtype().clone()) + .expect("should be correct dtype") + .into_array(); + Some(chunked) + } +} + +impl ListArrayTrait for ChunkedArray {} + +impl ExtensionArrayTrait for ChunkedArray {} diff --git a/vortex-array/src/array/constant/mod.rs b/vortex-array/src/array/constant/mod.rs index 706aa21b8d..005567a98c 100644 --- a/vortex-array/src/array/constant/mod.rs +++ b/vortex-array/src/array/constant/mod.rs @@ -11,6 +11,7 @@ use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; mod canonical; mod compute; mod stats; +mod variants; impl_encoding!("vortex.constant", 10u16, Constant); diff --git a/vortex-array/src/array/constant/variants.rs b/vortex-array/src/array/constant/variants.rs new file mode 100644 index 0000000000..28c9a22509 --- /dev/null +++ b/vortex-array/src/array/constant/variants.rs @@ -0,0 +1,99 @@ +use vortex_dtype::DType; +use vortex_scalar::StructScalar; + +use crate::array::constant::ConstantArray; +use crate::variants::{ + ArrayVariants, BinaryArrayTrait, BoolArrayTrait, ExtensionArrayTrait, ListArrayTrait, + NullArrayTrait, PrimitiveArrayTrait, StructArrayTrait, Utf8ArrayTrait, +}; +use crate::{Array, ArrayDType, IntoArray}; + +/// Constant arrays support all DTypes +impl ArrayVariants for ConstantArray { + fn as_null_array(&self) -> Option<&dyn NullArrayTrait> { + if matches!(self.dtype(), DType::Null) { + Some(self) + } else { + None + } + } + + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + if matches!(self.dtype(), DType::Bool(_)) { + Some(self) + } else { + None + } + } + + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + if matches!(self.dtype(), DType::Primitive(..)) { + Some(self) + } else { + None + } + } + + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + if matches!(self.dtype(), DType::Utf8(_)) { + Some(self) + } else { + None + } + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + if matches!(self.dtype(), DType::Binary(_)) { + Some(self) + } else { + None + } + } + + fn as_struct_array(&self) -> Option<&dyn StructArrayTrait> { + if matches!(self.dtype(), DType::Struct(..)) { + Some(self) + } else { + None + } + } + + fn as_list_array(&self) -> Option<&dyn ListArrayTrait> { + if matches!(self.dtype(), DType::List(..)) { + Some(self) + } else { + None + } + } + + fn as_extension_array(&self) -> Option<&dyn ExtensionArrayTrait> { + if matches!(self.dtype(), DType::Extension(..)) { + Some(self) + } else { + None + } + } +} + +impl NullArrayTrait for ConstantArray {} + +impl BoolArrayTrait for ConstantArray {} + +impl PrimitiveArrayTrait for ConstantArray {} + +impl Utf8ArrayTrait for ConstantArray {} + +impl BinaryArrayTrait for ConstantArray {} + +impl StructArrayTrait for ConstantArray { + fn field(&self, idx: usize) -> Option { + StructScalar::try_from(self.scalar()) + .ok()? + .field_by_idx(idx) + .map(|scalar| ConstantArray::new(scalar, self.len()).into_array()) + } +} + +impl ListArrayTrait for ConstantArray {} + +impl ExtensionArrayTrait for ConstantArray {} diff --git a/vortex-array/src/array/extension/mod.rs b/vortex-array/src/array/extension/mod.rs index f074e9e51d..fd9c76ed84 100644 --- a/vortex-array/src/array/extension/mod.rs +++ b/vortex-array/src/array/extension/mod.rs @@ -3,6 +3,7 @@ use vortex_dtype::{ExtDType, ExtID}; use crate::stats::ArrayStatisticsCompute; use crate::validity::{ArrayValidity, LogicalValidity}; +use crate::variants::{ArrayVariants, ExtensionArrayTrait}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; use crate::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; @@ -52,6 +53,14 @@ impl ExtensionArray { impl ArrayTrait for ExtensionArray {} +impl ArrayVariants for ExtensionArray { + fn as_extension_array(&self) -> Option<&dyn ExtensionArrayTrait> { + Some(self) + } +} + +impl ExtensionArrayTrait for ExtensionArray {} + impl IntoCanonical for ExtensionArray { fn into_canonical(self) -> VortexResult { Ok(Canonical::Extension(self)) diff --git a/vortex-array/src/array/null/mod.rs b/vortex-array/src/array/null/mod.rs index 9311d32e7c..5c2a0cbe22 100644 --- a/vortex-array/src/array/null/mod.rs +++ b/vortex-array/src/array/null/mod.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::stats::{ArrayStatisticsCompute, Stat}; use crate::validity::{ArrayValidity, LogicalValidity, Validity}; +use crate::variants::{ArrayVariants, NullArrayTrait}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; use crate::{impl_encoding, Canonical, IntoCanonical}; @@ -60,3 +61,11 @@ impl ArrayTrait for NullArray { 0 } } + +impl ArrayVariants for NullArray { + fn as_null_array(&self) -> Option<&dyn NullArrayTrait> { + Some(self) + } +} + +impl NullArrayTrait for NullArray {} diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index 7b249246f7..92500614f9 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -7,6 +7,7 @@ use vortex_dtype::{match_each_native_ptype, NativePType, PType}; use vortex_error::vortex_bail; use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; +use crate::variants::{ArrayVariants, PrimitiveArrayTrait}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; use crate::{impl_encoding, ArrayDType}; use crate::{Canonical, IntoCanonical}; @@ -158,6 +159,14 @@ impl PrimitiveArray { impl ArrayTrait for PrimitiveArray {} +impl ArrayVariants for PrimitiveArray { + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + Some(self) + } +} + +impl PrimitiveArrayTrait for PrimitiveArray {} + impl From> for PrimitiveArray { fn from(values: Vec) -> Self { Self::from_vec(values, Validity::NonNullable) diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index edb2f916c7..67bfbb3f1d 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -13,6 +13,7 @@ use crate::{impl_encoding, ArrayDType, IntoArrayVariant}; mod compute; mod flatten; +mod variants; impl_encoding!("vortex.sparse", 9u16, Sparse); diff --git a/vortex-array/src/array/sparse/variants.rs b/vortex-array/src/array/sparse/variants.rs new file mode 100644 index 0000000000..298c3cefa3 --- /dev/null +++ b/vortex-array/src/array/sparse/variants.rs @@ -0,0 +1,113 @@ +use vortex_dtype::DType; +use vortex_scalar::StructScalar; + +use crate::array::sparse::SparseArray; +use crate::variants::{ + ArrayVariants, BinaryArrayTrait, BoolArrayTrait, ExtensionArrayTrait, ListArrayTrait, + NullArrayTrait, PrimitiveArrayTrait, StructArrayTrait, Utf8ArrayTrait, +}; +use crate::{Array, ArrayDType, IntoArray}; + +/// Sparse arrays support all DTypes +impl ArrayVariants for SparseArray { + fn as_null_array(&self) -> Option<&dyn NullArrayTrait> { + if matches!(self.dtype(), DType::Null) { + Some(self) + } else { + None + } + } + + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + if matches!(self.dtype(), DType::Bool(_)) { + Some(self) + } else { + None + } + } + + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + if matches!(self.dtype(), DType::Primitive(..)) { + Some(self) + } else { + None + } + } + + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + if matches!(self.dtype(), DType::Utf8(_)) { + Some(self) + } else { + None + } + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + if matches!(self.dtype(), DType::Binary(_)) { + Some(self) + } else { + None + } + } + + fn as_struct_array(&self) -> Option<&dyn StructArrayTrait> { + if matches!(self.dtype(), DType::Struct(..)) { + Some(self) + } else { + None + } + } + + fn as_list_array(&self) -> Option<&dyn ListArrayTrait> { + if matches!(self.dtype(), DType::List(..)) { + Some(self) + } else { + None + } + } + + fn as_extension_array(&self) -> Option<&dyn ExtensionArrayTrait> { + if matches!(self.dtype(), DType::Extension(..)) { + Some(self) + } else { + None + } + } +} + +impl NullArrayTrait for SparseArray {} + +impl BoolArrayTrait for SparseArray {} + +impl PrimitiveArrayTrait for SparseArray {} + +impl Utf8ArrayTrait for SparseArray {} + +impl BinaryArrayTrait for SparseArray {} + +impl StructArrayTrait for SparseArray { + fn field(&self, idx: usize) -> Option { + let values = self + .values() + .with_dyn(|s| s.as_struct_array().and_then(|s| s.field(idx)))?; + let scalar = StructScalar::try_from(self.fill_value()) + .ok()? + .field_by_idx(idx)?; + + Some( + SparseArray::try_new_with_offset( + self.indices().clone(), + values, + self.len(), + self.indices_offset(), + scalar, + ) + .unwrap() + .into_array(), + ) + } +} + +impl ListArrayTrait for SparseArray {} + +impl ExtensionArrayTrait for SparseArray {} diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index 26271c6882..f86e9e0594 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -5,6 +5,7 @@ use vortex_scalar::Scalar; use crate::array::struct_::StructArray; use crate::compute::unary::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::{slice, take, ArrayCompute, SliceFn, TakeFn}; +use crate::variants::StructArrayTrait; use crate::{Array, ArrayDType, IntoArray}; impl ArrayCompute for StructArray { diff --git a/vortex-array/src/array/struct_/mod.rs b/vortex-array/src/array/struct_/mod.rs index cd6951c83a..dceec7e18d 100644 --- a/vortex-array/src/array/struct_/mod.rs +++ b/vortex-array/src/array/struct_/mod.rs @@ -4,6 +4,7 @@ use vortex_error::vortex_bail; use crate::stats::ArrayStatisticsCompute; use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; +use crate::variants::{ArrayVariants, StructArrayTrait}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; use crate::{impl_encoding, ArrayDType}; use crate::{Canonical, IntoCanonical}; @@ -19,41 +20,6 @@ pub struct StructMetadata { } impl StructArray { - pub fn field(&self, idx: usize) -> Option { - let DType::Struct(st, _) = self.dtype() else { - unreachable!() - }; - let dtype = st.dtypes().get(idx)?; - self.array().child(idx, dtype, self.len()) - } - - pub fn field_by_name(&self, name: &str) -> Option { - let field_idx = self - .names() - .iter() - .position(|field_name| field_name.as_ref() == name); - - field_idx.and_then(|field_idx| self.field(field_idx)) - } - - pub fn names(&self) -> &FieldNames { - let DType::Struct(st, _) = self.dtype() else { - unreachable!() - }; - st.names() - } - - pub fn dtypes(&self) -> &[DType] { - let DType::Struct(st, _) = self.dtype() else { - unreachable!() - }; - st.dtypes() - } - - pub fn nfields(&self) -> usize { - self.dtypes().len() - } - pub fn validity(&self) -> Validity { self.metadata().validity.to_validity(self.array().child( self.nfields(), @@ -158,6 +124,18 @@ impl StructArray { impl ArrayTrait for StructArray {} +impl ArrayVariants for StructArray { + fn as_struct_array(&self) -> Option<&dyn StructArrayTrait> { + Some(self) + } +} + +impl StructArrayTrait for StructArray { + fn field(&self, idx: usize) -> Option { + self.array().child(idx, &self.dtypes()[idx], self.len()) + } +} + impl IntoCanonical for StructArray { /// StructEncoding is the canonical form for a [DType::Struct] array, so return self. fn into_canonical(self) -> VortexResult { @@ -196,6 +174,7 @@ mod test { use crate::array::struct_::StructArray; use crate::array::varbin::VarBinArray; use crate::validity::Validity; + use crate::variants::StructArrayTrait; use crate::IntoArray; #[test] diff --git a/vortex-array/src/array/varbin/mod.rs b/vortex-array/src/array/varbin/mod.rs index a917dea8fd..efd80e355d 100644 --- a/vortex-array/src/array/varbin/mod.rs +++ b/vortex-array/src/array/varbin/mod.rs @@ -20,6 +20,7 @@ pub mod builder; mod compute; mod flatten; mod stats; +mod variants; impl_encoding!("vortex.varbin", 4u16, VarBin); diff --git a/vortex-array/src/array/varbin/variants.rs b/vortex-array/src/array/varbin/variants.rs new file mode 100644 index 0000000000..b387e0978a --- /dev/null +++ b/vortex-array/src/array/varbin/variants.rs @@ -0,0 +1,27 @@ +use vortex_dtype::DType; + +use crate::array::varbin::VarBinArray; +use crate::variants::{ArrayVariants, BinaryArrayTrait, Utf8ArrayTrait}; +use crate::ArrayDType; + +impl ArrayVariants for VarBinArray { + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + if matches!(self.dtype(), DType::Utf8(..)) { + Some(self) + } else { + None + } + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + if matches!(self.dtype(), DType::Binary(..)) { + Some(self) + } else { + None + } + } +} + +impl Utf8ArrayTrait for VarBinArray {} + +impl BinaryArrayTrait for VarBinArray {} diff --git a/vortex-array/src/array/varbinview/mod.rs b/vortex-array/src/array/varbinview/mod.rs index fa294d31fb..c73bcb5969 100644 --- a/vortex-array/src/array/varbinview/mod.rs +++ b/vortex-array/src/array/varbinview/mod.rs @@ -24,6 +24,7 @@ mod accessor; mod builder; mod compute; mod stats; +mod variants; #[derive(Clone, Copy, Debug)] #[repr(C, align(8))] diff --git a/vortex-array/src/array/varbinview/variants.rs b/vortex-array/src/array/varbinview/variants.rs new file mode 100644 index 0000000000..0a387e8966 --- /dev/null +++ b/vortex-array/src/array/varbinview/variants.rs @@ -0,0 +1,27 @@ +use vortex_dtype::DType; + +use crate::array::varbinview::VarBinViewArray; +use crate::variants::{ArrayVariants, BinaryArrayTrait, Utf8ArrayTrait}; +use crate::ArrayDType; + +impl ArrayVariants for VarBinViewArray { + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + if matches!(self.dtype(), DType::Utf8(..)) { + Some(self) + } else { + None + } + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + if matches!(self.dtype(), DType::Binary(..)) { + Some(self) + } else { + None + } + } +} + +impl Utf8ArrayTrait for VarBinViewArray {} + +impl BinaryArrayTrait for VarBinViewArray {} diff --git a/vortex-array/src/canonical.rs b/vortex-array/src/canonical.rs index b3b1f00c36..1c5dc58285 100644 --- a/vortex-array/src/canonical.rs +++ b/vortex-array/src/canonical.rs @@ -26,6 +26,7 @@ use crate::arrow::wrappers::as_offset_buffer; use crate::compute::unary::cast::try_cast; use crate::encoding::ArrayEncoding; use crate::validity::ArrayValidity; +use crate::variants::StructArrayTrait; use crate::{Array, ArrayDType, IntoArray, ToArray}; /// The set of canonical array encodings, also the set of encodings that can be transferred to diff --git a/vortex-array/src/lib.rs b/vortex-array/src/lib.rs index 87dd9924da..c7cfe45303 100644 --- a/vortex-array/src/lib.rs +++ b/vortex-array/src/lib.rs @@ -30,6 +30,7 @@ use crate::iter::{ArrayIterator, ArrayIteratorAdapter}; use crate::stats::{ArrayStatistics, ArrayStatisticsCompute}; use crate::stream::{ArrayStream, ArrayStreamAdapter}; use crate::validity::ArrayValidity; +use crate::variants::ArrayVariants; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; pub mod accessor; @@ -49,6 +50,7 @@ pub mod stream; mod tree; mod typed; pub mod validity; +pub mod variants; pub mod vendored; mod view; pub mod visitor; @@ -228,6 +230,7 @@ pub trait ArrayTrait: ArrayEncodingRef + ArrayCompute + ArrayDType + + ArrayVariants + IntoCanonical + ArrayValidity + AcceptArrayVisitor @@ -271,6 +274,23 @@ impl Array { self.encoding() .with_dyn(self, &mut |array| { + // Sanity check that the encoding implements the correct array trait + debug_assert!( + match array.dtype() { + DType::Null => array.as_null_array().is_some(), + DType::Bool(_) => array.as_bool_array().is_some(), + DType::Primitive(..) => array.as_primitive_array().is_some(), + DType::Utf8(_) => array.as_utf8_array().is_some(), + DType::Binary(_) => array.as_binary_array().is_some(), + DType::Struct(..) => array.as_struct_array().is_some(), + DType::List(..) => array.as_list_array().is_some(), + DType::Extension(..) => array.as_extension_array().is_some(), + }, + "Encoding {} does not implement the variant trait for {}", + self.encoding().id(), + array.dtype() + ); + result = Some(f(array)); Ok(()) }) diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs new file mode 100644 index 0000000000..c6a3aa0038 --- /dev/null +++ b/vortex-array/src/variants.rs @@ -0,0 +1,118 @@ +use vortex_dtype::{DType, FieldNames}; + +/// This module defines array traits for each Vortex DType. +/// +/// When callers only want to make assumptions about the DType, and not about any specific +/// encoding, they can use these traits to write encoding-agnostic code. +use crate::{Array, ArrayTrait}; + +pub trait ArrayVariants { + fn as_null_array(&self) -> Option<&dyn NullArrayTrait> { + None + } + + fn as_null_array_unchecked(&self) -> &dyn NullArrayTrait { + self.as_null_array().expect("Expected NullArray") + } + + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + None + } + + fn as_bool_array_unchecked(&self) -> &dyn BoolArrayTrait { + self.as_bool_array().expect("Expected BoolArray") + } + + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { + None + } + + fn as_primitive_array_unchecked(&self) -> &dyn PrimitiveArrayTrait { + self.as_primitive_array().expect("Expected PrimitiveArray") + } + + fn as_utf8_array(&self) -> Option<&dyn Utf8ArrayTrait> { + None + } + + fn as_utf8_array_unchecked(&self) -> &dyn Utf8ArrayTrait { + self.as_utf8_array().expect("Expected Utf8Array") + } + + fn as_binary_array(&self) -> Option<&dyn BinaryArrayTrait> { + None + } + + fn as_binary_array_unchecked(&self) -> &dyn BinaryArrayTrait { + self.as_binary_array().expect("Expected BinaryArray") + } + + fn as_struct_array(&self) -> Option<&dyn StructArrayTrait> { + None + } + + fn as_struct_array_unchecked(&self) -> &dyn StructArrayTrait { + self.as_struct_array().expect("Expected StructArray") + } + + fn as_list_array(&self) -> Option<&dyn ListArrayTrait> { + None + } + + fn as_list_array_unchecked(&self) -> &dyn ListArrayTrait { + self.as_list_array().expect("Expected ListArray") + } + + fn as_extension_array(&self) -> Option<&dyn ExtensionArrayTrait> { + None + } + + fn as_extension_array_unchecked(&self) -> &dyn ExtensionArrayTrait { + self.as_extension_array().expect("Expected ExtensionArray") + } +} + +pub trait NullArrayTrait: ArrayTrait {} + +pub trait BoolArrayTrait: ArrayTrait {} + +pub trait PrimitiveArrayTrait: ArrayTrait {} + +pub trait Utf8ArrayTrait: ArrayTrait {} + +pub trait BinaryArrayTrait: ArrayTrait {} + +pub trait StructArrayTrait: ArrayTrait { + fn names(&self) -> &FieldNames { + let DType::Struct(st, _) = self.dtype() else { + unreachable!() + }; + st.names() + } + + fn dtypes(&self) -> &[DType] { + let DType::Struct(st, _) = self.dtype() else { + unreachable!() + }; + st.dtypes() + } + + fn nfields(&self) -> usize { + self.names().len() + } + + fn field(&self, idx: usize) -> Option; + + fn field_by_name(&self, name: &str) -> Option { + let field_idx = self + .names() + .iter() + .position(|field_name| field_name.as_ref() == name); + + field_idx.and_then(|field_idx| self.field(field_idx)) + } +} + +pub trait ListArrayTrait: ArrayTrait {} + +pub trait ExtensionArrayTrait: ArrayTrait {} diff --git a/vortex-sampling-compressor/src/lib.rs b/vortex-sampling-compressor/src/lib.rs index ce8ac56f1b..9c031f1384 100644 --- a/vortex-sampling-compressor/src/lib.rs +++ b/vortex-sampling-compressor/src/lib.rs @@ -8,6 +8,7 @@ use vortex::array::struct_::{Struct, StructArray}; use vortex::compress::{check_dtype_unchanged, check_validity_unchanged, CompressionStrategy}; use vortex::compute::slice; use vortex::validity::Validity; +use vortex::variants::StructArrayTrait; use vortex::{Array, ArrayDType, ArrayDef, IntoArray, IntoCanonical}; use vortex_error::VortexResult; diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index 12877c4126..16144de6a8 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -17,23 +17,25 @@ impl<'a> StructScalar<'a> { self.dtype } - pub fn field_by_idx(&self, idx: usize, dtype: DType) -> Option { + pub fn field_by_idx(&self, idx: usize) -> Option { + let DType::Struct(st, _) = self.dtype() else { + unreachable!() + }; + self.fields .as_ref() .and_then(|fields| fields.get(idx)) .map(|field| Scalar { - dtype, + dtype: st.dtypes()[idx].clone(), value: field.clone(), }) } - pub fn field(&self, name: &str, dtype: DType) -> Option { - let DType::Struct(struct_dtype, _) = self.dtype() else { + pub fn field(&self, name: &str) -> Option { + let DType::Struct(st, _) = self.dtype() else { unreachable!() }; - struct_dtype - .find_name(name) - .and_then(|idx| self.field_by_idx(idx, dtype)) + st.find_name(name).and_then(|idx| self.field_by_idx(idx)) } pub fn cast(&self, _dtype: &DType) -> VortexResult {