From d43e617c5e0ed142dbcd7d72089f044b3c31fdac Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 8 May 2024 09:54:44 +0100 Subject: [PATCH] Scalar Refactor --- Cargo.lock | 1 + vortex-array/src/array/varbin/stats.rs | 12 ++- vortex-dict/src/compress.rs | 4 +- vortex-dtype/Cargo.toml | 5 +- vortex-scalar/src/extension.rs | 23 +++-- vortex-scalar/src/lib.rs | 1 + vortex-scalar/src/list.rs | 58 ++++++++--- vortex-scalar/src/primitive.rs | 87 +++++++++------- vortex-scalar/src/pvalue.rs | 132 +++++++++++++++++++++++++ vortex-scalar/src/serde/serde.rs | 44 +++++++-- vortex-scalar/src/struct_.rs | 26 +++-- vortex-scalar/src/value.rs | 70 +++---------- 12 files changed, 327 insertions(+), 136 deletions(-) create mode 100644 vortex-scalar/src/pvalue.rs diff --git a/Cargo.lock b/Cargo.lock index 04aef3e279..328d5fe9d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1985,6 +1985,7 @@ dependencies = [ "cfg-if", "crunchy", "num-traits", + "serde", ] [[package]] diff --git a/vortex-array/src/array/varbin/stats.rs b/vortex-array/src/array/varbin/stats.rs index 121c4b3b28..c77fb4c0e4 100644 --- a/vortex-array/src/array/varbin/stats.rs +++ b/vortex-array/src/array/varbin/stats.rs @@ -124,7 +124,9 @@ impl<'a> VarBinAccumulator<'a> { #[cfg(test)] mod test { - use vortex_buffer::BufferString; + use std::ops::Deref; + + use vortex_buffer::{Buffer, BufferString}; use vortex_dtype::{DType, Nullability}; use crate::array::varbin::{OwnedVarBinArray, VarBinArray}; @@ -157,12 +159,12 @@ mod test { fn binary_stats() { let arr = array(DType::Binary(Nullability::NonNullable)); assert_eq!( - arr.statistics().compute_min::>().unwrap(), - "hello world".as_bytes().to_vec() + arr.statistics().compute_min::().unwrap().deref(), + "hello world".as_bytes() ); assert_eq!( - arr.statistics().compute_max::>().unwrap(), - "hello world this is a long string".as_bytes().to_vec() + arr.statistics().compute_max::().unwrap().deref(), + "hello world this is a long string".as_bytes() ); assert_eq!(arr.statistics().compute_run_count().unwrap(), 2); assert!(!arr.statistics().compute_is_constant().unwrap()); diff --git a/vortex-dict/src/compress.rs b/vortex-dict/src/compress.rs index ba66b723a3..ae516537f2 100644 --- a/vortex-dict/src/compress.rs +++ b/vortex-dict/src/compress.rs @@ -255,7 +255,7 @@ mod test { use vortex::compute::scalar_at::scalar_at; use vortex::ToArray; use vortex_dtype::Nullability::Nullable; - use vortex_dtype::PType; + use vortex_dtype::{DType, PType}; use vortex_scalar::Scalar; use crate::compress::{dict_encode_typed_primitive, dict_encode_varbin}; @@ -287,7 +287,7 @@ mod test { ); assert_eq!( scalar_at(&values.to_array(), 0).unwrap(), - Scalar::null(PType::I32.into()) + Scalar::null(DType::Primitive(PType::I32, Nullable)) ); assert_eq!( scalar_at(&values.to_array(), 1).unwrap(), diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index 3dcdc20299..2797da7644 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] flatbuffers = { workspace = true, optional = true } -half = { workspace = true } +half = { workspace = true, features = ["num-traits"] } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true, optional = true } @@ -34,3 +34,6 @@ build-vortex = { path = "../build-vortex" } [lints] workspace = true + +[features] +serde = ["dep:serde", "half/serde"] \ No newline at end of file diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index 0558e374c9..73cc9e9461 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -4,16 +4,22 @@ use vortex_error::{vortex_bail, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; -pub struct ExtScalar<'a>(&'a Scalar); +pub struct ExtScalar<'a> { + dtype: &'a DType, + // TODO(ngates): we may need to serialize the value's dtype too so we can pull + // it out as a scalar. + value: &'a ScalarValue, +} + impl<'a> ExtScalar<'a> { #[inline] pub fn dtype(&self) -> &'a DType { - self.0.dtype() + self.dtype } /// Returns the stored value of the extension scalar. - pub fn value(&self) -> &ScalarValue { - &self.0.value + pub fn value(&self) -> &'a ScalarValue { + self.value } pub fn cast(&self, _dtype: &DType) -> VortexResult { @@ -25,11 +31,14 @@ impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if matches!(value.dtype(), DType::Extension(..)) { - Ok(Self(value)) - } else { + if !matches!(value.dtype(), DType::Extension(..)) { vortex_bail!("Expected extension scalar, found {}", value.dtype()) } + + Ok(Self { + dtype: value.dtype(), + value: &value.value, + }) } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 287bd5ea19..9331b165b6 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -8,6 +8,7 @@ mod display; mod extension; mod list; mod primitive; +mod pvalue; mod serde; mod struct_; mod utf8; diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 2f07b637dd..e07e56074f 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,39 +1,65 @@ +use std::ops::Deref; +use std::sync::Arc; + use itertools::Itertools; use vortex_dtype::DType; +use vortex_dtype::Nullability::NonNullable; use vortex_error::{vortex_bail, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; -pub struct ListScalar<'a>(&'a Scalar); +pub struct ListScalar<'a> { + dtype: &'a DType, + elements: Option>, +} + impl<'a> ListScalar<'a> { #[inline] pub fn dtype(&self) -> &'a DType { - self.0.dtype() + self.dtype } #[inline] pub fn len(&self) -> usize { - self.0.value.len() + self.elements.as_ref().map(|e| e.len()).unwrap_or(0) } #[inline] pub fn is_empty(&self) -> bool { - self.len() == 0 + match self.elements.as_ref() { + None => true, + Some(l) => l.is_empty(), + } } - pub fn element(&self, idx: usize) -> Option { + pub fn element_dtype(&self) -> DType { let DType::List(element_type, _) = self.dtype() else { unreachable!(); }; - self.0.value.child(idx).map(|value| Scalar { - dtype: element_type.as_ref().clone(), - value, - }) + (*element_type).deref().clone() + } + + pub fn element(&self, idx: usize) -> Option { + self.elements + .as_ref() + .and_then(|l| l.get(idx)) + .map(|value| Scalar { + dtype: self.element_dtype(), + value: value.clone(), + }) } pub fn elements(&self) -> impl Iterator + '_ { - (0..self.len()).map(move |idx| self.element(idx).expect("incorrect length")) + self.elements + .as_ref() + .map(|e| e.as_ref()) + .unwrap_or_else(|| &[] as &[ScalarValue]) + .iter() + .map(|e| Scalar { + dtype: self.element_dtype(), + value: e.clone(), + }) } pub fn cast(&self, _dtype: &DType) -> VortexResult { @@ -45,11 +71,14 @@ impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if matches!(value.dtype(), DType::List(..)) { - Ok(Self(value)) - } else { + if !matches!(value.dtype(), DType::List(..)) { vortex_bail!("Expected list scalar, found {}", value.dtype()) } + + Ok(Self { + dtype: value.dtype(), + elements: value.value.as_list()?.cloned(), + }) } } @@ -72,7 +101,8 @@ where { fn from(value: Vec) -> Self { let scalars = value.into_iter().map(|v| Scalar::from(v)).collect_vec(); - let dtype = scalars.first().expect("Empty list").dtype().clone(); + let element_dtype = scalars.first().expect("Empty list").dtype().clone(); + let dtype = DType::List(Arc::new(element_dtype), NonNullable); Scalar { dtype, value: ScalarValue::List(scalars.into_iter().map(|s| s.value).collect_vec().into()), diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 9eeba9895c..58d76f5419 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,27 +1,36 @@ use num_traits::NumCast; -use vortex_buffer::Buffer; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use crate::pvalue::PValue; use crate::value::ScalarValue; use crate::Scalar; -pub struct PrimitiveScalar<'a>(&'a Scalar); +pub struct PrimitiveScalar<'a> { + dtype: &'a DType, + ptype: PType, + pvalue: Option, +} impl<'a> PrimitiveScalar<'a> { #[inline] pub fn dtype(&self) -> &'a DType { - self.0.dtype() + self.dtype } #[inline] pub fn ptype(&self) -> PType { - PType::try_from(self.dtype()).expect("Invalid primitive scalar dtype") + self.ptype } - pub fn typed_value(&self) -> Option { - self.0.value.as_primitive::() + pub fn typed_value>(&self) -> Option { + if self.ptype != T::PTYPE { + panic!("Attempting to read {} scalar as {}", self.ptype, T::PTYPE); + } + self.pvalue + .as_ref() + .map(|pv| T::try_from(*pv).expect("checked on construction")) } pub fn cast(&self, dtype: &DType) -> VortexResult { @@ -42,44 +51,46 @@ impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if matches!(value.dtype(), DType::Primitive(..)) { - Ok(Self(value)) - } else { + if !matches!(value.dtype(), DType::Primitive(..)) { vortex_bail!("Expected primitive scalar, found {}", value.dtype()) } + + let ptype = PType::try_from(value.dtype())?; + + // Read the serialized value into the correct PValue. + // The serialized form may come back over the wire as e.g. any integer type. + let pvalue = match_each_native_ptype!(ptype, |$T| { + if let Some(pvalue) = value.value.as_pvalue()? { + Some(PValue::from(<$T>::try_from(pvalue)?)) + } else { + None + } + }); + + Ok(Self { + dtype: value.dtype(), + ptype, + pvalue, + }) } } impl Scalar { - pub fn primitive(value: T, nullability: Nullability) -> Scalar { + pub fn primitive>(value: T, nullability: Nullability) -> Scalar { Scalar { dtype: DType::Primitive(T::PTYPE, nullability), - value: ScalarValue::Buffer(value.to_le_bytes().into()), + value: ScalarValue::Primitive(value.into()), } } } -impl From for Scalar { - fn from(value: usize) -> Self { - Scalar::from(value as u64) - } -} - -impl TryFrom<&Scalar> for usize { - type Error = VortexError; - - fn try_from(value: &Scalar) -> Result { - u64::try_from(value).map(|value| value as usize) - } -} - macro_rules! primitive_scalar { ($T:ty) => { impl From<$T> for Scalar { fn from(value: $T) -> Self { Scalar { dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable), - value: ScalarValue::from(value), + value: ScalarValue::Primitive(value.into()), } } } @@ -89,7 +100,7 @@ macro_rules! primitive_scalar { Scalar { dtype: DType::Primitive(<$T>::PTYPE, Nullability::Nullable), value: value - .map(|v| ScalarValue::from(v)) + .map(|v| ScalarValue::Primitive(v.into())) .unwrap_or_else(|| ScalarValue::Null), } } @@ -115,24 +126,26 @@ primitive_scalar!(i8); primitive_scalar!(i16); primitive_scalar!(i32); primitive_scalar!(i64); +primitive_scalar!(f16); primitive_scalar!(f32); primitive_scalar!(f64); -impl From for Scalar { - fn from(value: f16) -> Self { - Scalar { - dtype: DType::Primitive(PType::F16, Nullability::NonNullable), - value: ScalarValue::Buffer(Buffer::from(value.to_le_bytes().to_vec())), - } +impl From for Scalar { + fn from(value: usize) -> Self { + Scalar::from(value as u64) } } -impl TryFrom<&Scalar> for f16 { +/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics. +impl TryFrom<&Scalar> for usize { type Error = VortexError; fn try_from(value: &Scalar) -> Result { - PrimitiveScalar::try_from(value)? - .typed_value::() - .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + Ok(u64::try_from( + value + .cast(&DType::Primitive(PType::U64, Nullability::NonNullable))? + .as_ref(), + ) + .map(|v| v as usize)?) } } diff --git a/vortex-scalar/src/pvalue.rs b/vortex-scalar/src/pvalue.rs new file mode 100644 index 0000000000..b4bca018ab --- /dev/null +++ b/vortex-scalar/src/pvalue.rs @@ -0,0 +1,132 @@ +use num_traits::NumCast; +use vortex_dtype::half::f16; +use vortex_dtype::PType; +use vortex_error::vortex_err; +use vortex_error::VortexError; + +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub enum PValue { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + F16(f16), + F32(f32), + F64(f64), +} + +impl PValue { + pub fn ptype(&self) -> PType { + match self { + PValue::U8(_) => PType::U8, + PValue::U16(_) => PType::U16, + PValue::U32(_) => PType::U32, + PValue::U64(_) => PType::U64, + PValue::I8(_) => PType::I8, + PValue::I16(_) => PType::I16, + PValue::I32(_) => PType::I32, + PValue::I64(_) => PType::I64, + PValue::F16(_) => PType::F16, + PValue::F32(_) => PType::F32, + PValue::F64(_) => PType::F64, + } + } +} + +macro_rules! int_pvalue { + ($T:ty, $PT:tt) => { + impl TryFrom for $T { + type Error = VortexError; + + fn try_from(value: PValue) -> Result { + match value { + PValue::U8(v) => <$T as NumCast>::from(v), + PValue::U16(v) => <$T as NumCast>::from(v), + PValue::U32(v) => <$T as NumCast>::from(v), + PValue::U64(v) => <$T as NumCast>::from(v), + PValue::I8(v) => <$T as NumCast>::from(v), + PValue::I16(v) => <$T as NumCast>::from(v), + PValue::I32(v) => <$T as NumCast>::from(v), + PValue::I64(v) => <$T as NumCast>::from(v), + _ => None, + } + .ok_or_else(|| { + vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT) + }) + } + } + }; +} + +int_pvalue!(u8, U8); +int_pvalue!(u16, U16); +int_pvalue!(u32, U32); +int_pvalue!(u64, U64); +int_pvalue!(i8, I8); +int_pvalue!(i16, I16); +int_pvalue!(i32, I32); +int_pvalue!(i64, I64); + +macro_rules! float_pvalue { + ($T:ty, $PT:tt) => { + impl TryFrom for $T { + type Error = VortexError; + + fn try_from(value: PValue) -> Result { + match value { + PValue::F16(f) => <$T as NumCast>::from(f), + PValue::F32(f) => <$T as NumCast>::from(f), + PValue::F64(f) => <$T as NumCast>::from(f), + _ => None, + } + .ok_or_else(|| { + vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT) + }) + } + } + }; +} + +float_pvalue!(f32, F32); +float_pvalue!(f64, F64); + +impl TryFrom for f16 { + type Error = VortexError; + + fn try_from(value: PValue) -> Result { + // We serialize f16 as u16. + match value { + PValue::U16(u) => Some(f16::from_bits(u)), + PValue::F32(f) => ::from(f), + PValue::F64(f) => ::from(f), + _ => None, + } + .ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as {}", value, PType::F16)) + } +} + +macro_rules! impl_pvalue { + ($T:ty, $PT:tt) => { + impl From<$T> for PValue { + fn from(value: $T) -> Self { + PValue::$PT(value) + } + } + }; +} + +impl_pvalue!(u8, U8); +impl_pvalue!(u16, U16); +impl_pvalue!(u32, U32); +impl_pvalue!(u64, U64); +impl_pvalue!(i8, I8); +impl_pvalue!(i16, I16); +impl_pvalue!(i32, I32); +impl_pvalue!(i64, I64); +impl_pvalue!(f16, F16); +impl_pvalue!(f32, F32); +impl_pvalue!(f64, F64); diff --git a/vortex-scalar/src/serde/serde.rs b/vortex-scalar/src/serde/serde.rs index a79d2f0b95..d9666e6412 100644 --- a/vortex-scalar/src/serde/serde.rs +++ b/vortex-scalar/src/serde/serde.rs @@ -5,6 +5,7 @@ use std::fmt::Formatter; use serde::de::{Error, SeqAccess, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::pvalue::PValue; use crate::value::ScalarValue; impl Serialize for ScalarValue { @@ -15,6 +16,7 @@ impl Serialize for ScalarValue { match self { ScalarValue::Null => ().serialize(serializer), ScalarValue::Bool(b) => b.serialize(serializer), + ScalarValue::Primitive(p) => p.serialize(serializer), ScalarValue::Buffer(buffer) => buffer.as_ref().serialize(serializer), ScalarValue::List(l) => l.serialize(serializer), } @@ -45,70 +47,70 @@ impl<'de> Deserialize<'de> for ScalarValue { where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::I8(v))) } fn visit_i16(self, v: i16) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::I16(v))) } fn visit_i32(self, v: i32) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::I32(v))) } fn visit_i64(self, v: i64) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::I64(v))) } fn visit_u8(self, v: u8) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::U8(v))) } fn visit_u16(self, v: u16) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::U16(v))) } fn visit_u32(self, v: u32) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::U32(v))) } fn visit_u64(self, v: u64) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::U64(v))) } fn visit_f32(self, v: f32) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::F32(v))) } fn visit_f64(self, v: f64) -> Result where E: Error, { - Ok(ScalarValue::Buffer(v.to_le_bytes().to_vec().into())) + Ok(ScalarValue::Primitive(PValue::F64(v))) } fn visit_str(self, v: &str) -> Result @@ -147,3 +149,25 @@ impl<'de> Deserialize<'de> for ScalarValue { deserializer.deserialize_any(ScalarValueVisitor) } } + +impl Serialize for PValue { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + PValue::U8(v) => serializer.serialize_u8(*v), + PValue::U16(v) => serializer.serialize_u16(*v), + PValue::U32(v) => serializer.serialize_u32(*v), + PValue::U64(v) => serializer.serialize_u64(*v), + PValue::I8(v) => serializer.serialize_i8(*v), + PValue::I16(v) => serializer.serialize_i16(*v), + PValue::I32(v) => serializer.serialize_i32(*v), + PValue::I64(v) => serializer.serialize_i64(*v), + // NOTE(ngates): f16's are serialized bit-wise as u16. + PValue::F16(v) => serializer.serialize_u16(v.to_bits()), + PValue::F32(v) => serializer.serialize_f32(*v), + PValue::F64(v) => serializer.serialize_f64(*v), + } + } +} diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index 94aa5520d8..08dcb6472c 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -1,22 +1,34 @@ +use std::sync::Arc; + use vortex_dtype::DType; use vortex_error::{vortex_bail, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; -pub struct StructScalar<'a>(&'a Scalar); +pub struct StructScalar<'a> { + dtype: &'a DType, + fields: Option>, +} + impl<'a> StructScalar<'a> { #[inline] pub fn dtype(&self) -> &'a DType { - self.0.dtype() + self.dtype } pub fn field_by_idx(&self, idx: usize, dtype: DType) -> Option { - self.0.value.child(idx).map(|value| Scalar { dtype, value }) + self.fields + .as_ref() + .and_then(|fields| fields.get(idx)) + .map(|field| Scalar { + dtype, + value: field.clone(), + }) } pub fn field(&self, name: &str, dtype: DType) -> Option { - let DType::Struct(struct_dtype, _) = self.0.dtype() else { + let DType::Struct(struct_dtype, _) = self.dtype() else { unreachable!() }; struct_dtype @@ -43,9 +55,11 @@ impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> { fn try_from(value: &'a Scalar) -> Result { if matches!(value.dtype(), DType::Struct(..)) { - Ok(Self(value)) - } else { vortex_bail!("Expected struct scalar, found {}", value.dtype()) } + Ok(Self { + dtype: value.dtype(), + fields: value.value.as_list()?.cloned(), + }) } } diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index 3281137ebf..ee9c661aa3 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -1,23 +1,21 @@ use std::sync::Arc; -use paste::paste; use vortex_buffer::Buffer; -use vortex_dtype::half::f16; -use vortex_dtype::NativePType; use vortex_error::{vortex_err, VortexResult}; -/// Represents the internal data of a scalar value. Can only be interpreted by wrapping +use crate::pvalue::PValue; + +/// Represents the internal data of a scalar value. Must be interpreted by wrapping /// up with a DType to make a Scalar. /// -/// This is similar to serde_json::Value, but uses our own Buffer implementation for bytes, -/// an Arc<[]> for list elements, and structs are modelled as lists. -/// -/// TODO(ngates): we could support reading structs from both structs and lists in the future since -/// storing sparse structs dense with null scalars may be inefficient. +/// Note that these values can be deserialized from JSON or other formats. So a PValue may not +/// have the correct width for what the DType expects. This means primitive values must be +/// cast on-read. #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum ScalarValue { Null, Bool(bool), + Primitive(PValue), Buffer(Buffer), List(Arc<[ScalarValue]>), } @@ -31,14 +29,15 @@ impl ScalarValue { match self { ScalarValue::Null => Ok(None), ScalarValue::Bool(b) => Ok(Some(*b)), - _ => Err(vortex_err!("Not a bool scalar")), + _ => Err(vortex_err!("Expected a bool scalar, found {:?}", self)), } } - pub fn as_primitive(&self) -> Option { + pub fn as_pvalue(&self) -> VortexResult> { match self { - ScalarValue::Buffer(b) => T::try_from_le_bytes(b.as_ref()).ok(), - _ => None, + ScalarValue::Null => Ok(None), + ScalarValue::Primitive(p) => Ok(Some(*p)), + _ => Err(vortex_err!("Expected a primitive scalar, found {:?}", self)), } } @@ -46,51 +45,14 @@ impl ScalarValue { match self { ScalarValue::Null => Ok(None), ScalarValue::Buffer(b) => Ok(Some(b.clone())), - _ => Err(vortex_err!("Not a binary scalar")), - } - } - - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - match self { - ScalarValue::List(l) => l.len(), - _ => 0, + _ => Err(vortex_err!("Expected a binary scalar, found {:?}", self)), } } - pub fn child(&self, idx: usize) -> Option { + pub fn as_list(&self) -> VortexResult>> { match self { - ScalarValue::List(l) => l.get(idx).cloned(), - _ => None, + ScalarValue::List(l) => Ok(Some(l)), + _ => Err(vortex_err!("Expected a list scalar, found {:?}", self)), } } } - -macro_rules! primitive_from_scalar_view { - ($T:ty) => { - paste! { - impl From<$T> for ScalarValue { - fn from(value: $T) -> Self { - ScalarValue::Buffer(Buffer::from(value.to_le_bytes().as_ref().to_vec())) - } - } - } - }; -} - -primitive_from_scalar_view!(u8); -primitive_from_scalar_view!(u16); -primitive_from_scalar_view!(u32); -primitive_from_scalar_view!(u64); -primitive_from_scalar_view!(i8); -primitive_from_scalar_view!(i16); -primitive_from_scalar_view!(i32); -primitive_from_scalar_view!(i64); -primitive_from_scalar_view!(f32); -primitive_from_scalar_view!(f64); - -impl From for ScalarValue { - fn from(value: f16) -> Self { - ScalarValue::Buffer(Buffer::from(value.to_le_bytes().as_ref().to_vec())) - } -}