From ce181c4530d119629e9dbd705f710c0c69d0fc9c Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Mon, 11 Mar 2024 16:43:06 +0000 Subject: [PATCH] Scalars are an enum (#93) --- vortex-alp/src/compute.rs | 18 +- vortex-array/src/array/bool/compute.rs | 11 +- vortex-array/src/array/bool/mod.rs | 7 + vortex-array/src/array/chunked/compute.rs | 4 +- .../src/array/constant/compute/mod.rs | 8 +- vortex-array/src/array/constant/mod.rs | 49 +- vortex-array/src/array/constant/serde.rs | 4 +- vortex-array/src/array/constant/stats.rs | 27 +- .../src/array/primitive/compute/scalar_at.rs | 7 +- .../array/primitive/compute/search_sorted.rs | 2 +- vortex-array/src/array/primitive/mod.rs | 12 + vortex-array/src/array/primitive/stats.rs | 9 +- vortex-array/src/array/sparse/compute.rs | 11 +- vortex-array/src/array/struct_/compute.rs | 9 +- vortex-array/src/array/typed/compute.rs | 6 +- vortex-array/src/array/typed/mod.rs | 18 +- vortex-array/src/array/varbin/compute.rs | 11 +- vortex-array/src/array/varbinview/compute.rs | 7 +- vortex-array/src/arrow/compute/mod.rs | 3 - vortex-array/src/arrow/compute/repeat.rs | 79 ---- vortex-array/src/arrow/mod.rs | 1 - vortex-array/src/compute/add.rs | 6 +- vortex-array/src/compute/repeat.rs | 9 +- vortex-array/src/compute/scalar_at.rs | 6 +- vortex-array/src/compute/search_sorted.rs | 8 +- vortex-array/src/encode.rs | 4 +- vortex-array/src/polars.rs | 93 ++++ vortex-array/src/ptype.rs | 26 +- vortex-array/src/scalar/arrow.rs | 72 --- vortex-array/src/scalar/binary.rs | 85 ++-- vortex-array/src/scalar/bool.rs | 87 ++-- vortex-array/src/scalar/equal.rs | 61 --- vortex-array/src/scalar/list.rs | 137 +++--- vortex-array/src/scalar/localtime.rs | 47 +- vortex-array/src/scalar/mod.rs | 133 +++++- vortex-array/src/scalar/null.rs | 40 +- vortex-array/src/scalar/nullable.rs | 162 ------- vortex-array/src/scalar/ord.rs | 64 --- vortex-array/src/scalar/primitive.rs | 259 ++++++----- vortex-array/src/scalar/serde.rs | 422 +++++++----------- vortex-array/src/scalar/struct_.rs | 56 +-- vortex-array/src/scalar/utf8.rs | 89 ++-- vortex-array/src/serde/dtype.rs | 55 +-- vortex-array/src/serde/mod.rs | 35 +- vortex-array/src/serde/ptype.rs | 55 +++ vortex-array/src/stats.rs | 38 +- vortex-dict/src/compute.rs | 7 +- vortex-fastlanes/src/for/mod.rs | 10 +- vortex-fastlanes/src/for/serde.rs | 8 +- vortex-ree/src/compute.rs | 7 +- vortex-roaring/src/boolean/compute.rs | 7 +- vortex-roaring/src/boolean/mod.rs | 6 +- vortex-roaring/src/integer/compute.rs | 9 +- vortex-zigzag/src/compute.rs | 38 +- 54 files changed, 1058 insertions(+), 1386 deletions(-) delete mode 100644 vortex-array/src/arrow/compute/mod.rs delete mode 100644 vortex-array/src/arrow/compute/repeat.rs create mode 100644 vortex-array/src/polars.rs delete mode 100644 vortex-array/src/scalar/arrow.rs delete mode 100644 vortex-array/src/scalar/equal.rs delete mode 100644 vortex-array/src/scalar/nullable.rs delete mode 100644 vortex-array/src/scalar/ord.rs create mode 100644 vortex-array/src/serde/ptype.rs diff --git a/vortex-alp/src/compute.rs b/vortex-alp/src/compute.rs index 86426ab162..68867d26ca 100644 --- a/vortex-alp/src/compute.rs +++ b/vortex-alp/src/compute.rs @@ -6,7 +6,7 @@ use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::ArrayCompute; use vortex::dtype::{DType, FloatWidth}; use vortex::error::{VortexError, VortexResult}; -use vortex::scalar::{NullableScalar, Scalar, ScalarRef}; +use vortex::scalar::Scalar; impl ArrayCompute for ALPArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -15,30 +15,24 @@ impl ArrayCompute for ALPArray { } impl ScalarAtFn for ALPArray { - fn scalar_at(&self, index: usize) -> VortexResult { - if let Some(patch) = self - .patches() - .and_then(|p| scalar_at(p, index).ok()) - .and_then(|p| p.into_nonnull()) - { + fn scalar_at(&self, index: usize) -> VortexResult { + if let Some(patch) = self.patches().and_then(|p| scalar_at(p, index).ok()) { return Ok(patch); } - let Some(encoded_val) = scalar_at(self.encoded(), index)?.into_nonnull() else { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - }; + let encoded_val = scalar_at(self.encoded(), index)?; match self.dtype() { DType::Float(FloatWidth::_32, _) => { let encoded_val: i32 = encoded_val.try_into().unwrap(); - Ok(ScalarRef::from(::decode_single( + Ok(Scalar::from(::decode_single( encoded_val, self.exponents(), ))) } DType::Float(FloatWidth::_64, _) => { let encoded_val: i64 = encoded_val.try_into().unwrap(); - Ok(ScalarRef::from(::decode_single( + Ok(Scalar::from(::decode_single( encoded_val, self.exponents(), ))) diff --git a/vortex-array/src/array/bool/compute.rs b/vortex-array/src/array/bool/compute.rs index 44dc7417c1..24dff6a6a4 100644 --- a/vortex-array/src/array/bool/compute.rs +++ b/vortex-array/src/array/bool/compute.rs @@ -1,3 +1,6 @@ +use arrow::buffer::BooleanBuffer; +use itertools::Itertools; + use crate::array::bool::BoolArray; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::{Array, ArrayRef, CloneOptionalArray}; @@ -7,9 +10,7 @@ use crate::compute::fill::FillForwardFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; -use arrow::buffer::BooleanBuffer; -use itertools::Itertools; +use crate::scalar::{BoolScalar, Scalar}; impl ArrayCompute for BoolArray { fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> { @@ -68,11 +69,11 @@ impl CastBoolFn for BoolArray { } impl ScalarAtFn for BoolArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { if self.is_valid(index) { Ok(self.buffer.value(index).into()) } else { - Ok(NullableScalar::none(self.dtype().clone()).boxed()) + Ok(BoolScalar::new(None).into()) } } } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index 33fad93f54..65ccd4cb4e 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -53,6 +53,13 @@ impl BoolArray { .unwrap_or(true) } + pub fn null(n: usize) -> Self { + BoolArray::new( + BooleanBuffer::from(vec![false; n]), + Some(BoolArray::from(vec![false; n]).boxed()), + ) + } + #[inline] pub fn buffer(&self) -> &BooleanBuffer { &self.buffer diff --git a/vortex-array/src/array/chunked/compute.rs b/vortex-array/src/array/chunked/compute.rs index 2caa7d543d..30a6c3278a 100644 --- a/vortex-array/src/array/chunked/compute.rs +++ b/vortex-array/src/array/chunked/compute.rs @@ -5,7 +5,7 @@ use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::ScalarRef; +use crate::scalar::Scalar; use itertools::Itertools; impl ArrayCompute for ChunkedArray { @@ -31,7 +31,7 @@ impl AsContiguousFn for ChunkedArray { } impl ScalarAtFn for ChunkedArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { let (chunk_index, chunk_offset) = self.find_physical_location(index); scalar_at(self.chunks[chunk_index].as_ref(), chunk_offset) } diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index 7cf87841f4..75b8d9f053 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -4,7 +4,7 @@ use crate::compute::scalar_at::ScalarAtFn; use crate::compute::take::TakeFn; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::ScalarRef; +use crate::scalar::Scalar; impl ArrayCompute for ConstantArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -17,13 +17,13 @@ impl ArrayCompute for ConstantArray { } impl ScalarAtFn for ConstantArray { - fn scalar_at(&self, _index: usize) -> VortexResult { - Ok(dyn_clone::clone_box(self.scalar())) + fn scalar_at(&self, _index: usize) -> VortexResult { + Ok(self.scalar().clone()) } } impl TakeFn for ConstantArray { fn take(&self, indices: &dyn Array) -> VortexResult { - Ok(ConstantArray::new(dyn_clone::clone_box(self.scalar()), indices.len()).boxed()) + Ok(ConstantArray::new(self.scalar().clone(), indices.len()).boxed()) } } diff --git a/vortex-array/src/array/constant/mod.rs b/vortex-array/src/array/constant/mod.rs index 093a93cd96..bc5202acdd 100644 --- a/vortex-array/src/array/constant/mod.rs +++ b/vortex-array/src/array/constant/mod.rs @@ -1,18 +1,19 @@ use std::any::Any; use std::sync::{Arc, RwLock}; -use arrow::array::Datum; use linkme::distributed_slice; +use crate::array::bool::BoolArray; +use crate::array::primitive::PrimitiveArray; use crate::array::{ check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, ENCODINGS, }; -use crate::arrow::compute::repeat; use crate::dtype::DType; use crate::error::VortexResult; use crate::formatter::{ArrayDisplay, ArrayFormatter}; -use crate::scalar::{Scalar, ScalarRef}; +use crate::match_each_native_ptype; +use crate::scalar::{PScalar, Scalar}; use crate::serde::{ArraySerde, EncodingSerde}; use crate::stats::{Stats, StatsSet}; @@ -22,13 +23,13 @@ mod stats; #[derive(Debug, Clone)] pub struct ConstantArray { - scalar: ScalarRef, + scalar: Scalar, length: usize, stats: Arc>, } impl ConstantArray { - pub fn new(scalar: ScalarRef, length: usize) -> Self { + pub fn new(scalar: Scalar, length: usize) -> Self { Self { scalar, length, @@ -36,8 +37,8 @@ impl ConstantArray { } } - pub fn scalar(&self) -> &dyn Scalar { - self.scalar.as_ref() + pub fn scalar(&self) -> &Scalar { + &self.scalar } } @@ -78,8 +79,38 @@ impl Array for ConstantArray { } fn iter_arrow(&self) -> Box { - let arrow_scalar: Box = self.scalar.as_ref().into(); - Box::new(std::iter::once(repeat(arrow_scalar.as_ref(), self.length))) + let plain_array = match self.scalar() { + Scalar::Bool(b) => { + if let Some(bv) = b.value() { + BoolArray::from(vec![bv; self.len()]).boxed() + } else { + BoolArray::null(self.len()).boxed() + } + } + Scalar::Primitive(p) => { + if let Some(ps) = p.value() { + match ps { + PScalar::U8(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::U16(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::U32(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::U64(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::I8(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::I16(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::I32(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::I64(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::F16(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::F32(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + PScalar::F64(p) => PrimitiveArray::from_value(p, self.len()).boxed(), + } + } else { + match_each_native_ptype!(p.ptype(), |$P| { + PrimitiveArray::null::<$P>(self.len()).boxed() + }) + } + } + _ => panic!("Unsupported scalar type {}", self.dtype()), + }; + plain_array.iter_arrow() } fn slice(&self, start: usize, stop: usize) -> VortexResult { diff --git a/vortex-array/src/array/constant/serde.rs b/vortex-array/src/array/constant/serde.rs index 47c7a46d32..f01098bcbd 100644 --- a/vortex-array/src/array/constant/serde.rs +++ b/vortex-array/src/array/constant/serde.rs @@ -24,12 +24,12 @@ mod test { use crate::array::constant::ConstantArray; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::Array; - use crate::scalar::NullableScalarOption; + use crate::scalar::{PScalar, PrimitiveScalar}; use crate::serde::test::roundtrip_array; #[test] fn roundtrip() { - let arr = ConstantArray::new(NullableScalarOption(Some(42)).into(), 100); + let arr = ConstantArray::new(PrimitiveScalar::some(PScalar::I32(42)).into(), 100); let read_arr = roundtrip_array(arr.as_ref()).unwrap(); assert_eq!(arr.scalar(), read_arr.as_constant().scalar()); diff --git a/vortex-array/src/array/constant/stats.rs b/vortex-array/src/array/constant/stats.rs index 14eeefad12..744f401946 100644 --- a/vortex-array/src/array/constant/stats.rs +++ b/vortex-array/src/array/constant/stats.rs @@ -2,34 +2,31 @@ use std::collections::HashMap; use crate::array::constant::ConstantArray; use crate::array::Array; -use crate::dtype::{DType, Nullability}; +use crate::dtype::DType; use crate::error::VortexResult; -use crate::scalar::{BoolScalar, PScalar, Scalar}; +use crate::scalar::{PScalar, PrimitiveScalar, Scalar}; use crate::stats::{Stat, StatsCompute, StatsSet}; impl StatsCompute for ConstantArray { fn compute(&self, _stat: &Stat) -> VortexResult { let mut m = HashMap::from([ - (Stat::Max, dyn_clone::clone_box(self.scalar())), - (Stat::Min, dyn_clone::clone_box(self.scalar())), + (Stat::Max, self.scalar().clone()), + (Stat::Min, self.scalar().clone()), (Stat::IsConstant, true.into()), (Stat::IsSorted, true.into()), (Stat::RunCount, 1.into()), ]); - if matches!(self.dtype(), &DType::Bool(Nullability::NonNullable)) { + if matches!(self.dtype(), &DType::Bool(_)) { + let Scalar::Bool(b) = self.scalar() else { + unreachable!("Got bool dtype without bool scalar") + }; m.insert( Stat::TrueCount, - PScalar::U64( - self.len() as u64 - * self - .scalar() - .as_any() - .downcast_ref::() - .unwrap() - .value() as u64, - ) - .boxed(), + PrimitiveScalar::some(PScalar::U64( + self.len() as u64 * b.value().map(|v| v as u64).unwrap_or(0), + )) + .into(), ); } diff --git a/vortex-array/src/array/primitive/compute/scalar_at.rs b/vortex-array/src/array/primitive/compute/scalar_at.rs index bf550ef10c..8e0e3d4c78 100644 --- a/vortex-array/src/array/primitive/compute/scalar_at.rs +++ b/vortex-array/src/array/primitive/compute/scalar_at.rs @@ -1,16 +1,15 @@ use crate::array::primitive::PrimitiveArray; -use crate::array::Array; use crate::compute::scalar_at::ScalarAtFn; use crate::error::VortexResult; use crate::match_each_native_ptype; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; +use crate::scalar::{PrimitiveScalar, Scalar}; impl ScalarAtFn for PrimitiveArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { if self.is_valid(index) { Ok(match_each_native_ptype!(self.ptype, |$T| self.typed_data::<$T>()[index].into())) } else { - Ok(NullableScalar::none(self.dtype().clone()).boxed()) + Ok(PrimitiveScalar::none(self.ptype).into()) } } } diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index bcc8396369..4e0993b883 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -6,7 +6,7 @@ use crate::ptype::NativePType; use crate::scalar::Scalar; impl SearchSortedFn for PrimitiveArray { - fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult { + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { match_each_native_ptype!(self.ptype(), |$T| { let pvalue: $T = value.try_into()?; Ok(search_sorted(self.typed_data::<$T>(), pvalue, side)) diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index 4be48ed899..c0b14af1d3 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -11,6 +11,7 @@ use arrow::array::{make_array, ArrayData, AsArray}; use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer}; use linkme::distributed_slice; +use crate::array::bool::BoolArray; use crate::array::{ check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef, ENCODINGS, @@ -98,6 +99,17 @@ impl PrimitiveArray { .unwrap_or(true) } + pub fn from_value(value: T, n: usize) -> Self { + PrimitiveArray::from(iter::repeat(value).take(n).collect::>()) + } + + pub fn null(n: usize) -> Self { + PrimitiveArray::from_nullable( + iter::repeat(T::zero()).take(n).collect::>(), + Some(BoolArray::from(vec![false; n]).boxed()), + ) + } + #[inline] pub fn ptype(&self) -> &PType { &self.ptype diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index 80ee311f83..a577ea74d6 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -1,13 +1,14 @@ -use arrow::buffer::BooleanBuffer; use std::collections::HashMap; use std::mem::size_of; +use arrow::buffer::BooleanBuffer; + use crate::array::primitive::PrimitiveArray; use crate::compute::cast::cast_bool; use crate::error::VortexResult; use crate::match_each_native_ptype; use crate::ptype::NativePType; -use crate::scalar::{ListScalarVec, NullableScalar, PScalar, Scalar}; +use crate::scalar::{ListScalarVec, PScalar}; use crate::stats::{Stat, StatsCompute, StatsSet}; impl StatsCompute for PrimitiveArray { @@ -54,8 +55,8 @@ impl<'a, T: NativePType> StatsCompute for NullableValues<'a, T> { if first_non_null.is_none() { return Ok(StatsSet::from(HashMap::from([ - (Stat::Min, NullableScalar::none(T::PTYPE.into()).boxed()), - (Stat::Max, NullableScalar::none(T::PTYPE.into()).boxed()), + (Stat::Min, Option::::None.into()), + (Stat::Max, Option::::None.into()), (Stat::IsConstant, true.into()), (Stat::IsSorted, true.into()), (Stat::IsStrictSorted, true.into()), diff --git a/vortex-array/src/array/sparse/compute.rs b/vortex-array/src/array/sparse/compute.rs index 91de0d9a48..98330c3480 100644 --- a/vortex-array/src/array/sparse/compute.rs +++ b/vortex-array/src/array/sparse/compute.rs @@ -1,3 +1,5 @@ +use itertools::Itertools; + use crate::array::downcast::DowncastArrayBuiltin; use crate::array::sparse::SparseArray; use crate::array::{Array, ArrayRef}; @@ -6,8 +8,7 @@ use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; -use itertools::Itertools; +use crate::scalar::Scalar; impl ArrayCompute for SparseArray { fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> { @@ -43,7 +44,7 @@ impl AsContiguousFn for SparseArray { } impl ScalarAtFn for SparseArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { // Check whether `true_patch_index` exists in the patch index array // First, get the index of the patch index array that is the first index // greater than or equal to the true index @@ -52,13 +53,13 @@ impl ScalarAtFn for SparseArray { // If the value at this index is equal to the true index, then it exists in the patch index array // and we should return the value at the corresponding index in the patch values array scalar_at(self.indices(), idx) - .or_else(|_| Ok(NullableScalar::none(self.values().dtype().clone()).boxed())) + .or_else(|_| Ok(Scalar::null(self.values().dtype()))) .and_then(usize::try_from) .and_then(|patch_index| { if patch_index == true_patch_index { scalar_at(self.values(), idx) } else { - Ok(NullableScalar::none(self.values().dtype().clone()).boxed()) + Ok(Scalar::null(self.values().dtype())) } }) }) diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index a7281ef1c3..520dcf634c 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -1,3 +1,5 @@ +use itertools::Itertools; + use crate::array::downcast::DowncastArrayBuiltin; use crate::array::struct_::StructArray; use crate::array::{Array, ArrayRef}; @@ -5,8 +7,7 @@ use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::{Scalar, ScalarRef, StructScalar}; -use itertools::Itertools; +use crate::scalar::{Scalar, StructScalar}; impl ArrayCompute for StructArray { fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> { @@ -38,7 +39,7 @@ impl AsContiguousFn for StructArray { } impl ScalarAtFn for StructArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { Ok(StructScalar::new( self.dtype.clone(), self.fields @@ -46,6 +47,6 @@ impl ScalarAtFn for StructArray { .map(|field| scalar_at(field.as_ref(), index)) .try_collect()?, ) - .boxed()) + .into()) } } diff --git a/vortex-array/src/array/typed/compute.rs b/vortex-array/src/array/typed/compute.rs index e82d8364f1..c90ccda12e 100644 --- a/vortex-array/src/array/typed/compute.rs +++ b/vortex-array/src/array/typed/compute.rs @@ -5,7 +5,7 @@ use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::ArrayCompute; use crate::error::VortexResult; -use crate::scalar::ScalarRef; +use crate::scalar::Scalar; use itertools::Itertools; impl ArrayCompute for TypedArray { @@ -34,8 +34,8 @@ impl AsContiguousFn for TypedArray { } impl ScalarAtFn for TypedArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { let underlying = scalar_at(self.array.as_ref(), index)?; - underlying.as_ref().cast(self.dtype()) + underlying.cast(self.dtype()) } } diff --git a/vortex-array/src/array/typed/mod.rs b/vortex-array/src/array/typed/mod.rs index 5965bc46bf..f3733bca88 100644 --- a/vortex-array/src/array/typed/mod.rs +++ b/vortex-array/src/array/typed/mod.rs @@ -164,7 +164,7 @@ mod test { use crate::array::Array; use crate::compute::scalar_at::scalar_at; use crate::dtype::{DType, Nullability, TimeUnit}; - use crate::scalar::{LocalTimeScalar, PScalar, Scalar}; + use crate::scalar::{LocalTimeScalar, PScalar, PrimitiveScalar}; #[test] pub fn scalar() { @@ -173,12 +173,20 @@ mod test { DType::LocalTime(TimeUnit::Us, Nullability::NonNullable), ); assert_eq!( - scalar_at(arr.as_ref(), 0).unwrap().as_ref(), - &LocalTimeScalar::new(PScalar::U64(64_799_000_000), TimeUnit::Us) as &dyn Scalar + scalar_at(arr.as_ref(), 0).unwrap(), + LocalTimeScalar::new( + PrimitiveScalar::some(PScalar::U64(64_799_000_000)), + TimeUnit::Us + ) + .into() ); assert_eq!( - scalar_at(arr.as_ref(), 1).unwrap().as_ref(), - &LocalTimeScalar::new(PScalar::U64(43_000_000_000), TimeUnit::Us) as &dyn Scalar + scalar_at(arr.as_ref(), 1).unwrap(), + LocalTimeScalar::new( + PrimitiveScalar::some(PScalar::U64(43_000_000_000)), + TimeUnit::Us + ) + .into() ); } diff --git a/vortex-array/src/array/varbin/compute.rs b/vortex-array/src/array/varbin/compute.rs index 22ae1b6ca5..ed4b7d6008 100644 --- a/vortex-array/src/array/varbin/compute.rs +++ b/vortex-array/src/array/varbin/compute.rs @@ -1,3 +1,5 @@ +use itertools::Itertools; + use crate::array::bool::BoolArray; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::primitive::PrimitiveArray; @@ -10,8 +12,7 @@ use crate::compute::ArrayCompute; use crate::dtype::DType; use crate::error::VortexResult; use crate::ptype::PType; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; -use itertools::Itertools; +use crate::scalar::{BinaryScalar, Scalar, Utf8Scalar}; impl ArrayCompute for VarBinArray { fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> { @@ -66,7 +67,7 @@ impl AsContiguousFn for VarBinArray { } impl ScalarAtFn for VarBinArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { if self.is_valid(index) { self.bytes_at(index).map(|bytes| { if matches!(self.dtype, DType::Utf8(_)) { @@ -75,8 +76,10 @@ impl ScalarAtFn for VarBinArray { bytes.into() } }) + } else if matches!(self.dtype, DType::Utf8(_)) { + Ok(Utf8Scalar::new(None).into()) } else { - Ok(NullableScalar::none(self.dtype.clone()).boxed()) + Ok(BinaryScalar::new(None).into()) } } } diff --git a/vortex-array/src/array/varbinview/compute.rs b/vortex-array/src/array/varbinview/compute.rs index ae4b3949ee..a6f440b111 100644 --- a/vortex-array/src/array/varbinview/compute.rs +++ b/vortex-array/src/array/varbinview/compute.rs @@ -1,9 +1,10 @@ use crate::array::varbinview::VarBinViewArray; +use crate::array::Array; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::ArrayCompute; use crate::dtype::DType; use crate::error::VortexResult; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; +use crate::scalar::Scalar; impl ArrayCompute for VarBinViewArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -12,7 +13,7 @@ impl ArrayCompute for VarBinViewArray { } impl ScalarAtFn for VarBinViewArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { if self.is_valid(index) { self.bytes_at(index).map(|bytes| { if matches!(self.dtype, DType::Utf8(_)) { @@ -22,7 +23,7 @@ impl ScalarAtFn for VarBinViewArray { } }) } else { - Ok(NullableScalar::none(self.dtype.clone()).boxed()) + Ok(Scalar::null(self.dtype())) } } } diff --git a/vortex-array/src/arrow/compute/mod.rs b/vortex-array/src/arrow/compute/mod.rs deleted file mode 100644 index 9ca7476da5..0000000000 --- a/vortex-array/src/arrow/compute/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub use repeat::*; - -mod repeat; diff --git a/vortex-array/src/arrow/compute/repeat.rs b/vortex-array/src/arrow/compute/repeat.rs deleted file mode 100644 index 662775e721..0000000000 --- a/vortex-array/src/arrow/compute/repeat.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::sync::Arc; - -use arrow::array::cast::AsArray; -use arrow::array::types::{ - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, -}; -use arrow::array::{ArrayRef, ArrowPrimitiveType, BooleanArray, Datum, NullArray, PrimitiveArray}; -use arrow::buffer::BooleanBuffer; -use arrow::datatypes::DataType; - -macro_rules! repeat_primitive { - ($arrow_type:ty, $arr:expr, $n:expr) => {{ - if $arr.is_null(0) { - return repeat_primitive::<$arrow_type>(None, $n) as ArrayRef; - } - - repeat_primitive::<$arrow_type>(Some($arr.as_primitive::<$arrow_type>().value(0)), $n) - as ArrayRef - }}; -} - -pub fn repeat(scalar: &dyn Datum, n: usize) -> ArrayRef { - let (arr, is_scalar) = scalar.get(); - assert!(is_scalar, "Datum was not a scalar"); - match arr.data_type() { - DataType::Null => Arc::new(NullArray::new(n)), - DataType::Boolean => { - if arr.is_valid(0) { - if arr.as_boolean().value(0) { - Arc::new(BooleanArray::from(BooleanBuffer::new_set(n))) - } else { - Arc::new(BooleanArray::from(BooleanBuffer::new_unset(n))) - } - } else { - Arc::new(BooleanArray::new_null(n)) - } - } - DataType::UInt8 => repeat_primitive!(UInt8Type, arr, n), - DataType::UInt16 => repeat_primitive!(UInt16Type, arr, n), - DataType::UInt32 => repeat_primitive!(UInt32Type, arr, n), - DataType::UInt64 => repeat_primitive!(UInt64Type, arr, n), - DataType::Int8 => repeat_primitive!(Int8Type, arr, n), - DataType::Int16 => repeat_primitive!(Int16Type, arr, n), - DataType::Int32 => repeat_primitive!(Int32Type, arr, n), - DataType::Int64 => repeat_primitive!(Int64Type, arr, n), - DataType::Float16 => repeat_primitive!(Float16Type, arr, n), - DataType::Float32 => repeat_primitive!(Float32Type, arr, n), - DataType::Float64 => repeat_primitive!(Float64Type, arr, n), - _ => todo!("Not implemented yet"), - } -} - -fn repeat_primitive( - value: Option, - n: usize, -) -> Arc> { - Arc::new( - value - .map(|v| PrimitiveArray::from_value(v, n)) - .unwrap_or_else(|| PrimitiveArray::new_null(n)), - ) -} - -#[cfg(test)] -mod test { - use crate::arrow::compute::repeat; - use arrow::array::cast::AsArray; - use arrow::array::types::UInt64Type; - use arrow::array::{Scalar, UInt64Array}; - - #[test] - fn test_repeat() { - let scalar = Scalar::new(UInt64Array::from(vec![47])); - let array = repeat(&scalar, 100); - assert_eq!(array.len(), 100); - assert_eq!(array.as_primitive::().value(50), 47); - } -} diff --git a/vortex-array/src/arrow/mod.rs b/vortex-array/src/arrow/mod.rs index e2c9df1c2f..2884ea84d3 100644 --- a/vortex-array/src/arrow/mod.rs +++ b/vortex-array/src/arrow/mod.rs @@ -4,7 +4,6 @@ use itertools::Itertools; use crate::array::ArrowIterator; pub mod aligned_iter; -pub mod compute; pub mod convert; pub trait CombineChunks { diff --git a/vortex-array/src/compute/add.rs b/vortex-array/src/compute/add.rs index f1d4158a25..f020b75610 100644 --- a/vortex-array/src/compute/add.rs +++ b/vortex-array/src/compute/add.rs @@ -1,7 +1,7 @@ use crate::array::constant::ConstantArray; use crate::array::{Array, ArrayKind, ArrayRef}; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{Scalar, ScalarRef}; +use crate::scalar::Scalar; // TODO(ngates): convert this to arithmetic operations with macro over the kernel. pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { @@ -21,7 +21,7 @@ pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { } } -pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> VortexResult { +pub fn add_scalar(lhs: &dyn Array, rhs: &Scalar) -> VortexResult { match ArrayKind::from(lhs) { ArrayKind::Constant(lhs) => { Ok(ConstantArray::new(add_scalars(lhs.scalar(), rhs)?, lhs.len()).boxed()) @@ -30,7 +30,7 @@ pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> VortexResult { } } -pub fn add_scalars(_lhs: &dyn Scalar, _rhs: &dyn Scalar) -> VortexResult { +pub fn add_scalars(_lhs: &Scalar, _rhs: &Scalar) -> VortexResult { // Might need to improve this implementation... Ok(24.into()) } diff --git a/vortex-array/src/compute/repeat.rs b/vortex-array/src/compute/repeat.rs index 9b94508a63..3d383fc1c7 100644 --- a/vortex-array/src/compute/repeat.rs +++ b/vortex-array/src/compute/repeat.rs @@ -2,19 +2,18 @@ use crate::array::constant::ConstantArray; use crate::array::{Array, ArrayRef}; use crate::scalar::Scalar; -pub fn repeat(scalar: &dyn Scalar, n: usize) -> ArrayRef { - ConstantArray::new(dyn_clone::clone_box(scalar), n).boxed() +pub fn repeat(scalar: &Scalar, n: usize) -> ArrayRef { + ConstantArray::new(scalar.clone(), n).boxed() } #[cfg(test)] mod test { use super::*; - use crate::scalar::ScalarRef; #[test] fn test_repeat() { - let scalar: ScalarRef = 47.into(); - let array = repeat(scalar.as_ref(), 100); + let scalar: Scalar = 47.into(); + let array = repeat(&scalar, 100); assert_eq!(array.len(), 100); } } diff --git a/vortex-array/src/compute/scalar_at.rs b/vortex-array/src/compute/scalar_at.rs index f579246cfd..c3d601dbd0 100644 --- a/vortex-array/src/compute/scalar_at.rs +++ b/vortex-array/src/compute/scalar_at.rs @@ -1,12 +1,12 @@ use crate::array::Array; use crate::error::{VortexError, VortexResult}; -use crate::scalar::ScalarRef; +use crate::scalar::Scalar; pub trait ScalarAtFn { - fn scalar_at(&self, index: usize) -> VortexResult; + fn scalar_at(&self, index: usize) -> VortexResult; } -pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult { +pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult { if index >= array.len() { return Err(VortexError::OutOfBounds(index, 0, array.len())); } diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index d6e836be94..b002d8fb9d 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -1,6 +1,6 @@ use crate::array::Array; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{Scalar, ScalarRef}; +use crate::scalar::Scalar; pub enum SearchSortedSide { Left, @@ -8,10 +8,10 @@ pub enum SearchSortedSide { } pub trait SearchSortedFn { - fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult; + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; } -pub fn search_sorted>( +pub fn search_sorted>( array: &dyn Array, target: T, side: SearchSortedSide, @@ -19,7 +19,7 @@ pub fn search_sorted>( let scalar = target.into().cast(array.dtype())?; array .search_sorted() - .map(|f| f.search_sorted(scalar.as_ref(), side)) + .map(|f| f.search_sorted(&scalar, side)) .unwrap_or_else(|| { Err(VortexError::NotImplemented( "search_sorted", diff --git a/vortex-array/src/encode.rs b/vortex-array/src/encode.rs index 020c2133d3..ac2dfcdd9a 100644 --- a/vortex-array/src/encode.rs +++ b/vortex-array/src/encode.rs @@ -29,7 +29,7 @@ use crate::array::varbin::VarBinArray; use crate::array::{Array, ArrayRef}; use crate::arrow::convert::TryIntoDType; use crate::ptype::PType; -use crate::scalar::{NullScalar, Scalar}; +use crate::scalar::NullScalar; impl From<&Buffer> for ArrayRef { fn from(value: &Buffer) -> Self { @@ -110,7 +110,7 @@ impl From<&ArrowStructArray> for ArrayRef { impl From<&ArrowNullArray> for ArrayRef { fn from(value: &ArrowNullArray) -> Self { - ConstantArray::new(NullScalar::new().boxed(), value.len()).boxed() + ConstantArray::new(NullScalar::new().into(), value.len()).boxed() } } diff --git a/vortex-array/src/polars.rs b/vortex-array/src/polars.rs new file mode 100644 index 0000000000..2f151d22cb --- /dev/null +++ b/vortex-array/src/polars.rs @@ -0,0 +1,93 @@ +use arrow::array::{Array as ArrowArray, ArrayRef as ArrowArrayRef}; +use polars_arrow::array::from_data; +use polars_core::prelude::{AnyValue, Series}; + +use crate::array::ArrowIterator; +use crate::dtype::DType; +use crate::scalar::{BinaryScalar, BoolScalar, PScalar, Scalar, Utf8Scalar}; + +pub trait IntoPolarsSeries { + fn into_polars(self) -> Series; +} + +impl IntoPolarsSeries for ArrowArrayRef { + fn into_polars(self) -> Series { + let polars_array = from_data(&self.to_data()); + ("array", polars_array).try_into().unwrap() + } +} + +impl IntoPolarsSeries for Vec { + fn into_polars(self) -> Series { + let chunks: Vec> = + self.iter().map(|a| from_data(&a.to_data())).collect(); + ("array", chunks).try_into().unwrap() + } +} + +impl IntoPolarsSeries for Box { + fn into_polars(self) -> Series { + let chunks: Vec> = + self.map(|a| from_data(&a.to_data())).collect(); + ("array", chunks).try_into().unwrap() + } +} + +pub trait IntoPolarsValue { + fn into_polars<'a>(self) -> AnyValue<'a>; +} + +impl IntoPolarsValue for Scalar { + fn into_polars<'a>(self) -> AnyValue<'a> { + if let Some(ns) = self.as_any().downcast_ref::() { + return match ns { + NullableScalar::Some(s, _) => s.as_ref().into_polars(), + NullableScalar::None(_) => AnyValue::Null, + }; + } + + match self.dtype() { + DType::Null => AnyValue::Null, + DType::Bool(_) => { + AnyValue::Boolean(self.as_any().downcast_ref::().unwrap().value()) + } + DType::Int(_, _, _) | DType::Float(_, _) => { + match self.as_any().downcast_ref::().unwrap() { + PScalar::U8(v) => AnyValue::UInt8(*v), + PScalar::U16(v) => AnyValue::UInt16(*v), + PScalar::U32(v) => AnyValue::UInt32(*v), + PScalar::U64(v) => AnyValue::UInt64(*v), + PScalar::I8(v) => AnyValue::Int8(*v), + PScalar::I16(v) => AnyValue::Int16(*v), + PScalar::I32(v) => AnyValue::Int32(*v), + PScalar::I64(v) => AnyValue::Int64(*v), + PScalar::F16(v) => AnyValue::Float32(v.to_f32()), + PScalar::F32(v) => AnyValue::Float32(*v), + PScalar::F64(v) => AnyValue::Float64(*v), + } + } + DType::Decimal(_, _, _) => todo!(), + DType::Utf8(_) => AnyValue::StringOwned( + self.as_any() + .downcast_ref::() + .unwrap() + .value() + .into(), + ), + DType::Binary(_) => AnyValue::BinaryOwned( + self.as_any() + .downcast_ref::() + .unwrap() + .value() + .clone(), + ), + DType::LocalTime(_, _) => todo!(), + DType::LocalDate(_) => todo!(), + DType::Instant(_, _) => todo!(), + DType::ZonedDateTime(_, _) => todo!(), + DType::Struct(_, _) => todo!(), + DType::List(_, _) => todo!(), + DType::Map(_, _, _) => todo!(), + } + } +} diff --git a/vortex-array/src/ptype.rs b/vortex-array/src/ptype.rs index 7463a53bb3..b1be9e820a 100644 --- a/vortex-array/src/ptype.rs +++ b/vortex-array/src/ptype.rs @@ -1,4 +1,4 @@ -use std::fmt::{Debug, Display}; +use std::fmt::{Debug, Display, Formatter}; use std::panic::RefUnwindSafe; use arrow::datatypes::ArrowNativeType; @@ -7,7 +7,7 @@ use num_traits::{Num, NumCast}; use crate::dtype::{DType, FloatWidth, IntWidth, Signedness}; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{PScalar, ScalarRef}; +use crate::scalar::{PScalar, Scalar}; #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)] pub enum PType { @@ -37,8 +37,8 @@ pub trait NativePType: + RefUnwindSafe + Num + NumCast - + Into - + TryFrom + + Into + + TryFrom + Into { const PTYPE: PType; @@ -149,6 +149,24 @@ impl PType { } } +impl Display for PType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + PType::U8 => write!(f, "u8"), + PType::U16 => write!(f, "u16"), + PType::U32 => write!(f, "u32"), + PType::U64 => write!(f, "u64"), + PType::I8 => write!(f, "i8"), + PType::I16 => write!(f, "i16"), + PType::I32 => write!(f, "i32"), + PType::I64 => write!(f, "i64"), + PType::F16 => write!(f, "f16"), + PType::F32 => write!(f, "f32"), + PType::F64 => write!(f, "f64"), + } + } +} + impl TryFrom<&DType> for PType { type Error = VortexError; diff --git a/vortex-array/src/scalar/arrow.rs b/vortex-array/src/scalar/arrow.rs deleted file mode 100644 index 13ce1700c1..0000000000 --- a/vortex-array/src/scalar/arrow.rs +++ /dev/null @@ -1,72 +0,0 @@ -use arrow::array::types::{ - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, -}; -use arrow::array::Scalar as ArrowScalar; -use arrow::array::{Datum, PrimitiveArray}; - -use crate::scalar::{PScalar, Scalar}; - -impl From<&dyn Scalar> for Box { - fn from(value: &dyn Scalar) -> Self { - if let Some(pscalar) = value.as_any().downcast_ref::() { - return match pscalar { - PScalar::U8(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::U16(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::U32(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::U64(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::I8(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![*v]))) - } - PScalar::I16(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::I32(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::I64(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::F16(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::F32(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - PScalar::F64(v) => { - Box::new(ArrowScalar::new(PrimitiveArray::::from(vec![ - *v, - ]))) - } - }; - } - - todo!("implement other scalar types {:?}", value) - } -} diff --git a/vortex-array/src/scalar/binary.rs b/vortex-array/src/scalar/binary.rs index fd2f725087..3975ee1edd 100644 --- a/vortex-array/src/scalar/binary.rs +++ b/vortex-array/src/scalar/binary.rs @@ -1,97 +1,68 @@ +use std::fmt::{Display, Formatter}; + use crate::dtype::{DType, Nullability}; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{Scalar, ScalarRef}; -use std::any::Any; -use std::fmt::{Display, Formatter}; +use crate::scalar::Scalar; #[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct BinaryScalar { - value: Vec, + value: Option>, } impl BinaryScalar { - pub fn new(value: Vec) -> Self { + pub fn new(value: Option>) -> Self { Self { value } } - pub fn value(&self) -> &Vec { - &self.value - } -} - -impl Scalar for BinaryScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) + pub fn none() -> Self { + Self { value: None } } - #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) + pub fn some(value: Vec) -> Self { + Self { value: Some(value) } } - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) + pub fn value(&self) -> Option<&[u8]> { + self.value.as_deref() } #[inline] - fn dtype(&self) -> &DType { + pub fn dtype(&self) -> &DType { &DType::Binary(Nullability::NonNullable) } - fn cast(&self, _dtype: &DType) -> VortexResult { + pub fn cast(&self, _dtype: &DType) -> VortexResult { todo!() } - fn nbytes(&self) -> usize { - self.value.len() + pub fn nbytes(&self) -> usize { + self.value().map(|s| s.len()).unwrap_or(1) } } -impl From> for ScalarRef { +impl From> for Scalar { fn from(value: Vec) -> Self { - BinaryScalar::new(value).boxed() + BinaryScalar::new(Some(value)).into() } } -impl TryFrom for Vec { +impl TryFrom for Vec { type Error = VortexError; - fn try_from(value: ScalarRef) -> Result { - let dtype = value.dtype().clone(); - let scalar = value - .into_any() - .downcast::() - .map_err(|_| VortexError::InvalidDType(dtype))?; - Ok(scalar.value) - } -} - -impl TryFrom<&dyn Scalar> for Vec { - type Error = VortexError; - - fn try_from(value: &dyn Scalar) -> Result { - if let Some(scalar) = value.as_any().downcast_ref::() { - Ok(scalar.value.clone()) - } else { - Err(VortexError::InvalidDType(value.dtype().clone())) - } + fn try_from(value: Scalar) -> VortexResult { + let Scalar::Binary(b) = value else { + return Err(VortexError::InvalidDType(value.dtype().clone())); + }; + let dtype = b.dtype().clone(); + b.value.ok_or_else(|| VortexError::InvalidDType(dtype)) } } impl Display for BinaryScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "bytes[{}]", self.value.len()) + match self.value() { + None => write!(f, "bytes[none]"), + Some(b) => write!(f, "bytes[{}]", b.len()), + } } } diff --git a/vortex-array/src/scalar/bool.rs b/vortex-array/src/scalar/bool.rs index ea7c9f56bf..fb4093e727 100644 --- a/vortex-array/src/scalar/bool.rs +++ b/vortex-array/src/scalar/bool.rs @@ -1,105 +1,74 @@ -use std::any::Any; use std::fmt::{Display, Formatter}; use crate::dtype::{DType, Nullability}; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; +use crate::scalar::Scalar; #[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct BoolScalar { - value: bool, + value: Option, } impl BoolScalar { - pub fn new(value: bool) -> Self { + pub fn new(value: Option) -> Self { Self { value } } - pub fn value(&self) -> bool { - self.value - } -} - -impl Scalar for BoolScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self + pub fn none() -> Self { + Self { value: None } } - #[inline] - fn into_any(self: Box) -> Box { - self + pub fn some(value: bool) -> Self { + Self { value: Some(value) } } - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) + pub fn value(&self) -> Option { + self.value } #[inline] - fn dtype(&self) -> &DType { + pub fn dtype(&self) -> &DType { &DType::Bool(Nullability::NonNullable) } - fn cast(&self, dtype: &DType) -> VortexResult { + pub fn cast(&self, dtype: &DType) -> VortexResult { match dtype { - DType::Bool(Nullability::NonNullable) => Ok(self.clone().boxed()), - DType::Bool(Nullability::Nullable) => { - Ok(NullableScalar::some(self.clone().boxed()).boxed()) - } + DType::Bool(_) => Ok(self.clone().into()), _ => Err(VortexError::InvalidDType(dtype.clone())), } } - fn nbytes(&self) -> usize { + pub fn nbytes(&self) -> usize { 1 } } -impl From for ScalarRef { +impl From for Scalar { #[inline] fn from(value: bool) -> Self { - BoolScalar::new(value).boxed() + BoolScalar::new(Some(value)).into() } } -impl TryFrom for bool { +impl TryFrom for bool { type Error = VortexError; - #[inline] - fn try_from(value: ScalarRef) -> VortexResult { - value.as_ref().try_into() - } -} - -impl TryFrom<&dyn Scalar> for bool { - type Error = VortexError; + fn try_from(value: Scalar) -> VortexResult { + let Scalar::Bool(b) = value else { + return Err(VortexError::InvalidDType(value.dtype().clone())); + }; - fn try_from(value: &dyn Scalar) -> VortexResult { - if let Some(bool_scalar) = value - .as_nonnull() - .and_then(|v| v.as_any().downcast_ref::()) - { - Ok(bool_scalar.value()) - } else { - Err(VortexError::InvalidDType(value.dtype().clone())) - } + b.value() + .ok_or_else(|| VortexError::InvalidDType(b.dtype().clone())) } } impl Display for BoolScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.value) + match self.value() { + None => write!(f, "null"), + Some(b) => Display::fmt(&b, f), + } } } @@ -109,7 +78,7 @@ mod test { #[test] fn into_from() { - let scalar: ScalarRef = false.into(); - assert_eq!(scalar.as_ref().try_into(), Ok(false)); + let scalar: Scalar = false.into(); + assert_eq!(scalar.try_into(), Ok(false)); } } diff --git a/vortex-array/src/scalar/equal.rs b/vortex-array/src/scalar/equal.rs deleted file mode 100644 index 8ce8a7f3a2..0000000000 --- a/vortex-array/src/scalar/equal.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::sync::Arc; - -use crate::scalar::localtime::LocalTimeScalar; -use crate::scalar::{ - BinaryScalar, BoolScalar, NullableScalar, PScalar, Scalar, ScalarRef, StructScalar, Utf8Scalar, -}; - -impl PartialEq for dyn Scalar { - fn eq(&self, that: &dyn Scalar) -> bool { - equal(self, that) - } -} - -impl PartialEq for Arc { - fn eq(&self, that: &dyn Scalar) -> bool { - equal(&**self, that) - } -} - -impl PartialEq for ScalarRef { - fn eq(&self, that: &dyn Scalar) -> bool { - equal(self.as_ref(), that) - } -} - -impl Eq for dyn Scalar {} - -macro_rules! dyn_eq { - ($ty:ty, $lhs:expr, $rhs:expr) => {{ - let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap(); - let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap(); - lhs == rhs - }}; -} - -fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { - if lhs.dtype() != rhs.dtype() { - return false; - } - - // If the dtypes are the same then both of the scalars are either nullable or plain scalar - if let Some(ls) = lhs.as_any().downcast_ref::() { - if let Some(rs) = rhs.as_any().downcast_ref::() { - return dyn_eq!(NullableScalar, ls, rs); - } else { - unreachable!("DTypes were equal, but only one was nullable") - } - } - - use crate::dtype::DType::*; - match lhs.dtype() { - Bool(_) => dyn_eq!(BoolScalar, lhs, rhs), - Int(_, _, _) => dyn_eq!(PScalar, lhs, rhs), - Float(_, _) => dyn_eq!(PScalar, lhs, rhs), - Struct(..) => dyn_eq!(StructScalar, lhs, rhs), - Utf8(_) => dyn_eq!(Utf8Scalar, lhs, rhs), - Binary(_) => dyn_eq!(BinaryScalar, lhs, rhs), - LocalTime(_, _) => dyn_eq!(LocalTimeScalar, lhs, rhs), - _ => todo!("Equal not yet implemented for {:?} {:?}", lhs, rhs), - } -} diff --git a/vortex-array/src/scalar/list.rs b/vortex-array/src/scalar/list.rs index 9c3b2435be..012d401e82 100644 --- a/vortex-array/src/scalar/list.rs +++ b/vortex-array/src/scalar/list.rs @@ -1,142 +1,117 @@ -use std::any::Any; use std::fmt::{Display, Formatter}; use itertools::Itertools; -use crate::dtype::{DType, Nullability}; +use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; +use crate::scalar::Scalar; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct ListScalar { dtype: DType, - values: Vec, + values: Option>, } impl ListScalar { #[inline] - pub fn new(dtype: DType, values: Vec) -> Self { + pub fn new(dtype: DType, values: Option>) -> Self { Self { dtype, values } } #[inline] - pub fn values(&self) -> &[ScalarRef] { - &self.values - } -} - -impl Scalar for ListScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) + pub fn values(&self) -> Option<&[Scalar]> { + self.values.as_deref() } #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) - } - #[inline] - fn dtype(&self) -> &DType { + pub fn dtype(&self) -> &DType { &self.dtype } - fn cast(&self, dtype: &DType) -> VortexResult { + pub fn cast(&self, dtype: &DType) -> VortexResult { match dtype { DType::List(field_dtype, n) => { - let new_fields: Vec = self - .values - .iter() - .map(|field| field.cast(field_dtype)) - .try_collect()?; - - let new_type = if new_fields.is_empty() { - dtype.clone() + let new_fields: Option> = self + .values() + .map(|v| v.iter().map(|field| field.cast(field_dtype)).try_collect()) + .transpose()?; + + let new_type = if let Some(nf) = new_fields.as_ref() { + if nf.is_empty() { + dtype.clone() + } else { + DType::List(Box::new(nf[0].dtype().clone()), *n) + } } else { - DType::List(Box::new(new_fields[0].dtype().clone()), *n) + dtype.clone() }; - let list_scalar = ListScalar::new(new_type, new_fields).boxed(); - match n { - Nullability::NonNullable => Ok(list_scalar), - Nullability::Nullable => Ok(NullableScalar::some(list_scalar).boxed()), - } + Ok(ListScalar::new(new_type, new_fields).into()) } _ => Err(VortexError::InvalidDType(dtype.clone())), } } - fn nbytes(&self) -> usize { - self.values.iter().map(|s| s.nbytes()).sum() + pub fn nbytes(&self) -> usize { + self.values() + .map(|v| v.iter().map(|s| s.nbytes()).sum()) + .unwrap_or(0) } } #[derive(Debug, Clone, PartialEq)] pub struct ListScalarVec(pub Vec); -impl> From> for ScalarRef { +impl> From> for Scalar { fn from(value: ListScalarVec) -> Self { - let values: Vec = value.0.into_iter().map(|v| v.into()).collect(); + let values: Vec = value.0.into_iter().map(|v| v.into()).collect(); if values.is_empty() { panic!("Can't implicitly convert empty list into ListScalar"); } - ListScalar::new(values[0].dtype().clone(), values).boxed() + ListScalar::new(values[0].dtype().clone(), Some(values)).into() } } -impl> TryFrom<&dyn Scalar> for ListScalarVec { +impl> TryFrom for ListScalarVec { type Error = VortexError; - fn try_from(value: &dyn Scalar) -> Result { - if let Some(list_s) = value.as_any().downcast_ref::() { - Ok(ListScalarVec( - list_s - .values - .clone() - .into_iter() - .map(|v| v.try_into()) - .try_collect()?, - )) + fn try_from(value: Scalar) -> Result { + if let Scalar::List(ls) = value { + if let Some(vs) = ls.values { + Ok(ListScalarVec( + vs.into_iter().map(|v| v.try_into()).try_collect()?, + )) + } else { + Err(VortexError::InvalidDType(ls.dtype().clone())) + } } else { Err(VortexError::InvalidDType(value.dtype().clone())) } } } -impl> TryFrom for ListScalarVec { +impl<'a, T: TryFrom<&'a Scalar, Error = VortexError>> TryFrom<&'a Scalar> for ListScalarVec { type Error = VortexError; - fn try_from(value: ScalarRef) -> Result { - let value_dtype = value.dtype().clone(); - let list_s = value - .into_any() - .downcast::() - .map_err(|_| VortexError::InvalidDType(value_dtype))?; - - Ok(ListScalarVec( - list_s - .values - .into_iter() - .map(|v| v.try_into()) - .try_collect()?, - )) + fn try_from(value: &'a Scalar) -> Result { + if let Scalar::List(ls) = value { + if let Some(vs) = ls.values() { + Ok(ListScalarVec( + vs.iter().map(|v| v.try_into()).try_collect()?, + )) + } else { + Err(VortexError::InvalidDType(ls.dtype().clone())) + } + } else { + Err(VortexError::InvalidDType(value.dtype().clone())) + } } } impl Display for ListScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.values.iter().format(", ")) + match self.values() { + None => write!(f, ""), + Some(vs) => write!(f, "{}", vs.iter().format(", ")), + } } } diff --git a/vortex-array/src/scalar/localtime.rs b/vortex-array/src/scalar/localtime.rs index 16a533bfdf..512137084f 100644 --- a/vortex-array/src/scalar/localtime.rs +++ b/vortex-array/src/scalar/localtime.rs @@ -1,25 +1,25 @@ -use crate::dtype::{DType, Nullability, TimeUnit}; -use crate::error::VortexResult; -use crate::scalar::{PScalar, Scalar, ScalarRef}; -use std::any::Any; use std::cmp::Ordering; use std::fmt::{Display, Formatter}; +use crate::dtype::{DType, Nullability, TimeUnit}; +use crate::error::VortexResult; +use crate::scalar::{PrimitiveScalar, Scalar}; + #[derive(Debug, Clone, PartialEq)] pub struct LocalTimeScalar { - value: PScalar, + value: PrimitiveScalar, dtype: DType, } impl LocalTimeScalar { - pub fn new(value: PScalar, unit: TimeUnit) -> Self { + pub fn new(value: PrimitiveScalar, unit: TimeUnit) -> Self { Self { value, dtype: DType::LocalTime(unit, Nullability::NonNullable), } } - pub fn value(&self) -> &PScalar { + pub fn value(&self) -> &PrimitiveScalar { &self.value } @@ -29,44 +29,17 @@ impl LocalTimeScalar { }; u } -} - -impl Scalar for LocalTimeScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) - } #[inline] - fn dtype(&self) -> &DType { + pub fn dtype(&self) -> &DType { &self.dtype } - fn cast(&self, _dtype: &DType) -> VortexResult { + pub fn cast(&self, _dtype: &DType) -> VortexResult { todo!() } - fn nbytes(&self) -> usize { + pub fn nbytes(&self) -> usize { self.value.nbytes() } } diff --git a/vortex-array/src/scalar/mod.rs b/vortex-array/src/scalar/mod.rs index ecfb0dc771..4bb0c78083 100644 --- a/vortex-array/src/scalar/mod.rs +++ b/vortex-array/src/scalar/mod.rs @@ -1,57 +1,144 @@ -use std::any::Any; -use std::fmt::{Debug, Display}; +use std::fmt::{Debug, Display, Formatter}; pub use binary::*; pub use bool::*; pub use list::*; pub use localtime::*; pub use null::*; -pub use nullable::*; pub use primitive::*; pub use serde::*; pub use struct_::*; pub use utf8::*; -use crate::dtype::DType; +use crate::dtype::{DType, FloatWidth, IntWidth, Signedness}; use crate::error::VortexResult; -use crate::ptype::NativePType; +use crate::ptype::{NativePType, PType}; -mod arrow; mod binary; mod bool; -mod equal; mod list; mod localtime; mod null; -mod nullable; -mod ord; mod primitive; mod serde; mod struct_; mod utf8; -pub type ScalarRef = Box; - -pub trait Scalar: Display + Debug + dyn_clone::DynClone + Send + Sync + 'static { - fn as_any(&self) -> &dyn Any; - - fn into_any(self: Box) -> Box; +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub enum Scalar { + Binary(BinaryScalar), + Bool(BoolScalar), + List(ListScalar), + LocalTime(LocalTimeScalar), + Null(NullScalar), + Primitive(PrimitiveScalar), + Struct(StructScalar), + Utf8(Utf8Scalar), +} - fn as_nonnull(&self) -> Option<&dyn Scalar>; +macro_rules! impls_for_scalars { + ($variant:tt, $E:ty) => { + impl From<$E> for Scalar { + fn from(arr: $E) -> Self { + Self::$variant(arr) + } + } + }; +} - fn into_nonnull(self: Box) -> Option; +impls_for_scalars!(Binary, BinaryScalar); +impls_for_scalars!(Bool, BoolScalar); +impls_for_scalars!(List, ListScalar); +impls_for_scalars!(LocalTime, LocalTimeScalar); +impls_for_scalars!(Null, NullScalar); +impls_for_scalars!(Primitive, PrimitiveScalar); +impls_for_scalars!(Struct, StructScalar); +impls_for_scalars!(Utf8, Utf8Scalar); + +macro_rules! match_each_scalar { + ($self:expr, | $_:tt $scalar:ident | $($body:tt)*) => ({ + macro_rules! __with_scalar__ {( $_ $scalar:ident ) => ( $($body)* )} + match $self { + Scalar::Binary(s) => __with_scalar__! { s }, + Scalar::Bool(s) => __with_scalar__! { s }, + Scalar::List(s) => __with_scalar__! { s }, + Scalar::LocalTime(s) => __with_scalar__! { s }, + Scalar::Null(s) => __with_scalar__! { s }, + Scalar::Primitive(s) => __with_scalar__! { s }, + Scalar::Struct(s) => __with_scalar__! { s }, + Scalar::Utf8(s) => __with_scalar__! { s }, + } + }) +} - fn boxed(self) -> ScalarRef; +impl Scalar { + pub fn dtype(&self) -> &DType { + match_each_scalar! { self, |$s| $s.dtype() } + } - /// the logical type. - fn dtype(&self) -> &DType; + pub fn cast(&self, dtype: &DType) -> VortexResult { + match_each_scalar! { self, |$s| $s.cast(dtype) } + } - fn cast(&self, dtype: &DType) -> VortexResult; + pub fn nbytes(&self) -> usize { + match_each_scalar! { self, |$s| $s.nbytes() } + } - fn nbytes(&self) -> usize; + pub fn null(dtype: &DType) -> Self { + match dtype { + DType::Null => NullScalar::new().into(), + DType::Bool(_) => BoolScalar::new(None).into(), + DType::Int(w, s, _) => match (w, s) { + (IntWidth::Unknown, Signedness::Unknown | Signedness::Signed) => { + PrimitiveScalar::none(PType::I64).into() + } + (IntWidth::_8, Signedness::Unknown | Signedness::Signed) => { + PrimitiveScalar::none(PType::I8).into() + } + (IntWidth::_16, Signedness::Unknown | Signedness::Signed) => { + PrimitiveScalar::none(PType::I16).into() + } + (IntWidth::_32, Signedness::Unknown | Signedness::Signed) => { + PrimitiveScalar::none(PType::I32).into() + } + (IntWidth::_64, Signedness::Unknown | Signedness::Signed) => { + PrimitiveScalar::none(PType::I64).into() + } + (IntWidth::Unknown, Signedness::Unsigned) => { + PrimitiveScalar::none(PType::U64).into() + } + (IntWidth::_8, Signedness::Unsigned) => PrimitiveScalar::none(PType::U8).into(), + (IntWidth::_16, Signedness::Unsigned) => PrimitiveScalar::none(PType::U16).into(), + (IntWidth::_32, Signedness::Unsigned) => PrimitiveScalar::none(PType::U32).into(), + (IntWidth::_64, Signedness::Unsigned) => PrimitiveScalar::none(PType::U64).into(), + }, + DType::Decimal(_, _, _) => unimplemented!("DecimalScalar"), + DType::Float(w, _) => match w { + FloatWidth::Unknown => PrimitiveScalar::none(PType::F64).into(), + FloatWidth::_16 => PrimitiveScalar::none(PType::F16).into(), + FloatWidth::_32 => PrimitiveScalar::none(PType::F32).into(), + FloatWidth::_64 => PrimitiveScalar::none(PType::F64).into(), + }, + DType::Utf8(_) => Utf8Scalar::new(None).into(), + DType::Binary(_) => BinaryScalar::new(None).into(), + DType::LocalTime(u, _) => { + LocalTimeScalar::new(PrimitiveScalar::none(PType::U64), *u).into() + } + DType::LocalDate(_) => unimplemented!("LocalDateScalar"), + DType::Instant(_, _) => unimplemented!("InstantScalar"), + DType::ZonedDateTime(_, _) => unimplemented!("ZonedDateTimeScalar"), + DType::Struct(_, _) => StructScalar::new(dtype.clone(), vec![]).into(), + DType::List(_, _) => ListScalar::new(dtype.clone(), None).into(), + DType::Map(_, _, _) => unimplemented!("MapScalar"), + } + } } -dyn_clone::clone_trait_object!(Scalar); +impl Display for Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match_each_scalar! { self, |$s| Display::fmt($s, f) } + } +} /// Allows conversion from Enc scalars to a byte slice. pub trait AsBytes { diff --git a/vortex-array/src/scalar/null.rs b/vortex-array/src/scalar/null.rs index 6d28622992..817dbecccd 100644 --- a/vortex-array/src/scalar/null.rs +++ b/vortex-array/src/scalar/null.rs @@ -1,11 +1,10 @@ -use std::any::Any; use std::fmt::{Display, Formatter}; use crate::dtype::DType; use crate::error::VortexResult; -use crate::scalar::{NullableScalar, Scalar, ScalarRef}; +use crate::scalar::Scalar; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct NullScalar; impl Default for NullScalar { @@ -19,44 +18,17 @@ impl NullScalar { pub fn new() -> Self { Self {} } -} - -impl Scalar for NullScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - None - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - None - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) - } #[inline] - fn dtype(&self) -> &DType { + pub fn dtype(&self) -> &DType { &DType::Null } - fn cast(&self, dtype: &DType) -> VortexResult { - Ok(NullableScalar::none(dtype.clone()).boxed()) + pub fn cast(&self, _dtype: &DType) -> VortexResult { + todo!() } - fn nbytes(&self) -> usize { + pub fn nbytes(&self) -> usize { 1 } } diff --git a/vortex-array/src/scalar/nullable.rs b/vortex-array/src/scalar/nullable.rs deleted file mode 100644 index 92815f7667..0000000000 --- a/vortex-array/src/scalar/nullable.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::any::Any; -use std::fmt::{Display, Formatter}; -use std::mem::size_of; - -use crate::dtype::DType; -use crate::error::{VortexError, VortexResult}; -use crate::scalar::{NullScalar, Scalar, ScalarRef}; - -#[derive(Debug, Clone, PartialEq, PartialOrd)] -pub enum NullableScalar { - None(DType), - Some(ScalarRef, DType), -} - -impl NullableScalar { - pub fn some(scalar: ScalarRef) -> Self { - let dtype = scalar.dtype().as_nullable(); - Self::Some(scalar, dtype) - } - - pub fn none(dtype: DType) -> Self { - Self::None(dtype.as_nullable()) - } -} - -impl Scalar for NullableScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - match self { - Self::Some(s, _) => Some(s.as_ref()), - Self::None(_) => None, - } - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - match *self { - Self::Some(s, _) => Some(s), - Self::None(_) => None, - } - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) - } - - #[inline] - fn dtype(&self) -> &DType { - match self { - Self::Some(_, dtype) => dtype, - Self::None(dtype) => dtype, - } - } - - fn cast(&self, dtype: &DType) -> VortexResult { - match self { - Self::Some(s, _dt) => { - if dtype.is_nullable() { - Ok(Self::Some(s.cast(&dtype.as_nonnullable())?, dtype.clone()).boxed()) - } else { - s.cast(&dtype.as_nonnullable()) - } - } - Self::None(_dt) => { - if dtype.is_nullable() { - Ok(Self::None(dtype.clone()).boxed()) - } else { - Err(VortexError::InvalidDType(dtype.clone())) - } - } - } - } - - fn nbytes(&self) -> usize { - match self { - NullableScalar::Some(s, _) => s.nbytes() + size_of::(), - NullableScalar::None(_) => size_of::(), - } - } -} - -impl Display for NullableScalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - NullableScalar::Some(p, _) => write!(f, "{}?", p), - NullableScalar::None(_) => write!(f, "null"), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct NullableScalarOption(pub Option); - -impl> From> for ScalarRef { - fn from(value: NullableScalarOption) -> Self { - match value.0 { - // TODO(robert): This should return NullableScalar::None - // but that's not possible with some type that holds the associated dtype - // We need to change the bound of T to be able to get datatype from it. - None => NullScalar::new().boxed(), - Some(v) => NullableScalar::some(v.into()).boxed(), - } - } -} - -impl> TryFrom<&dyn Scalar> for NullableScalarOption { - type Error = VortexError; - - fn try_from(value: &dyn Scalar) -> Result { - let Some(ns) = value.as_any().downcast_ref::() else { - return Err(VortexError::InvalidDType(value.dtype().clone())); - }; - - Ok(NullableScalarOption(match ns { - NullableScalar::None(_) => None, - NullableScalar::Some(v, _) => Some(v.clone().try_into()?), - })) - } -} - -impl> TryFrom for NullableScalarOption { - type Error = VortexError; - - fn try_from(value: ScalarRef) -> Result { - let dtype = value.dtype().clone(); - let ns = value - .into_any() - .downcast::() - .map_err(|_| VortexError::InvalidDType(dtype))?; - - Ok(NullableScalarOption(match *ns { - NullableScalar::None(_) => None, - NullableScalar::Some(v, _) => Some(v.try_into()?), - })) - } -} - -#[cfg(test)] -mod tests { - use crate::dtype::DType; - use crate::ptype::PType; - use crate::scalar::Scalar; - - #[test] - fn test_nullable_scalar_option() { - let ns: Box = Some(10i16).into(); - let nsi32 = ns.cast(&DType::from(PType::I32)).unwrap(); - let v: i32 = nsi32.try_into().unwrap(); - assert_eq!(v, 10); - } -} diff --git a/vortex-array/src/scalar/ord.rs b/vortex-array/src/scalar/ord.rs deleted file mode 100644 index cee8b2a4e0..0000000000 --- a/vortex-array/src/scalar/ord.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::scalar::{ - BinaryScalar, BoolScalar, LocalTimeScalar, NullableScalar, PScalar, Scalar, ScalarRef, - StructScalar, Utf8Scalar, -}; -use std::cmp::Ordering; -use std::sync::Arc; -macro_rules! dyn_ord { - ($ty:ty, $lhs:expr, $rhs:expr) => {{ - let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap(); - let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap(); - if lhs < rhs { - Ordering::Less - } else if lhs == rhs { - Ordering::Equal - } else { - Ordering::Greater - } - }}; -} - -fn cmp(lhs: &dyn Scalar, rhs: &dyn Scalar) -> Option { - if lhs.dtype() != rhs.dtype() { - return None; - } - - // If the dtypes are the same then both of the scalars are either nullable or plain scalar - if let Some(ls) = lhs.as_any().downcast_ref::() { - if let Some(rs) = rhs.as_any().downcast_ref::() { - return Some(dyn_ord!(NullableScalar, ls, rs)); - } else { - unreachable!("DTypes were equal, but only one was nullable") - } - } - - use crate::dtype::DType::*; - Some(match lhs.dtype() { - Bool(_) => dyn_ord!(BoolScalar, lhs, rhs), - Int(_, _, _) => dyn_ord!(PScalar, lhs, rhs), - Float(_, _) => dyn_ord!(PScalar, lhs, rhs), - Struct(..) => dyn_ord!(StructScalar, lhs, rhs), - Utf8(_) => dyn_ord!(Utf8Scalar, lhs, rhs), - Binary(_) => dyn_ord!(BinaryScalar, lhs, rhs), - LocalTime(_, _) => dyn_ord!(LocalTimeScalar, lhs, rhs), - _ => todo!("Cmp not yet implemented for {:?} {:?}", lhs, rhs), - }) -} - -impl PartialOrd for dyn Scalar { - fn partial_cmp(&self, that: &Self) -> Option { - cmp(self, that) - } -} - -impl PartialOrd for ScalarRef { - fn partial_cmp(&self, that: &dyn Scalar) -> Option { - cmp(self.as_ref(), that) - } -} - -impl PartialOrd for Arc { - fn partial_cmp(&self, that: &dyn Scalar) -> Option { - cmp(&**self, that) - } -} diff --git a/vortex-array/src/scalar/primitive.rs b/vortex-array/src/scalar/primitive.rs index 7e2c170308..4df4a84d21 100644 --- a/vortex-array/src/scalar/primitive.rs +++ b/vortex-array/src/scalar/primitive.rs @@ -1,4 +1,3 @@ -use std::any::Any; use std::fmt::{Display, Formatter}; use std::mem::size_of; @@ -7,9 +6,81 @@ use half::f16; use crate::dtype::{DType, Nullability}; use crate::error::{VortexError, VortexResult}; use crate::ptype::{NativePType, PType}; -use crate::scalar::{LocalTimeScalar, NullableScalar, Scalar, ScalarRef}; +use crate::scalar::{LocalTimeScalar, Scalar}; #[derive(Debug, Clone, PartialEq, PartialOrd)] +pub struct PrimitiveScalar { + ptype: PType, + value: Option, +} + +impl PrimitiveScalar { + pub fn new(ptype: PType, value: Option) -> Self { + Self { ptype, value } + } + + pub fn some(value: PScalar) -> Self { + Self { + ptype: value.ptype(), + value: Some(value), + } + } + + pub fn none(ptype: PType) -> Self { + Self { ptype, value: None } + } + + #[inline] + pub fn value(&self) -> Option { + self.value + } + + #[inline] + pub fn ptype(&self) -> PType { + self.ptype + } + + #[inline] + pub fn dtype(&self) -> &DType { + self.ptype.into() + } + + pub fn cast(&self, dtype: &DType) -> VortexResult { + let ptype: VortexResult = dtype.try_into(); + ptype + .and_then(|p| match self.value() { + None => Ok(PrimitiveScalar::none(p).into()), + Some(ps) => ps.cast_ptype(p), + }) + .or_else(|_| self.cast_dtype(dtype)) + } + + // General conversion function that handles casting primitive scalar to non-primitive. + // If target dtype can be converted to ptype you should use cast_ptype. + pub fn cast_dtype(&self, dtype: &DType) -> VortexResult { + match dtype { + DType::LocalTime(w, Nullability::NonNullable) => { + Ok(LocalTimeScalar::new(self.clone(), *w).into()) + } + _ => Err(VortexError::InvalidDType(dtype.clone())), + } + } + + pub fn nbytes(&self) -> usize { + size_of::() + } +} + +impl Display for PrimitiveScalar { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.value() { + None => write!(f, "({}?)", self.ptype), + Some(v) => write!(f, "{}({})", v, self.ptype), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] pub enum PScalar { U8(u8), U16(u16), @@ -41,30 +112,7 @@ impl PScalar { } } - // General conversion function that handles casting primitive scalar to non primitive. - // If target dtype can be converted to ptype you should use cast_ptype. - pub fn cast_dtype(&self, dtype: DType) -> VortexResult { - macro_rules! from_int { - ($dtype:ident , $ps:ident) => { - match $dtype { - DType::LocalTime(w, Nullability::NonNullable) => { - Ok(LocalTimeScalar::new($ps.clone(), w.clone()).boxed()) - } - _ => Err(VortexError::InvalidDType($dtype.clone())), - } - }; - } - - match self { - p @ PScalar::U32(_) - | p @ PScalar::U64(_) - | p @ PScalar::I32(_) - | p @ PScalar::I64(_) => from_int!(dtype, p), - _ => Err(VortexError::InvalidDType(dtype.clone())), - } - } - - pub fn cast_ptype(&self, ptype: PType) -> VortexResult { + pub fn cast_ptype(&self, ptype: PType) -> VortexResult { macro_rules! from_int { ($ptype:ident , $v:ident) => { match $ptype { @@ -120,49 +168,6 @@ fn is_negative(value: T) -> bool { value < T::default() } -impl Scalar for PScalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) - } - - #[inline] - fn dtype(&self) -> &DType { - self.ptype().into() - } - - fn cast(&self, dtype: &DType) -> VortexResult { - let ptype: VortexResult = dtype.try_into(); - ptype - .and_then(|p| self.cast_ptype(p)) - .or_else(|_| self.cast_dtype(dtype.clone())) - } - - fn nbytes(&self) -> usize { - size_of::() - } -} - macro_rules! pscalar { ($T:ty, $ptype:tt) => { impl From<$T> for PScalar { @@ -171,35 +176,42 @@ macro_rules! pscalar { } } - impl From<$T> for ScalarRef { + impl From<$T> for Scalar { fn from(value: $T) -> Self { - PScalar::from(value).boxed() + PrimitiveScalar::some(PScalar::from(value)).into() } } - impl TryFrom for $T { + impl TryFrom<&Scalar> for $T { type Error = VortexError; - #[inline] - fn try_from(value: ScalarRef) -> VortexResult { - value.as_ref().try_into() + fn try_from(value: &Scalar) -> VortexResult { + match value { + Scalar::Primitive(PrimitiveScalar { + ptype: _, + value: Some(pscalar), + }) => match pscalar { + PScalar::$ptype(v) => Ok(*v), + _ => Err(VortexError::InvalidDType(pscalar.ptype().into())), + }, + _ => Err(VortexError::InvalidDType(value.dtype().clone())), + } } } - impl TryFrom<&dyn Scalar> for $T { + impl TryFrom for $T { type Error = VortexError; - fn try_from(value: &dyn Scalar) -> VortexResult { - if let Some(pscalar) = value - .as_nonnull() - .and_then(|v| v.as_any().downcast_ref::()) - { - match pscalar { - PScalar::$ptype(v) => Ok(*v), + fn try_from(value: Scalar) -> VortexResult { + match value { + Scalar::Primitive(PrimitiveScalar { + ptype: _, + value: Some(pscalar), + }) => match pscalar { + PScalar::$ptype(v) => Ok(v), _ => Err(VortexError::InvalidDType(pscalar.ptype().into())), - } - } else { - Err(VortexError::InvalidDType(value.dtype().clone())) + }, + _ => Err(VortexError::InvalidDType(value.dtype().clone())), } } } @@ -218,34 +230,62 @@ pscalar!(f16, F16); pscalar!(f32, F32); pscalar!(f64, F64); -impl From> for ScalarRef { +impl From> for Scalar { fn from(value: Option) -> Self { match value { Some(value) => value.into(), - None => Box::new(NullableScalar::None(DType::from(T::PTYPE))), + None => PrimitiveScalar::new(T::PTYPE, None).into(), } } } -impl From for ScalarRef { +impl From for Scalar { #[inline] fn from(value: usize) -> Self { - PScalar::U64(value as u64).boxed() + PrimitiveScalar::new(PType::U64, Some(PScalar::U64(value as u64))).into() } } -impl TryFrom for usize { +impl TryFrom for usize { type Error = VortexError; - fn try_from(value: ScalarRef) -> VortexResult { - value.as_ref().try_into() + fn try_from(value: Scalar) -> VortexResult { + macro_rules! match_each_pscalar_integer { + ($self:expr, | $_:tt $pscalar:ident | $($body:tt)*) => ({ + macro_rules! __with_pscalar__ {( $_ $pscalar:ident ) => ( $($body)* )} + match $self { + PScalar::U8(v) => __with_pscalar__! { v }, + PScalar::U16(v) => __with_pscalar__! { v }, + PScalar::U32(v) => __with_pscalar__! { v }, + PScalar::U64(v) => __with_pscalar__! { v }, + PScalar::I8(v) => __with_pscalar__! { v }, + PScalar::I16(v) => __with_pscalar__! { v }, + PScalar::I32(v) => __with_pscalar__! { v }, + PScalar::I64(v) => __with_pscalar__! { v }, + _ => Err(VortexError::InvalidDType($self.ptype().into())), + } + }) + } + + match value { + Scalar::Primitive(PrimitiveScalar { + ptype: _, + value: Some(pscalar), + }) => match_each_pscalar_integer!(pscalar, |$V| { + if is_negative($V) { + return Err(VortexError::ComputeError("required positive integer".into())); + } + Ok($V as usize) + }), + _ => Err(VortexError::InvalidDType(value.dtype().clone())), + } } } -impl TryFrom<&dyn Scalar> for usize { +impl TryFrom<&Scalar> for usize { type Error = VortexError; - fn try_from(value: &dyn Scalar) -> VortexResult { + fn try_from(value: &Scalar) -> VortexResult { macro_rules! match_each_pscalar_integer { ($self:expr, | $_:tt $pscalar:ident | $($body:tt)*) => ({ macro_rules! __with_pscalar__ {( $_ $pscalar:ident ) => ( $($body)* )} @@ -263,18 +303,17 @@ impl TryFrom<&dyn Scalar> for usize { }) } - if let Some(pscalar) = value - .as_nonnull() - .and_then(|v| v.as_any().downcast_ref::()) - { - match_each_pscalar_integer!(pscalar, |$V| { + match value { + Scalar::Primitive(PrimitiveScalar { + ptype: _, + value: Some(pscalar), + }) => match_each_pscalar_integer!(pscalar, |$V| { if is_negative(*$V) { return Err(VortexError::ComputeError("required positive integer".into())); } Ok(*$V as usize) - }) - } else { - Err(VortexError::InvalidDType(value.dtype().clone())) + }), + _ => Err(VortexError::InvalidDType(value.dtype().clone())), } } } @@ -302,18 +341,18 @@ mod test { use crate::dtype::{DType, IntWidth, Nullability, Signedness}; use crate::error::VortexError; use crate::ptype::PType; - use crate::scalar::ScalarRef; + use crate::scalar::Scalar; #[test] fn into_from() { - let scalar: ScalarRef = 10u16.into(); - assert_eq!(scalar.as_ref().try_into(), Ok(10u16)); + let scalar: Scalar = 10u16.into(); + assert_eq!(scalar.clone().try_into(), Ok(10u16)); // All integers should be convertible to usize - assert_eq!(scalar.as_ref().try_into(), Ok(10usize)); + assert_eq!(scalar.try_into(), Ok(10usize)); - let scalar: ScalarRef = (-10i16).into(); + let scalar: Scalar = (-10i16).into(); assert_eq!( - scalar.as_ref().try_into(), + scalar.try_into(), Err::(VortexError::ComputeError( "required positive integer".into() )) @@ -322,7 +361,7 @@ mod test { #[test] fn cast() { - let scalar: ScalarRef = 10u16.into(); + let scalar: Scalar = 10u16.into(); let u32_scalar = scalar .cast(&DType::Int( IntWidth::_32, diff --git a/vortex-array/src/scalar/serde.rs b/vortex-array/src/scalar/serde.rs index 68012d9050..b426dd1819 100644 --- a/vortex-array/src/scalar/serde.rs +++ b/vortex-array/src/scalar/serde.rs @@ -1,130 +1,136 @@ use std::io; -use std::io::{ErrorKind, Read}; +use std::sync::Arc; use half::f16; use num_enum::{IntoPrimitive, TryFromPrimitive}; -use crate::dtype::{DType, FloatWidth, IntWidth, Signedness, TimeUnit}; +use crate::dtype::{DType, TimeUnit}; +use crate::ptype::PType; use crate::scalar::{ - BinaryScalar, BoolScalar, ListScalar, LocalTimeScalar, NullScalar, NullableScalar, PScalar, - Scalar, ScalarRef, StructScalar, Utf8Scalar, + BinaryScalar, BoolScalar, ListScalar, LocalTimeScalar, NullScalar, PScalar, PrimitiveScalar, + Scalar, StructScalar, Utf8Scalar, }; -use crate::serde::{DTypeReader, TimeUnitTag, WriteCtx}; +use crate::serde::{ReadCtx, TimeUnitTag, WriteCtx}; -pub struct ScalarReader<'a> { - reader: &'a mut dyn Read, +pub struct ScalarReader<'a, 'b> { + reader: &'b mut ReadCtx<'a>, } -impl<'a> ScalarReader<'a> { - pub fn new(reader: &'a mut dyn Read) -> Self { +impl<'a, 'b> ScalarReader<'a, 'b> { + pub fn new(reader: &'b mut ReadCtx<'a>) -> Self { Self { reader } } - fn read_nbytes(&mut self) -> io::Result<[u8; N]> { - let mut bytes: [u8; N] = [0; N]; - self.reader.read_exact(&mut bytes)?; - Ok(bytes) - } - - pub fn read(&mut self) -> io::Result { - let tag = ScalarTag::try_from(self.read_nbytes::<1>()?[0]) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; + pub fn read(&mut self) -> io::Result { + let tag = ScalarTag::try_from(self.reader.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; match tag { ScalarTag::Binary => { - let len = leb128::read::unsigned(self.reader) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; - let mut value = Vec::::with_capacity(len as usize); - self.reader.take(len).read_to_end(&mut value)?; - Ok(BinaryScalar::new(value).boxed()) - } - ScalarTag::Bool => Ok(BoolScalar::new(self.read_nbytes::<1>()?[0] != 0).boxed()), - ScalarTag::F16 => { - Ok(PScalar::F16(f16::from_le_bytes(self.read_nbytes::<2>()?)).boxed()) - } - ScalarTag::F32 => { - Ok(PScalar::F32(f32::from_le_bytes(self.read_nbytes::<4>()?)).boxed()) - } - ScalarTag::F64 => { - Ok(PScalar::F64(f64::from_le_bytes(self.read_nbytes::<8>()?)).boxed()) - } - ScalarTag::I16 => { - Ok(PScalar::I16(i16::from_le_bytes(self.read_nbytes::<2>()?)).boxed()) - } - ScalarTag::I32 => { - Ok(PScalar::I32(i32::from_le_bytes(self.read_nbytes::<4>()?)).boxed()) + let slice = self.reader.read_optional_slice()?; + Ok(BinaryScalar::new(slice).into()) } - ScalarTag::I64 => { - Ok(PScalar::I64(i64::from_le_bytes(self.read_nbytes::<8>()?)).boxed()) + ScalarTag::Bool => { + let is_present = self.reader.read_option_tag()?; + if is_present { + Ok(BoolScalar::some(self.reader.read_nbytes::<1>()?[0] != 0).into()) + } else { + Ok(BoolScalar::none().into()) + } } - ScalarTag::I8 => Ok(PScalar::I8(i8::from_le_bytes(self.read_nbytes::<1>()?)).boxed()), + ScalarTag::PrimitiveS => self.read_primitive_scalar().map(|p| p.into()), ScalarTag::List => { - let elems = leb128::read::unsigned(self.reader) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; - if elems == 0 { - let dtype = DTypeReader::new(self.reader).read()?; - Ok(ListScalar::new(dtype, Vec::new()).boxed()) - } else { - let mut values = Vec::::with_capacity(elems as usize); - for value in values.iter_mut() { - *value = self.read()?; + let is_present = self.reader.read_option_tag()?; + if is_present { + let elems = self.reader.read_usize()?; + let mut values = Vec::with_capacity(elems); + for _ in 0..elems { + values.push(self.read()?); } - Ok(ListScalar::new(values[0].dtype().clone(), values).boxed()) + Ok(ListScalar::new(values[0].dtype().clone(), Some(values)).into()) + } else { + Ok(ListScalar::new(self.reader.dtype()?, None).into()) } } ScalarTag::LocalTime => { - let pscalar = self - .read()? - .into_any() - .downcast::() - .map_err(|_e| io::Error::new(ErrorKind::InvalidData, "invalid scalar"))?; - let time_unit = TimeUnitTag::try_from(self.read_nbytes::<1>()?[0]) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + let time_unit = TimeUnitTag::try_from(self.reader.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map(TimeUnit::from)?; + let ps = self.read_primitive_scalar()?; - Ok(LocalTimeScalar::new(*pscalar, time_unit).boxed()) - } - ScalarTag::Null => Ok(NullScalar::new().boxed()), - ScalarTag::Nullable => { - let tag = self.read_nbytes::<1>()?[0]; - match tag { - 0x00 => Ok(NullableScalar::none(DTypeReader::new(self.reader).read()?).boxed()), - 0x01 => Ok(NullableScalar::some(self.read()?).boxed()), - _ => Err(io::Error::new( - ErrorKind::InvalidData, - "Invalid NullableScalar tag", - )), - } + Ok(LocalTimeScalar::new(ps, time_unit).into()) } + ScalarTag::Null => Ok(NullScalar::new().into()), ScalarTag::Struct => { - let dtype = DTypeReader::new(self.reader).read()?; - let DType::Struct(ns, _fs) = &dtype else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid dtype")); - }; - let mut values = Vec::::with_capacity(ns.len()); - for value in values.iter_mut() { - *value = self.read()?; + let field_num = self.reader.read_usize()?; + let mut names = Vec::with_capacity(field_num); + for _ in 0..field_num { + names.push(Arc::new( + self.reader + .read_slice() + .map(|v| unsafe { String::from_utf8_unchecked(v) })?, + )); } - Ok(StructScalar::new(dtype, values).boxed()) - } - ScalarTag::U16 => { - Ok(PScalar::U16(u16::from_le_bytes(self.read_nbytes::<2>()?)).boxed()) - } - ScalarTag::U32 => { - Ok(PScalar::U32(u32::from_le_bytes(self.read_nbytes::<4>()?)).boxed()) - } - ScalarTag::U64 => { - Ok(PScalar::U64(u64::from_le_bytes(self.read_nbytes::<8>()?)).boxed()) + let mut values = Vec::with_capacity(field_num); + for _ in 0..field_num { + values.push(self.read()?); + } + let dtypes = values.iter().map(|s| s.dtype().clone()).collect::>(); + Ok(StructScalar::new(DType::Struct(names, dtypes), values).into()) } - ScalarTag::U8 => Ok(PScalar::U8(u8::from_le_bytes(self.read_nbytes::<1>()?)).boxed()), ScalarTag::Utf8 => { - let len = leb128::read::unsigned(self.reader) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; - let mut value = Vec::::with_capacity(len as usize); - self.reader.take(len).read_to_end(&mut value)?; - Ok(Utf8Scalar::new(unsafe { String::from_utf8_unchecked(value) }).boxed()) + let value = self.reader.read_optional_slice()?; + Ok( + Utf8Scalar::new(value.map(|v| unsafe { String::from_utf8_unchecked(v) })) + .into(), + ) } } } + + fn read_primitive_scalar(&mut self) -> io::Result { + let ptype = self.reader.ptype()?; + let is_present = self.reader.read_option_tag()?; + if is_present { + let pscalar = match ptype { + PType::U8 => PrimitiveScalar::some(PScalar::U8(u8::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::U16 => PrimitiveScalar::some(PScalar::U16(u16::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::U32 => PrimitiveScalar::some(PScalar::U32(u32::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::U64 => PrimitiveScalar::some(PScalar::U64(u64::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::I8 => PrimitiveScalar::some(PScalar::I8(i8::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::I16 => PrimitiveScalar::some(PScalar::I16(i16::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::I32 => PrimitiveScalar::some(PScalar::I32(i32::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::I64 => PrimitiveScalar::some(PScalar::I64(i64::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::F16 => PrimitiveScalar::some(PScalar::F16(f16::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::F32 => PrimitiveScalar::some(PScalar::F32(f32::from_le_bytes( + self.reader.read_nbytes()?, + ))), + PType::F64 => PrimitiveScalar::some(PScalar::F64(f64::from_le_bytes( + self.reader.read_nbytes()?, + ))), + }; + Ok(pscalar) + } else { + Ok(PrimitiveScalar::none(ptype)) + } + } } pub struct ScalarWriter<'a, 'b> { @@ -136,127 +142,73 @@ impl<'a, 'b> ScalarWriter<'a, 'b> { Self { writer } } - pub fn write(&mut self, scalar: &dyn Scalar) -> io::Result<()> { - let tag = ScalarTag::from(scalar); - self.writer.write_fixed_slice([tag.into()])?; - match tag { - ScalarTag::Binary => { - let binary = scalar.as_any().downcast_ref::().unwrap(); - self.writer.write_slice(binary.value().as_slice()) - } - ScalarTag::Bool => self.writer.write_fixed_slice([scalar - .as_any() - .downcast_ref::() - .unwrap() - .value() as u8]), - ScalarTag::F16 => { - let PScalar::F16(f) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(f.to_le_bytes()) - } - ScalarTag::F32 => { - let PScalar::F32(f) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(f.to_le_bytes()) - } - ScalarTag::F64 => { - let PScalar::F64(f) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(f.to_le_bytes()) - } - ScalarTag::I16 => { - let PScalar::I16(i) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(i.to_le_bytes()) - } - ScalarTag::I32 => { - let PScalar::I32(i) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(i.to_le_bytes()) - } - ScalarTag::I64 => { - let PScalar::I64(i) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(i.to_le_bytes()) - } - ScalarTag::I8 => { - let PScalar::I8(i) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(i.to_le_bytes()) + pub fn write(&mut self, scalar: &Scalar) -> io::Result<()> { + self.writer + .write_fixed_slice([ScalarTag::from(scalar).into()])?; + match scalar { + Scalar::Binary(b) => self.writer.write_optional_slice(b.value()), + Scalar::Bool(b) => { + self.writer.write_option_tag(b.value().is_some())?; + if let Some(v) = b.value() { + self.writer.write_fixed_slice([v as u8])?; + } + Ok(()) } - ScalarTag::List => { - let ls = scalar.as_any().downcast_ref::().unwrap(); - self.writer.write_usize(ls.values().len())?; - if ls.values().is_empty() { - self.writer.dtype(ls.dtype())?; - Ok(()) - } else { - for elem in ls.values() { - self.write(elem.as_ref())?; + Scalar::List(ls) => { + self.writer.write_option_tag(ls.values().is_some())?; + if let Some(vs) = ls.values() { + self.writer.write_usize(vs.len())?; + for elem in vs { + self.write(elem)?; } - Ok(()) + } else { + self.writer.dtype(ls.dtype())?; } + Ok(()) } - ScalarTag::LocalTime => { - let lt = scalar.as_any().downcast_ref::().unwrap(); - self.write(lt.value())?; - self.writer - .write_fixed_slice([TimeUnitTag::from(lt.time_unit()).into()]) - } - ScalarTag::Null => Ok(()), - ScalarTag::Nullable => { - let ns = scalar.as_any().downcast_ref::().unwrap(); + Scalar::LocalTime(lt) => { self.writer - .write_option_tag(matches!(ns, NullableScalar::Some(_, _)))?; - match ns { - NullableScalar::None(d) => self.writer.dtype(d), - NullableScalar::Some(s, _) => self.write(s.as_ref()), + .write_fixed_slice([TimeUnitTag::from(lt.time_unit()).into()])?; + self.write_primitive_scalar(lt.value()) + } + Scalar::Null(_) => Ok(()), + Scalar::Primitive(p) => self.write_primitive_scalar(p), + Scalar::Struct(s) => { + let names = s.names(); + self.writer.write_usize(names.len())?; + for n in names { + self.writer.write_slice(n.as_bytes())?; } - } - ScalarTag::Struct => { - let s = scalar.as_any().downcast_ref::().unwrap(); - self.writer.dtype(s.dtype())?; for field in s.values() { - self.write(field.as_ref())?; + self.write(field)?; } Ok(()) } - ScalarTag::U16 => { - let PScalar::U16(u) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(u.to_le_bytes()) - } - ScalarTag::U32 => { - let PScalar::U32(u) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(u.to_le_bytes()) - } - ScalarTag::U64 => { - let PScalar::U64(u) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(u.to_le_bytes()) - } - ScalarTag::U8 => { - let PScalar::U8(u) = scalar.as_any().downcast_ref::().unwrap() else { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid scalar")); - }; - self.writer.write_fixed_slice(u.to_le_bytes()) - } - ScalarTag::Utf8 => { - let utf8 = scalar.as_any().downcast_ref::().unwrap(); - self.writer.write_slice(utf8.value().as_bytes()) + Scalar::Utf8(u) => self + .writer + .write_optional_slice(u.value().map(|s| s.as_bytes())), + } + } + + fn write_primitive_scalar(&mut self, scalar: &PrimitiveScalar) -> io::Result<()> { + self.writer.ptype(scalar.ptype())?; + self.writer.write_option_tag(scalar.value().is_some())?; + if let Some(ps) = scalar.value() { + match ps { + PScalar::F16(f) => self.writer.write_fixed_slice(f.to_le_bytes())?, + PScalar::F32(f) => self.writer.write_fixed_slice(f.to_le_bytes())?, + PScalar::F64(f) => self.writer.write_fixed_slice(f.to_le_bytes())?, + PScalar::I16(i) => self.writer.write_fixed_slice(i.to_le_bytes())?, + PScalar::I32(i) => self.writer.write_fixed_slice(i.to_le_bytes())?, + PScalar::I64(i) => self.writer.write_fixed_slice(i.to_le_bytes())?, + PScalar::I8(i) => self.writer.write_fixed_slice(i.to_le_bytes())?, + PScalar::U16(u) => self.writer.write_fixed_slice(u.to_le_bytes())?, + PScalar::U32(u) => self.writer.write_fixed_slice(u.to_le_bytes())?, + PScalar::U64(u) => self.writer.write_fixed_slice(u.to_le_bytes())?, + PScalar::U8(u) => self.writer.write_fixed_slice(u.to_le_bytes())?, } } + Ok(()) } } @@ -265,62 +217,26 @@ impl<'a, 'b> ScalarWriter<'a, 'b> { enum ScalarTag { Binary, Bool, - F16, - F32, - F64, - I16, - I32, - I64, - I8, List, LocalTime, Null, - Nullable, + // TODO(robert): rename to primitive once we stop using enum for serialization + PrimitiveS, Struct, - U16, - U32, - U64, - U8, Utf8, } -impl From<&dyn Scalar> for ScalarTag { - fn from(value: &dyn Scalar) -> Self { - if value.dtype().is_nullable() { - return ScalarTag::Nullable; - } - - match value.dtype() { - DType::Null => ScalarTag::Null, - DType::Bool(_) => ScalarTag::Bool, - DType::Int(w, s, _) => match (w, s) { - (IntWidth::Unknown, Signedness::Unknown | Signedness::Signed) => ScalarTag::I64, - (IntWidth::_8, Signedness::Unknown | Signedness::Signed) => ScalarTag::I8, - (IntWidth::_16, Signedness::Unknown | Signedness::Signed) => ScalarTag::I16, - (IntWidth::_32, Signedness::Unknown | Signedness::Signed) => ScalarTag::I32, - (IntWidth::_64, Signedness::Unknown | Signedness::Signed) => ScalarTag::I64, - (IntWidth::Unknown, Signedness::Unsigned) => ScalarTag::U64, - (IntWidth::_8, Signedness::Unsigned) => ScalarTag::U8, - (IntWidth::_16, Signedness::Unsigned) => ScalarTag::U16, - (IntWidth::_32, Signedness::Unsigned) => ScalarTag::U32, - (IntWidth::_64, Signedness::Unsigned) => ScalarTag::U64, - }, - DType::Decimal(_, _, _) => unimplemented!("decimal scalar"), - DType::Float(w, _) => match w { - FloatWidth::Unknown => ScalarTag::F64, - FloatWidth::_16 => ScalarTag::F16, - FloatWidth::_32 => ScalarTag::F32, - FloatWidth::_64 => ScalarTag::F64, - }, - DType::Utf8(_) => ScalarTag::Utf8, - DType::Binary(_) => ScalarTag::Binary, - DType::LocalTime(_, _) => ScalarTag::LocalTime, - DType::LocalDate(_) => unimplemented!("local date"), - DType::Instant(_, _) => unimplemented!("instant scalar"), - DType::ZonedDateTime(_, _) => unimplemented!("zoned date time scalar"), - DType::Struct(_, _) => ScalarTag::Struct, - DType::List(_, _) => ScalarTag::List, - DType::Map(_, _, _) => unimplemented!("map scalar"), +impl From<&Scalar> for ScalarTag { + fn from(value: &Scalar) -> Self { + match value { + Scalar::Binary(_) => ScalarTag::Binary, + Scalar::Bool(_) => ScalarTag::Bool, + Scalar::List(_) => ScalarTag::List, + Scalar::LocalTime(_) => ScalarTag::LocalTime, + Scalar::Null(_) => ScalarTag::Null, + Scalar::Primitive(_) => ScalarTag::PrimitiveS, + Scalar::Struct(_) => ScalarTag::Struct, + Scalar::Utf8(_) => ScalarTag::Utf8, } } } diff --git a/vortex-array/src/scalar/struct_.rs b/vortex-array/src/scalar/struct_.rs index b000099f5a..f097304523 100644 --- a/vortex-array/src/scalar/struct_.rs +++ b/vortex-array/src/scalar/struct_.rs @@ -1,70 +1,50 @@ -use std::any::Any; use std::cmp::Ordering; use std::fmt::{Display, Formatter}; +use std::sync::Arc; use itertools::Itertools; use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{Scalar, ScalarRef}; +use crate::scalar::Scalar; #[derive(Debug, Clone, PartialEq)] pub struct StructScalar { dtype: DType, - values: Vec, + values: Vec, } impl StructScalar { #[inline] - pub fn new(dtype: DType, values: Vec) -> Self { + pub fn new(dtype: DType, values: Vec) -> Self { Self { dtype, values } } #[inline] - pub fn values(&self) -> &[ScalarRef] { - &self.values + pub fn values(&self) -> &[Scalar] { + self.values.as_ref() } -} -impl Scalar for StructScalar { #[inline] - fn as_any(&self) -> &dyn Any { - self - } - - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) - } - - #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) + pub fn dtype(&self) -> &DType { + &self.dtype } - #[inline] - fn dtype(&self) -> &DType { - &self.dtype + pub fn names(&self) -> &[Arc] { + let DType::Struct(ns, _) = self.dtype() else { + unreachable!("Not a scalar dtype"); + }; + ns.as_slice() } - fn cast(&self, dtype: &DType) -> VortexResult { + pub fn cast(&self, dtype: &DType) -> VortexResult { match dtype { DType::Struct(names, field_dtypes) => { if field_dtypes.len() != self.values.len() { return Err(VortexError::InvalidDType(dtype.clone())); } - let new_fields: Vec = self + let new_fields: Vec = self .values .iter() .zip_eq(field_dtypes.iter()) @@ -75,14 +55,14 @@ impl Scalar for StructScalar { names.clone(), new_fields.iter().map(|x| x.dtype().clone()).collect(), ); - Ok(StructScalar::new(new_type, new_fields).boxed()) + Ok(StructScalar::new(new_type, new_fields).into()) } _ => Err(VortexError::InvalidDType(dtype.clone())), } } - fn nbytes(&self) -> usize { - self.values.iter().map(|s| s.nbytes()).sum() + pub fn nbytes(&self) -> usize { + self.values().iter().map(|s| s.nbytes()).sum() } } diff --git a/vortex-array/src/scalar/utf8.rs b/vortex-array/src/scalar/utf8.rs index 95109016d8..5153ab5511 100644 --- a/vortex-array/src/scalar/utf8.rs +++ b/vortex-array/src/scalar/utf8.rs @@ -1,103 +1,82 @@ -use std::any::Any; use std::fmt::{Display, Formatter}; use crate::dtype::{DType, Nullability}; use crate::error::{VortexError, VortexResult}; -use crate::scalar::{Scalar, ScalarRef}; +use crate::scalar::Scalar; #[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct Utf8Scalar { - value: String, + value: Option, } impl Utf8Scalar { - pub fn new(value: String) -> Self { + pub fn new(value: Option) -> Self { Self { value } } - pub fn value(&self) -> &str { - self.value.as_str() - } -} - -impl Scalar for Utf8Scalar { - #[inline] - fn as_any(&self) -> &dyn Any { - self - } - #[inline] - fn into_any(self: Box) -> Box { - self - } - - #[inline] - fn as_nonnull(&self) -> Option<&dyn Scalar> { - Some(self) + pub fn value(&self) -> Option<&str> { + self.value.as_deref() } #[inline] - fn into_nonnull(self: Box) -> Option { - Some(self) - } - - #[inline] - fn boxed(self) -> ScalarRef { - Box::new(self) - } - - #[inline] - fn dtype(&self) -> &DType { + pub fn dtype(&self) -> &DType { &DType::Utf8(Nullability::NonNullable) } - fn cast(&self, _dtype: &DType) -> VortexResult { + pub fn cast(&self, _dtype: &DType) -> VortexResult { todo!() } - fn nbytes(&self) -> usize { - self.value.len() + pub fn nbytes(&self) -> usize { + self.value().map(|v| v.len()).unwrap_or(0) } } -impl From for ScalarRef { +impl From for Scalar { fn from(value: String) -> Self { - Utf8Scalar::new(value).boxed() + Utf8Scalar::new(Some(value)).into() } } -impl From<&str> for ScalarRef { +impl From<&str> for Scalar { fn from(value: &str) -> Self { - Utf8Scalar::new(value.to_string()).boxed() + Utf8Scalar::new(Some(value.to_string())).into() } } -impl TryFrom for String { +impl TryFrom for String { type Error = VortexError; - fn try_from(value: ScalarRef) -> Result { - let dtype = value.dtype().clone(); - let scalar = value - .into_any() - .downcast::() - .map_err(|_| VortexError::InvalidDType(dtype))?; - Ok(scalar.value) + fn try_from(value: Scalar) -> Result { + let Scalar::Utf8(u) = value else { + return Err(VortexError::InvalidDType(value.dtype().clone())); + }; + match u.value { + None => Err(VortexError::InvalidDType(u.dtype().clone())), + Some(s) => Ok(s), + } } } -impl TryFrom<&dyn Scalar> for String { +impl TryFrom<&Scalar> for String { type Error = VortexError; - fn try_from(value: &dyn Scalar) -> Result { - if let Some(scalar) = value.as_any().downcast_ref::() { - Ok(scalar.value().to_string()) - } else { - Err(VortexError::InvalidDType(value.dtype().clone())) + fn try_from(value: &Scalar) -> Result { + let Scalar::Utf8(u) = value else { + return Err(VortexError::InvalidDType(value.dtype().clone())); + }; + match u.value() { + None => Err(VortexError::InvalidDType(u.dtype().clone())), + Some(s) => Ok(s.to_string()), } } } impl Display for Utf8Scalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.value) + match self.value() { + None => write!(f, ""), + Some(v) => Display::fmt(v, f), + } } } diff --git a/vortex-array/src/serde/dtype.rs b/vortex-array/src/serde/dtype.rs index 19caca7e11..6cfef12d28 100644 --- a/vortex-array/src/serde/dtype.rs +++ b/vortex-array/src/serde/dtype.rs @@ -17,15 +17,21 @@ impl<'a> DTypeReader<'a> { Self { reader } } - fn read_byte(&mut self) -> io::Result { - let mut buf: [u8; 1] = [0; 1]; - self.reader.read_exact(&mut buf)?; - Ok(buf[0]) + fn read_nbytes(&mut self) -> io::Result<[u8; N]> { + let mut bytes: [u8; N] = [0; N]; + self.reader.read_exact(&mut bytes)?; + Ok(bytes) + } + + fn read_usize(&mut self) -> io::Result { + leb128::read::unsigned(self.reader) + .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + .map(|u| u as usize) } pub fn read(&mut self) -> io::Result { - let dtype = DTypeTag::try_from(self.read_byte()?) - .map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?; + let dtype = DTypeTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; match dtype { DTypeTag::Null => Ok(Null), DTypeTag::Bool => Ok(Bool(self.read_nullability()?)), @@ -45,8 +51,7 @@ impl<'a> DTypeReader<'a> { DTypeTag::Binary => Ok(Binary(self.read_nullability()?)), DTypeTag::Decimal => { let nullability = self.read_nullability()?; - let mut precision_scale: [u8; 2] = [0; 2]; - self.reader.read_exact(&mut precision_scale)?; + let precision_scale: [u8; 2] = self.read_nbytes()?; Ok(Decimal( precision_scale[0], precision_scale[1] as i8, @@ -79,18 +84,16 @@ impl<'a> DTypeReader<'a> { )) } DTypeTag::Struct => { - let field_num = leb128::read::unsigned(self.reader) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; - let mut names = Vec::>::with_capacity(field_num as usize); + let field_num = self.read_usize()?; + let mut names = Vec::with_capacity(field_num); for _ in 0..field_num { - let len = leb128::read::unsigned(self.reader) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; - let mut name = String::with_capacity(len as usize); - self.reader.take(len).read_to_string(&mut name)?; + let len = self.read_usize()?; + let mut name = String::with_capacity(len); + self.reader.take(len as u64).read_to_string(&mut name)?; names.push(Arc::new(name)); } - let mut fields = Vec::::with_capacity(field_num as usize); + let mut fields = Vec::with_capacity(field_num); for _ in 0..field_num { fields.push(self.read()?); } @@ -100,32 +103,32 @@ impl<'a> DTypeReader<'a> { } fn read_signedness(&mut self) -> io::Result { - SignednessTag::try_from(self.read_byte()?) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + SignednessTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map(Signedness::from) } fn read_nullability(&mut self) -> io::Result { - NullabilityTag::try_from(self.read_byte()?) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + NullabilityTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map(Nullability::from) } fn read_int_width(&mut self) -> io::Result { - IntWidthTag::try_from(self.read_byte()?) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + IntWidthTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map(IntWidth::from) } fn read_float_width(&mut self) -> io::Result { - FloatWidthTag::try_from(self.read_byte()?) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + FloatWidthTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map(FloatWidth::from) } fn read_time_unit(&mut self) -> io::Result { - TimeUnitTag::try_from(self.read_byte()?) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + TimeUnitTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map(TimeUnit::from) } } diff --git a/vortex-array/src/serde/mod.rs b/vortex-array/src/serde/mod.rs index a2bc1d517c..c0f06e0bf4 100644 --- a/vortex-array/src/serde/mod.rs +++ b/vortex-array/src/serde/mod.rs @@ -5,10 +5,13 @@ use arrow::buffer::{Buffer, MutableBuffer}; use crate::array::{Array, ArrayRef, EncodingId, ENCODINGS}; use crate::dtype::{DType, IntWidth, Nullability, Signedness}; -use crate::scalar::{Scalar, ScalarReader, ScalarRef, ScalarWriter}; +use crate::ptype::PType; +use crate::scalar::{Scalar, ScalarReader, ScalarWriter}; pub use crate::serde::dtype::{DTypeReader, DTypeWriter, TimeUnitTag}; +use crate::serde::ptype::PTypeTag; mod dtype; +mod ptype; pub trait ArraySerde { fn write(&self, ctx: &mut WriteCtx) -> io::Result<()>; @@ -70,9 +73,20 @@ impl<'a> ReadCtx<'a> { DTypeReader::new(self.r).read() } + pub fn ptype(&mut self) -> io::Result { + let typetag = PTypeTag::try_from(self.read_nbytes::<1>()?[0]) + .map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?; + Ok(typetag.into()) + } + #[inline] - pub fn scalar(&mut self) -> io::Result { - ScalarReader::new(self.r).read() + pub fn scalar(&mut self) -> io::Result { + ScalarReader::new(self).read() + } + + pub fn read_optional_slice(&mut self) -> io::Result>> { + let is_present = self.read_option_tag()?; + is_present.then(|| self.read_slice()).transpose() } pub fn read_slice(&mut self) -> io::Result> { @@ -152,7 +166,11 @@ impl<'a> WriteCtx<'a> { DTypeWriter::new(self).write(dtype) } - pub fn scalar(&mut self, scalar: &dyn Scalar) -> io::Result<()> { + pub fn ptype(&mut self, ptype: PType) -> io::Result<()> { + self.write_fixed_slice([PTypeTag::from(ptype).into()]) + } + + pub fn scalar(&mut self, scalar: &Scalar) -> io::Result<()> { ScalarWriter::new(self).write(scalar) } @@ -169,6 +187,15 @@ impl<'a> WriteCtx<'a> { self.w.write_all(slice) } + pub fn write_optional_slice(&mut self, slice: Option<&[u8]>) -> io::Result<()> { + self.write_option_tag(slice.is_some())?; + if let Some(s) = slice { + self.write_slice(s) + } else { + Ok(()) + } + } + pub fn write_buffer(&mut self, logical_len: usize, buf: &Buffer) -> io::Result<()> { self.write_usize(logical_len)?; self.w.write_all(buf.as_slice()) diff --git a/vortex-array/src/serde/ptype.rs b/vortex-array/src/serde/ptype.rs new file mode 100644 index 0000000000..4481a6b3c0 --- /dev/null +++ b/vortex-array/src/serde/ptype.rs @@ -0,0 +1,55 @@ +use num_enum::{IntoPrimitive, TryFromPrimitive}; + +use crate::ptype::PType; + +#[derive(IntoPrimitive, TryFromPrimitive)] +#[repr(u8)] +pub enum PTypeTag { + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + F16, + F32, + F64, +} + +impl From for PTypeTag { + fn from(value: PType) -> Self { + match value { + PType::U8 => PTypeTag::U8, + PType::U16 => PTypeTag::U16, + PType::U32 => PTypeTag::U32, + PType::U64 => PTypeTag::U64, + PType::I8 => PTypeTag::I8, + PType::I16 => PTypeTag::I16, + PType::I32 => PTypeTag::I32, + PType::I64 => PTypeTag::I64, + PType::F16 => PTypeTag::F16, + PType::F32 => PTypeTag::F32, + PType::F64 => PTypeTag::F64, + } + } +} + +impl From for PType { + fn from(value: PTypeTag) -> Self { + match value { + PTypeTag::U8 => PType::U8, + PTypeTag::U16 => PType::U16, + PTypeTag::U32 => PType::U32, + PTypeTag::U64 => PType::U64, + PTypeTag::I8 => PType::I8, + PTypeTag::I16 => PType::I16, + PTypeTag::I32 => PType::I32, + PTypeTag::I64 => PType::I64, + PTypeTag::F16 => PType::F16, + PTypeTag::F32 => PType::F32, + PTypeTag::F64 => PType::F64, + } + } +} diff --git a/vortex-array/src/stats.rs b/vortex-array/src/stats.rs index 0ccf75a21e..092f5fc7bb 100644 --- a/vortex-array/src/stats.rs +++ b/vortex-array/src/stats.rs @@ -3,12 +3,12 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::RwLock; -use crate::dtype::DType; use itertools::Itertools; +use crate::dtype::DType; use crate::error::{VortexError, VortexResult}; use crate::ptype::NativePType; -use crate::scalar::{ListScalarVec, ScalarRef}; +use crate::scalar::{ListScalarVec, Scalar}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Stat { @@ -24,29 +24,29 @@ pub enum Stat { } #[derive(Debug, Clone, Default)] -pub struct StatsSet(HashMap); +pub struct StatsSet(HashMap); impl StatsSet { pub fn new() -> Self { StatsSet(HashMap::new()) } - pub fn from(map: HashMap) -> Self { + pub fn from(map: HashMap) -> Self { StatsSet(map) } - pub fn of(stat: Stat, value: ScalarRef) -> Self { + pub fn of(stat: Stat, value: Scalar) -> Self { StatsSet(HashMap::from([(stat, value)])) } - fn get_as>( + fn get_as>( &self, stat: &Stat, ) -> VortexResult> { self.0.get(stat).map(|v| T::try_from(v.clone())).transpose() } - pub fn set(&mut self, stat: Stat, value: ScalarRef) { + pub fn set(&mut self, stat: Stat, value: Scalar) { self.0.insert(stat, value); } @@ -68,7 +68,7 @@ impl StatsSet { match self.0.entry(Stat::Min) { Entry::Occupied(mut e) => { if let Some(omin) = other.0.get(&Stat::Min) { - match omin.partial_cmp(e.get().as_ref()) { + match omin.partial_cmp(e.get()) { None => { e.remove(); } @@ -91,7 +91,7 @@ impl StatsSet { match self.0.entry(Stat::Max) { Entry::Occupied(mut e) => { if let Some(omin) = other.0.get(&Stat::Max) { - match omin.partial_cmp(e.get().as_ref()) { + match omin.partial_cmp(e.get()) { None => { e.remove(); } @@ -148,7 +148,7 @@ impl StatsSet { match self.0.entry(stat.clone()) { Entry::Occupied(mut e) => { if let Some(other_value) = other.get_as::(stat).unwrap() { - let self_value: usize = e.get().as_ref().try_into().unwrap(); + let self_value: usize = e.get().try_into().unwrap(); e.insert((self_value + other_value).into()); } } @@ -168,7 +168,7 @@ impl StatsSet { .unwrap() { // TODO(robert): Avoid the copy here. We could e.get_mut() but need to figure out casting - let self_value: ListScalarVec = e.get().as_ref().try_into().unwrap(); + let self_value: ListScalarVec = e.get().try_into().unwrap(); e.insert( ListScalarVec( self_value @@ -195,7 +195,7 @@ impl StatsSet { match self.0.entry(Stat::RunCount) { Entry::Occupied(mut e) => { if let Some(other_value) = other.get_as::(&Stat::RunCount).unwrap() { - let self_value: usize = e.get().as_ref().try_into().unwrap(); + let self_value: usize = e.get().try_into().unwrap(); e.insert((self_value + other_value + 1).into()); } } @@ -232,7 +232,7 @@ impl<'a> Stats<'a> { }); } - pub fn set(&self, stat: Stat, value: ScalarRef) { + pub fn set(&self, stat: Stat, value: Scalar) { self.cache.write().unwrap().set(stat, value); } @@ -240,15 +240,15 @@ impl<'a> Stats<'a> { self.cache.read().unwrap().clone() } - pub fn get(&self, stat: &Stat) -> Option { + pub fn get(&self, stat: &Stat) -> Option { self.cache.read().unwrap().0.get(stat).cloned() } - pub fn get_as>(&self, stat: &Stat) -> Option { + pub fn get_as>(&self, stat: &Stat) -> Option { self.get(stat).map(|v| T::try_from(v).unwrap()) } - pub fn get_or_compute(&self, stat: &Stat) -> Option { + pub fn get_or_compute(&self, stat: &Stat) -> Option { if let Some(value) = self.cache.read().unwrap().0.get(stat) { return Some(value.clone()); } @@ -264,18 +264,18 @@ impl<'a> Stats<'a> { pub fn get_or_compute_cast(&self, stat: &Stat) -> Option { self.get_or_compute(stat) // TODO(ngates): fix the API so we don't convert the result to optional - .and_then(|v: ScalarRef| v.cast(&DType::from(T::PTYPE)).ok()) + .and_then(|v: Scalar| v.cast(&DType::from(T::PTYPE)).ok()) .and_then(|v| T::try_from(v).ok()) } - pub fn get_or_compute_as>( + pub fn get_or_compute_as>( &self, stat: &Stat, ) -> Option { self.get_or_compute(stat).and_then(|v| T::try_from(v).ok()) } - pub fn get_or_compute_or>( + pub fn get_or_compute_or>( &self, default: T, stat: &Stat, diff --git a/vortex-dict/src/compute.rs b/vortex-dict/src/compute.rs index 288068efe4..31eb6419dd 100644 --- a/vortex-dict/src/compute.rs +++ b/vortex-dict/src/compute.rs @@ -1,8 +1,9 @@ -use crate::DictArray; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::ArrayCompute; use vortex::error::VortexResult; -use vortex::scalar::ScalarRef; +use vortex::scalar::Scalar; + +use crate::DictArray; impl ArrayCompute for DictArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -11,7 +12,7 @@ impl ArrayCompute for DictArray { } impl ScalarAtFn for DictArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { let dict_index: usize = scalar_at(self.codes(), index)?.try_into()?; scalar_at(self.dict(), dict_index) } diff --git a/vortex-fastlanes/src/for/mod.rs b/vortex-fastlanes/src/for/mod.rs index 0ffc717f1f..33deda3eb0 100644 --- a/vortex-fastlanes/src/for/mod.rs +++ b/vortex-fastlanes/src/for/mod.rs @@ -7,7 +7,7 @@ use vortex::compute::ArrayCompute; use vortex::dtype::DType; use vortex::error::VortexResult; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; -use vortex::scalar::{Scalar, ScalarRef}; +use vortex::scalar::Scalar; use vortex::serde::{ArraySerde, EncodingSerde}; use vortex::stats::{Stat, Stats, StatsCompute, StatsSet}; @@ -17,12 +17,12 @@ mod serde; #[derive(Debug, Clone)] pub struct FoRArray { child: ArrayRef, - reference: ScalarRef, + reference: Scalar, stats: Arc>, } impl FoRArray { - pub fn try_new(child: ArrayRef, reference: ScalarRef) -> VortexResult { + pub fn try_new(child: ArrayRef, reference: Scalar) -> VortexResult { // TODO(ngates): check the dtype of reference == child.dtype() Ok(Self { child, @@ -37,8 +37,8 @@ impl FoRArray { } #[inline] - pub fn reference(&self) -> &dyn Scalar { - self.reference.as_ref() + pub fn reference(&self) -> &Scalar { + &self.reference } } diff --git a/vortex-fastlanes/src/for/serde.rs b/vortex-fastlanes/src/for/serde.rs index 7a03b425c5..f14aeb494b 100644 --- a/vortex-fastlanes/src/for/serde.rs +++ b/vortex-fastlanes/src/for/serde.rs @@ -22,12 +22,14 @@ impl EncodingSerde for FoREncoding { #[cfg(test)] mod test { - use crate::FoRArray; use std::io; + use vortex::array::{Array, ArrayRef}; - use vortex::scalar::ScalarRef; + use vortex::scalar::Scalar; use vortex::serde::{ReadCtx, WriteCtx}; + use crate::FoRArray; + fn roundtrip_array(array: &dyn Array) -> io::Result { let mut buf = Vec::::new(); let mut write_ctx = WriteCtx::new(&mut buf); @@ -41,7 +43,7 @@ mod test { fn roundtrip() { let arr = FoRArray::try_new( vec![-7i64, -13, 17, 23].into(), - >::into(-7i64), + >::into(-7i64), ) .unwrap(); roundtrip_array(arr.as_ref()).unwrap(); diff --git a/vortex-ree/src/compute.rs b/vortex-ree/src/compute.rs index 94a5d75576..c77272e547 100644 --- a/vortex-ree/src/compute.rs +++ b/vortex-ree/src/compute.rs @@ -1,8 +1,9 @@ -use crate::REEArray; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::ArrayCompute; use vortex::error::VortexResult; -use vortex::scalar::ScalarRef; +use vortex::scalar::Scalar; + +use crate::REEArray; impl ArrayCompute for REEArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -11,7 +12,7 @@ impl ArrayCompute for REEArray { } impl ScalarAtFn for REEArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { scalar_at(self.values(), self.find_physical_index(index)?) } } diff --git a/vortex-roaring/src/boolean/compute.rs b/vortex-roaring/src/boolean/compute.rs index 8f0db54106..7a41a02193 100644 --- a/vortex-roaring/src/boolean/compute.rs +++ b/vortex-roaring/src/boolean/compute.rs @@ -1,8 +1,9 @@ -use crate::RoaringBoolArray; use vortex::compute::scalar_at::ScalarAtFn; use vortex::compute::ArrayCompute; use vortex::error::VortexResult; -use vortex::scalar::ScalarRef; +use vortex::scalar::Scalar; + +use crate::RoaringBoolArray; impl ArrayCompute for RoaringBoolArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -11,7 +12,7 @@ impl ArrayCompute for RoaringBoolArray { } impl ScalarAtFn for RoaringBoolArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { if self.bitmap.contains(index as u32) { Ok(true.into()) } else { diff --git a/vortex-roaring/src/boolean/mod.rs b/vortex-roaring/src/boolean/mod.rs index 9c11cdd256..79ad8be157 100644 --- a/vortex-roaring/src/boolean/mod.rs +++ b/vortex-roaring/src/boolean/mod.rs @@ -157,7 +157,7 @@ mod test { use vortex::array::Array; use vortex::compute::scalar_at::scalar_at; use vortex::error::VortexResult; - use vortex::scalar::ScalarRef; + use vortex::scalar::Scalar; use crate::RoaringBoolArray; @@ -177,8 +177,8 @@ mod test { let bool: &dyn Array = &BoolArray::from(vec![true, false, true, true]); let array = RoaringBoolArray::encode(bool)?; - let truthy: ScalarRef = true.into(); - let falsy: ScalarRef = false.into(); + let truthy: Scalar = true.into(); + let falsy: Scalar = false.into(); assert_eq!(scalar_at(array.as_ref(), 0)?, truthy); assert_eq!(scalar_at(array.as_ref(), 1)?, falsy); diff --git a/vortex-roaring/src/integer/compute.rs b/vortex-roaring/src/integer/compute.rs index 45a97969a4..b05508d4ff 100644 --- a/vortex-roaring/src/integer/compute.rs +++ b/vortex-roaring/src/integer/compute.rs @@ -1,9 +1,10 @@ -use crate::RoaringIntArray; use vortex::compute::scalar_at::ScalarAtFn; use vortex::compute::ArrayCompute; use vortex::error::VortexResult; use vortex::ptype::PType; -use vortex::scalar::ScalarRef; +use vortex::scalar::Scalar; + +use crate::RoaringIntArray; impl ArrayCompute for RoaringIntArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -12,10 +13,10 @@ impl ArrayCompute for RoaringIntArray { } impl ScalarAtFn for RoaringIntArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { // Unwrap since we know the index is valid let bitmap_value = self.bitmap.select(index as u32).unwrap(); - let scalar: ScalarRef = match self.ptype { + let scalar: Scalar = match self.ptype { PType::U8 => (bitmap_value as u8).into(), PType::U16 => (bitmap_value as u16).into(), PType::U32 => bitmap_value.into(), diff --git a/vortex-zigzag/src/compute.rs b/vortex-zigzag/src/compute.rs index b6516e97b6..5f11e6a7d9 100644 --- a/vortex-zigzag/src/compute.rs +++ b/vortex-zigzag/src/compute.rs @@ -1,11 +1,12 @@ -use crate::ZigZagArray; +use zigzag::ZigZag; + use vortex::array::Array; use vortex::compute::scalar_at::{scalar_at, ScalarAtFn}; use vortex::compute::ArrayCompute; -use vortex::dtype::{DType, IntWidth, Signedness}; use vortex::error::{VortexError, VortexResult}; -use vortex::scalar::{NullableScalar, Scalar, ScalarRef}; -use zigzag::ZigZag; +use vortex::scalar::{PScalar, Scalar}; + +use crate::ZigZagArray; impl ArrayCompute for ZigZagArray { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { @@ -14,24 +15,19 @@ impl ArrayCompute for ZigZagArray { } impl ScalarAtFn for ZigZagArray { - fn scalar_at(&self, index: usize) -> VortexResult { + fn scalar_at(&self, index: usize) -> VortexResult { let scalar = scalar_at(self.encoded(), index)?; - let Some(scalar) = scalar.as_nonnull() else { - return Ok(NullableScalar::none(self.dtype().clone()).boxed()); - }; - match self.dtype() { - DType::Int(IntWidth::_8, Signedness::Signed, _) => { - Ok(i8::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_16, Signedness::Signed, _) => { - Ok(i16::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_32, Signedness::Signed, _) => { - Ok(i32::decode(scalar.try_into()?).into()) - } - DType::Int(IntWidth::_64, Signedness::Signed, _) => { - Ok(i64::decode(scalar.try_into()?).into()) - } + match scalar { + Scalar::Primitive(p) => match p.value() { + None => Ok(Scalar::null(self.dtype())), + Some(p) => match p { + PScalar::U8(u) => Ok(i8::decode(u).into()), + PScalar::U16(u) => Ok(i16::decode(u).into()), + PScalar::U32(u) => Ok(i32::decode(u).into()), + PScalar::U64(u) => Ok(i64::decode(u).into()), + _ => Err(VortexError::InvalidDType(self.dtype().clone())), + }, + }, _ => Err(VortexError::InvalidDType(self.dtype().clone())), } }