From 4f28c1087b93641bfa2a531f3b3aed304e1d2789 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 30 Apr 2024 13:44:03 +0100 Subject: [PATCH 1/4] DType Structs --- bench-vortex/src/lib.rs | 4 +- bench-vortex/src/vortex_utils.rs | 4 +- vortex-array/src/array/composite/array.rs | 13 ++- vortex-array/src/array/composite/mod.rs | 2 +- vortex-array/src/array/composite/typed.rs | 2 +- vortex-array/src/array/struct/mod.rs | 14 +-- vortex-array/src/arrow/dtype.rs | 18 ++-- vortex-dtype/Cargo.toml | 2 +- vortex-dtype/src/composite.rs | 62 +++++++++++++ vortex-dtype/src/deserialize.rs | 40 ++------- vortex-dtype/src/dtype.rs | 105 +++++++++++----------- vortex-dtype/src/lib.rs | 19 ++-- vortex-dtype/src/nullability.rs | 62 +++++++++++++ vortex-dtype/src/serde.rs | 87 +++++++++++------- vortex-dtype/src/serialize.rs | 34 +++---- vortex-ipc/src/reader.rs | 8 +- vortex-scalar/src/lib.rs | 3 +- vortex-scalar/src/serde.rs | 16 ++-- vortex-scalar/src/struct_.rs | 20 ++--- 19 files changed, 301 insertions(+), 214 deletions(-) create mode 100644 vortex-dtype/src/composite.rs create mode 100644 vortex-dtype/src/nullability.rs diff --git a/bench-vortex/src/lib.rs b/bench-vortex/src/lib.rs index d8621fb82b..0f7baad33e 100644 --- a/bench-vortex/src/lib.rs +++ b/bench-vortex/src/lib.rs @@ -184,13 +184,13 @@ pub struct CompressionRunStats { impl CompressionRunStats { pub fn to_results(&self, dataset_name: String) -> Vec { - let DType::Struct(ns, fs) = &self.schema else { + let DType::Struct { names, dtypes } = &self.schema else { unreachable!() }; self.compressed_sizes .iter() - .zip_eq(ns.iter().zip_eq(fs)) + .zip_eq(names.iter().zip_eq(dtypes)) .map( |(&size, (column_name, column_type))| CompressionRunResults { dataset_name: dataset_name.clone(), diff --git a/bench-vortex/src/vortex_utils.rs b/bench-vortex/src/vortex_utils.rs index b266ac07af..9ac215f9eb 100644 --- a/bench-vortex/src/vortex_utils.rs +++ b/bench-vortex/src/vortex_utils.rs @@ -16,11 +16,11 @@ pub fn vortex_chunk_sizes(path: &Path) -> VortexResult { let file = File::open(path)?; let total_compressed_size = file.metadata()?.size(); let vortex = open_vortex(path)?; - let DType::Struct(ns, _) = vortex.dtype() else { + let DType::Struct { names, .. } = vortex.dtype() else { unreachable!() }; - let mut compressed_sizes = vec![0; ns.len()]; + let mut compressed_sizes = vec![0; names.len()]; let chunked_array = ChunkedArray::try_from(vortex).unwrap(); for chunk in chunked_array.chunks() { let struct_arr = StructArray::try_from(chunk).unwrap(); diff --git a/vortex-array/src/array/composite/array.rs b/vortex-array/src/array/composite/array.rs index f3a1897cad..2db8f73d8a 100644 --- a/vortex-array/src/array/composite/array.rs +++ b/vortex-array/src/array/composite/array.rs @@ -1,6 +1,6 @@ use flatbuffers::root; use vortex_dtype::flatbuffers as fb; -use vortex_dtype::{CompositeID, DTypeSerdeContext}; +use vortex_dtype::CompositeID; use vortex_error::{vortex_err, VortexResult}; use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; @@ -34,7 +34,7 @@ impl TrySerializeArrayMetadata for CompositeMetadata { let mut fb = flexbuffers::Builder::default(); { let mut elems = fb.start_vector(); - elems.push(self.ext.id().0); + elems.push(self.ext.id().as_ref()); self.underlying_dtype .with_flatbuffer_bytes(|b| elems.push(flexbuffers::Blob(b))); elems.push(flexbuffers::Blob(self.underlying_metadata.as_ref())); @@ -53,9 +53,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata { .ok_or_else(|| vortex_err!("Unrecognized composite extension: {}", ext_id))?; let dtype_blob = elems.index(1).expect("missing dtype").as_blob(); - let ctx = DTypeSerdeContext::new(vec![]); // FIXME: composite_ids let underlying_dtype = DType::read_flatbuffer( - &ctx, + &(), &root::(dtype_blob.0).expect("invalid dtype"), )?; @@ -77,8 +76,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata { impl<'a> CompositeArray<'a> { pub fn new(id: CompositeID, metadata: Arc<[u8]>, underlying: Array<'a>) -> Self { - let dtype = DType::Composite(id, underlying.dtype().is_nullable().into()); - let ext = find_extension(id.0).expect("Unrecognized composite extension"); + let dtype = DType::Composite(id.clone(), underlying.dtype().is_nullable().into()); + let ext = find_extension(id.as_ref()).expect("Unrecognized composite extension"); Self::try_from_parts( dtype, CompositeMetadata { @@ -101,7 +100,7 @@ impl CompositeArray<'_> { #[inline] pub fn extension(&self) -> CompositeExtensionRef { - find_extension(self.id().0).expect("Unrecognized composite extension") + find_extension(self.id().as_ref()).expect("Unrecognized composite extension") } pub fn underlying_metadata(&self) -> &Arc<[u8]> { diff --git a/vortex-array/src/array/composite/mod.rs b/vortex-array/src/array/composite/mod.rs index 09f6bb5a5f..3cfa6de678 100644 --- a/vortex-array/src/array/composite/mod.rs +++ b/vortex-array/src/array/composite/mod.rs @@ -13,7 +13,7 @@ pub static VORTEX_COMPOSITE_EXTENSIONS: [&'static dyn CompositeExtension] = [..] pub fn find_extension(id: &str) -> Option<&'static dyn CompositeExtension> { VORTEX_COMPOSITE_EXTENSIONS .iter() - .find(|ext| ext.id().0 == id) + .find(|ext| ext.id().as_ref() == id) .copied() } diff --git a/vortex-array/src/array/composite/typed.rs b/vortex-array/src/array/composite/typed.rs index 7b4a158e5e..33983a3e5b 100644 --- a/vortex-array/src/array/composite/typed.rs +++ b/vortex-array/src/array/composite/typed.rs @@ -88,7 +88,7 @@ macro_rules! impl_composite { pub struct [<$T Extension>]; impl [<$T Extension>] { - pub const ID: CompositeID = CompositeID($id); + pub const ID: CompositeID = CompositeID::new($id); pub fn dtype(nullability: Nullability) -> DType { DType::Composite(Self::ID, nullability) diff --git a/vortex-array/src/array/struct/mod.rs b/vortex-array/src/array/struct/mod.rs index b07e07b0c6..e405772d92 100644 --- a/vortex-array/src/array/struct/mod.rs +++ b/vortex-array/src/array/struct/mod.rs @@ -19,25 +19,25 @@ pub struct StructMetadata { impl StructArray<'_> { pub fn child(&self, idx: usize) -> Option { - let DType::Struct(_, fields) = self.dtype() else { + let DType::Struct { dtypes, .. } = self.dtype() else { unreachable!() }; - let dtype = fields.get(idx)?; + let dtype = dtypes.get(idx)?; self.array().child(idx, dtype) } pub fn names(&self) -> &FieldNames { - let DType::Struct(names, _fields) = self.dtype() else { + let DType::Struct { names, .. } = self.dtype() else { unreachable!() }; names } pub fn fields(&self) -> &[DType] { - let DType::Struct(_names, fields) = self.dtype() else { + let DType::Struct { dtypes, .. } = self.dtype() else { unreachable!() }; - fields.as_slice() + dtypes.as_slice() } pub fn nfields(&self) -> usize { @@ -61,9 +61,9 @@ impl StructArray<'_> { vortex_bail!("Expected all struct fields to have length {}", length); } - let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect(); + let dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect(); Self::try_from_parts( - DType::Struct(names, field_dtypes), + DType::Struct { names, dtypes }, StructMetadata { length }, fields.into_iter().map(|a| a.into_array_data()).collect(), HashMap::default(), diff --git a/vortex-array/src/arrow/dtype.rs b/vortex-array/src/arrow/dtype.rs index 5a06ea18d2..f2f7018c7e 100644 --- a/vortex-array/src/arrow/dtype.rs +++ b/vortex-array/src/arrow/dtype.rs @@ -40,18 +40,18 @@ impl TryFromArrowType<&DataType> for PType { impl FromArrowType for DType { fn from_arrow(value: SchemaRef) -> Self { - DType::Struct( - value + DType::Struct { + names: value .fields() .iter() .map(|f| Arc::new(f.name().clone())) .collect(), - value + dtypes: value .fields() .iter() .map(|f| DType::from_arrow(f.as_ref())) .collect_vec(), - ) + } } } @@ -81,13 +81,13 @@ impl FromArrowType<&Field> for DType { DataType::List(e) | DataType::LargeList(e) => { List(Box::new(DType::from_arrow(e.as_ref())), nullability) } - DataType::Struct(f) => Struct( - f.iter().map(|f| Arc::new(f.name().clone())).collect(), - f.iter() + DataType::Struct(f) => Struct { + names: f.iter().map(|f| Arc::new(f.name().clone())).collect(), + dtypes: f + .iter() .map(|f| DType::from_arrow(f.as_ref())) .collect_vec(), - ), - DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(*p, *s, nullability), + }, _ => unimplemented!("Arrow data type not yet supported: {:?}", field.data_type()), } } diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index e129426438..21effb90f6 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -21,7 +21,7 @@ half = { workspace = true } itertools = { workspace = true } linkme = { workspace = true } num-traits = { workspace = true } -serde = { workspace = true, optional = true } +serde = { workspace = true, optional = true, features = ["rc"] } thiserror = { workspace = true } vortex-error = { path = "../vortex-error" } vortex-flatbuffers = { path = "../vortex-flatbuffers" } diff --git a/vortex-dtype/src/composite.rs b/vortex-dtype/src/composite.rs new file mode 100644 index 0000000000..118a0ed3a8 --- /dev/null +++ b/vortex-dtype/src/composite.rs @@ -0,0 +1,62 @@ +use std::fmt::{Display, Formatter}; + +use linkme::distributed_slice; +use serde::{Deserialize, Deserializer}; +use vortex_error::{vortex_err, VortexError}; + +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize))] +pub struct CompositeID(&'static str); + +impl CompositeID { + pub const fn new(id: &'static str) -> Self { + Self(id) + } +} + +impl<'a> TryFrom<&'a str> for CompositeID { + type Error = VortexError; + + fn try_from(value: &'a str) -> Result { + find_composite_dtype(value) + .map(|cdt| CompositeID(cdt.id())) + .ok_or_else(|| vortex_err!("CompositeID not found for the given id: {}", value)) + } +} + +impl AsRef for CompositeID { + fn as_ref(&self) -> &str { + self.0.as_ref() + } +} + +impl Display for CompositeID { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for CompositeID { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + CompositeID::try_from(<&'de str>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) + } +} + +pub trait CompositeDType { + fn id(&self) -> &'static str; +} + +#[distributed_slice] +pub static VORTEX_COMPOSITE_DTYPES: [&'static dyn CompositeDType] = [..]; + +pub fn find_composite_dtype(id: &str) -> Option<&'static dyn CompositeDType> { + VORTEX_COMPOSITE_DTYPES + .iter() + .find(|ext| ext.id() == id) + .copied() +} diff --git a/vortex-dtype/src/deserialize.rs b/vortex-dtype/src/deserialize.rs index 4ef3b2a406..f3fb86d5d2 100644 --- a/vortex-dtype/src/deserialize.rs +++ b/vortex-dtype/src/deserialize.rs @@ -6,31 +6,11 @@ use vortex_flatbuffers::ReadFlatBuffer; use crate::{flatbuffers as fb, Nullability}; use crate::{CompositeID, DType}; -#[allow(dead_code)] -pub struct DTypeSerdeContext { - composite_ids: Arc<[CompositeID]>, -} - -impl DTypeSerdeContext { - pub fn new(composite_ids: Vec) -> Self { - Self { - composite_ids: composite_ids.into(), - } - } - - pub fn find_composite_id(&self, id: &str) -> Option { - self.composite_ids.iter().find(|c| c.0 == id).copied() - } -} - -impl ReadFlatBuffer for DType { +impl ReadFlatBuffer<()> for DType { type Source<'a> = fb::DType<'a>; type Error = VortexError; - fn read_flatbuffer( - ctx: &DTypeSerdeContext, - fb: &Self::Source<'_>, - ) -> Result { + fn read_flatbuffer(ctx: &(), fb: &Self::Source<'_>) -> Result { match fb.type_type() { fb::Type::Null => Ok(DType::Null), fb::Type::Bool => Ok(DType::Bool( @@ -43,14 +23,6 @@ impl ReadFlatBuffer for DType { fb_primitive.nullability().try_into()?, )) } - fb::Type::Decimal => { - let fb_decimal = fb.type__as_decimal().unwrap(); - Ok(DType::Decimal( - fb_decimal.precision(), - fb_decimal.scale(), - fb_decimal.nullability().try_into()?, - )) - } fb::Type::Binary => Ok(DType::Binary( fb.type__as_binary().unwrap().nullability().try_into()?, )), @@ -73,19 +45,17 @@ impl ReadFlatBuffer for DType { .iter() .map(|n| Arc::new(n.to_string())) .collect::>(); - let fields: Vec = fb_struct + let dtypes: Vec = fb_struct .fields() .unwrap() .iter() .map(|f| DType::read_flatbuffer(ctx, &f)) .collect::>>()?; - Ok(DType::Struct(names, fields)) + Ok(DType::Struct { names, dtypes }) } fb::Type::Composite => { let fb_composite = fb.type__as_composite().unwrap(); - let id = ctx - .find_composite_id(fb_composite.id().unwrap()) - .ok_or_else(|| vortex_err!("Couldn't find composite id"))?; + let id = CompositeID::try_from(fb_composite.id().unwrap())?; Ok(DType::Composite(id, fb_composite.nullability().try_into()?)) } _ => Err(vortex_err!("Unknown DType variant")), diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 0c9411812f..f094ef0ce1 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -3,63 +3,63 @@ use std::hash::Hash; use std::sync::Arc; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use DType::*; -use crate::{CompositeID, PType}; - -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -pub enum Nullability { - #[default] - NonNullable, - Nullable, -} - -impl From for Nullability { - fn from(value: bool) -> Self { - if value { - Nullability::Nullable - } else { - Nullability::NonNullable - } - } -} - -impl From for bool { - fn from(value: Nullability) -> Self { - match value { - Nullability::NonNullable => false, - Nullability::Nullable => true, - } - } -} - -impl Display for Nullability { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Nullability::NonNullable => write!(f, ""), - Nullability::Nullable => write!(f, "?"), - } - } -} +use crate::{CompositeID, Nullability, PType}; pub type FieldNames = Vec>; pub type Metadata = Vec; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum DType { Null, Bool(Nullability), + #[serde(with = "primitive_serde")] Primitive(PType, Nullability), - Decimal(u8, i8, Nullability), Utf8(Nullability), Binary(Nullability), - Struct(FieldNames, Vec), + Struct { + names: FieldNames, + dtypes: Vec, + }, List(Box, Nullability), Composite(CompositeID, Nullability), } +#[cfg(feature = "serde")] +mod primitive_serde { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + use crate::{Nullability, PType}; + + #[derive(Serialize, Deserialize)] + struct PrimitiveSerde { + ptype: PType, + n: Nullability, + } + + pub fn serialize(ptype: &PType, n: &Nullability, serializer: S) -> Result + where + S: Serializer, + { + PrimitiveSerde { + ptype: *ptype, + n: *n, + } + .serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result<(PType, Nullability), D::Error> + where + D: Deserializer<'de>, + { + let PrimitiveSerde { ptype, n } = PrimitiveSerde::deserialize(deserializer)?; + Ok((ptype, n)) + } +} + impl DType { pub const BYTES: DType = Primitive(PType::U8, Nullability::NonNullable); @@ -77,10 +77,9 @@ impl DType { Null => true, Bool(n) => matches!(n, Nullable), Primitive(_, n) => matches!(n, Nullable), - Decimal(_, _, n) => matches!(n, Nullable), Utf8(n) => matches!(n, Nullable), Binary(n) => matches!(n, Nullable), - Struct(_, fs) => fs.iter().all(|f| f.is_nullable()), + Struct { dtypes, .. } => dtypes.iter().all(|dt| dt.is_nullable()), List(_, n) => matches!(n, Nullable), Composite(_, n) => matches!(n, Nullable), } @@ -99,15 +98,17 @@ impl DType { Null => Null, Bool(_) => Bool(nullability), Primitive(p, _) => Primitive(*p, nullability), - Decimal(s, p, _) => Decimal(*s, *p, nullability), Utf8(_) => Utf8(nullability), Binary(_) => Binary(nullability), - Struct(n, fs) => Struct( - n.clone(), - fs.iter().map(|f| f.with_nullability(nullability)).collect(), - ), + Struct { names, dtypes } => Struct { + names: names.clone(), + dtypes: dtypes + .iter() + .map(|dt| dt.with_nullability(nullability)) + .collect(), + }, List(c, _) => List(c.clone(), nullability), - Composite(id, _) => Composite(*id, nullability), + Composite(id, _) => Composite(id.clone(), nullability), } } @@ -122,14 +123,14 @@ impl Display for DType { Null => write!(f, "null"), Bool(n) => write!(f, "bool{}", n), Primitive(p, n) => write!(f, "{}{}", p, n), - Decimal(p, s, n) => write!(f, "decimal({}, {}){}", p, s, n), Utf8(n) => write!(f, "utf8{}", n), Binary(n) => write!(f, "binary{}", n), - Struct(n, dt) => write!( + Struct { names, dtypes } => write!( f, "{{{}}}", - n.iter() - .zip(dt.iter()) + names + .iter() + .zip(dtypes.iter()) .map(|(n, dt)| format!("{}={}", n, dt)) .join(", ") ), diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index eb293e0223..7ab5709623 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -1,25 +1,16 @@ -use std::fmt::{Display, Formatter}; - pub use dtype::*; pub use half; pub use ptype::*; mod deserialize; +pub use composite::*; mod dtype; mod ptype; -mod serde; +// mod serde; +mod composite; +mod nullability; mod serialize; -pub use deserialize::*; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -pub struct CompositeID(pub &'static str); - -impl Display for CompositeID { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} +pub use nullability::*; pub mod flatbuffers { #[allow(unused_imports)] diff --git a/vortex-dtype/src/nullability.rs b/vortex-dtype/src/nullability.rs new file mode 100644 index 0000000000..b2fa7d55cb --- /dev/null +++ b/vortex-dtype/src/nullability.rs @@ -0,0 +1,62 @@ +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub enum Nullability { + #[default] + NonNullable, + Nullable, +} + +impl From for Nullability { + fn from(value: bool) -> Self { + if value { + Nullability::Nullable + } else { + Nullability::NonNullable + } + } +} + +impl From for bool { + fn from(value: Nullability) -> Self { + match value { + Nullability::NonNullable => false, + Nullability::Nullable => true, + } + } +} + +impl Display for Nullability { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Nullability::NonNullable => write!(f, ""), + Nullability::Nullable => write!(f, "?"), + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for Nullability { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Nullability::NonNullable => serializer.serialize_bool(false), + Nullability::Nullable => serializer.serialize_bool(true), + } + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for Nullability { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(match bool::deserialize(deserializer)? { + true => Nullability::Nullable, + false => Nullability::NonNullable, + }) + } +} diff --git a/vortex-dtype/src/serde.rs b/vortex-dtype/src/serde.rs index a5bbd242df..a69a734941 100644 --- a/vortex-dtype/src/serde.rs +++ b/vortex-dtype/src/serde.rs @@ -1,59 +1,84 @@ #![cfg(feature = "serde")] - -use flatbuffers::root; -use serde::de::{DeserializeSeed, Visitor}; +/// We hand-write the serde implementation for DType so we can retain more ergonomic tuple variants. +use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; -use crate::DType; -use crate::{flatbuffers as fb, DTypeSerdeContext}; +use crate::{DType, Nullability}; impl Serialize for DType { fn serialize(&self, serializer: S) -> Result where S: Serializer, { - self.with_flatbuffer_bytes(|bytes| serializer.serialize_bytes(bytes)) + match self { + DType::Null => serializer + .serialize_map(Some(1))? + .serialize_entry("type", "null"), + DType::Bool(n) => serializer + .serialize_map(Some(2))? + .serialize_entry("type", "null")? + .serialize_entry("n", n), + DType::Primitive(ptype, n) => serializer + .serialize_map(Some(3))? + .serialize_entry("type", "primitive")? + .serialize_entry("ptype", *ptype)? + .serialize_entry("n", n), + DType::Utf8(n) => serializer + .serialize_map(Some(2))? + .serialize_entry("type", "utf8")? + .serialize_entry("n", n), + DType::Binary(n) => serializer + .serialize_map(Some(2))? + .serialize_entry("type", "binary")? + .serialize_entry("n", n), + DType::Struct { names, dtypes } => serializer + .serialize_map(Some(3))? + .serialize_entry("type", "struct")? + .serialize_entry("names", names)? + .serialize_entry("dtypes", dtypes), + DType::List(element, n) => serializer + .serialize_map(Some(3))? + .serialize_entry("type", "primitive")? + .serialize_entry("element", element)? + .serialize_entry("n", n), + DType::Composite(id, n) => serializer + .serialize_map(Some(3))? + .serialize_entry("type", "composite")? + .serialize_entry("id", id)? + .serialize_entry("n", n), + } } } -struct DTypeDeserializer(DTypeSerdeContext); - -impl<'de> Visitor<'de> for DTypeDeserializer { - type Value = DType; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a vortex dtype") - } - - fn visit_bytes(self, v: &[u8]) -> Result +impl<'de> Deserialize<'de> for DType { + fn deserialize(deserializer: D) -> Result where - E: serde::de::Error, + D: Deserializer<'de>, { - let fb = root::(v).map_err(E::custom)?; - DType::read_flatbuffer(&self.0, &fb).map_err(E::custom) + todo!() } } -impl<'de> DeserializeSeed<'de> for DTypeSerdeContext { - type Value = DType; - - fn deserialize(self, deserializer: D) -> Result +impl Serialize for Nullability { + fn serialize(&self, serializer: S) -> Result where - D: Deserializer<'de>, + S: Serializer, { - deserializer.deserialize_bytes(DTypeDeserializer(self)) + match self { + Nullability::NonNullable => serializer.serialize_bool(false), + Nullability::Nullable => serializer.serialize_bool(true), + } } } -// TODO(ngates): Remove this trait in favour of storing e.g. IdxType which doesn't require -// the context for composite types. -impl<'de> Deserialize<'de> for DType { +impl<'de> Deserialize<'de> for Nullability { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - let ctx = DTypeSerdeContext::new(vec![]); - deserializer.deserialize_bytes(DTypeDeserializer(ctx)) + Ok(match bool::deserialize(deserializer)? { + true => Nullability::Nullable, + false => Nullability::NonNullable, + }) } } diff --git a/vortex-dtype/src/serialize.rs b/vortex-dtype/src/serialize.rs index 1aa485807a..d5ad6d8220 100644 --- a/vortex-dtype/src/serialize.rs +++ b/vortex-dtype/src/serialize.rs @@ -31,15 +31,6 @@ impl WriteFlatBuffer for DType { }, ) .as_union_value(), - DType::Decimal(p, s, n) => fb::Decimal::create( - fbb, - &fb::DecimalArgs { - precision: *p, - scale: *s, - nullability: n.into(), - }, - ) - .as_union_value(), DType::Utf8(n) => fb::Utf8::create( fbb, &fb::Utf8Args { @@ -54,7 +45,7 @@ impl WriteFlatBuffer for DType { }, ) .as_union_value(), - DType::Struct(names, dtypes) => { + DType::Struct { names, dtypes } => { let names = names .iter() .map(|n| fbb.create_string(n.as_str())) @@ -81,7 +72,7 @@ impl WriteFlatBuffer for DType { .as_union_value() } DType::Composite(id, n) => { - let id = Some(fbb.create_string(id.0)); + let id = Some(fbb.create_string(id.as_ref())); fb::Composite::create( fbb, &fb::CompositeArgs { @@ -97,10 +88,9 @@ impl WriteFlatBuffer for DType { DType::Null => fb::Type::Null, DType::Bool(_) => fb::Type::Bool, DType::Primitive(..) => fb::Type::Primitive, - DType::Decimal(..) => fb::Type::Decimal, DType::Utf8(_) => fb::Type::Utf8, DType::Binary(_) => fb::Type::Binary, - DType::Struct(..) => fb::Type::Struct_, + DType::Struct { .. } => fb::Type::Struct_, DType::List(..) => fb::Type::List, DType::Composite(..) => fb::Type::Composite, }; @@ -180,15 +170,12 @@ mod test { use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; use crate::{flatbuffers as fb, PType}; - use crate::{DType, DTypeSerdeContext, Nullability}; + use crate::{DType, Nullability}; fn roundtrip_dtype(dtype: DType) { let bytes = dtype.with_flatbuffer_bytes(|bytes| bytes.to_vec()); - let deserialized = DType::read_flatbuffer( - &DTypeSerdeContext::new(vec![]), - &root::(&bytes).unwrap(), - ) - .unwrap(); + let deserialized = + DType::read_flatbuffer(&(), &root::(&bytes).unwrap()).unwrap(); assert_eq!(dtype, deserialized); } @@ -197,19 +184,18 @@ mod test { roundtrip_dtype(DType::Null); roundtrip_dtype(DType::Bool(Nullability::NonNullable)); roundtrip_dtype(DType::Primitive(PType::U64, Nullability::NonNullable)); - roundtrip_dtype(DType::Decimal(18, 9, Nullability::NonNullable)); roundtrip_dtype(DType::Binary(Nullability::NonNullable)); roundtrip_dtype(DType::Utf8(Nullability::NonNullable)); roundtrip_dtype(DType::List( Box::new(DType::Primitive(PType::F32, Nullability::Nullable)), Nullability::NonNullable, )); - roundtrip_dtype(DType::Struct( - vec![Arc::new("strings".into()), Arc::new("ints".into())], - vec![ + roundtrip_dtype(DType::Struct { + names: vec![Arc::new("strings".into()), Arc::new("ints".into())], + dtypes: vec![ DType::Utf8(Nullability::NonNullable), DType::Primitive(PType::U16, Nullability::Nullable), ], - )) + }) } } diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index 4cba0d85e7..c6239894d0 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -7,7 +7,6 @@ use flatbuffers::{root, root_unchecked}; use itertools::Itertools; use nougat::gat; use vortex::array::chunked::ChunkedArray; -use vortex::array::composite::VORTEX_COMPOSITE_EXTENSIONS; use vortex::array::primitive::PrimitiveArray; use vortex::buffer::Buffer; use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; @@ -17,7 +16,7 @@ use vortex::stats::{ArrayStatistics, Stat}; use vortex::{ Array, ArrayDType, ArrayView, IntoArray, OwnedArray, SerdeContext, ToArray, ToStatic, }; -use vortex_dtype::{match_each_integer_ptype, DType, DTypeSerdeContext}; +use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::ReadFlatBuffer; @@ -103,11 +102,8 @@ impl FallibleLendingIterator for StreamReader { .header_as_schema() .unwrap(); - // TODO(ngates): construct this from the SerdeContext. - let dtype_ctx = - DTypeSerdeContext::new(VORTEX_COMPOSITE_EXTENSIONS.iter().map(|e| e.id()).collect()); let dtype = DType::read_flatbuffer( - &dtype_ctx, + &(), &schema_msg .dtype() .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?, diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 30d4be5d8f..b34437bd02 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -126,10 +126,9 @@ impl Scalar { DType::Null => NullScalar::new().into(), DType::Bool(_) => BoolScalar::none().into(), DType::Primitive(p, _) => PrimitiveScalar::none_from_ptype(*p).into(), - DType::Decimal(..) => unimplemented!("DecimalScalar"), DType::Utf8(_) => Utf8Scalar::none().into(), DType::Binary(_) => BinaryScalar::none().into(), - DType::Struct(..) => StructScalar::new(dtype.clone(), vec![]).into(), + DType::Struct { .. } => StructScalar::new(dtype.clone(), vec![]).into(), DType::List(..) => ListScalar::new(dtype.clone(), None).into(), DType::Composite(..) => unimplemented!("CompositeScalar"), } diff --git a/vortex-scalar/src/serde.rs b/vortex-scalar/src/serde.rs index 1d77288a83..18d25af7e7 100644 --- a/vortex-scalar/src/serde.rs +++ b/vortex-scalar/src/serde.rs @@ -4,7 +4,7 @@ use flatbuffers::{root, FlatBufferBuilder, WIPOffset}; use serde::de::Visitor; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use vortex_dtype::match_each_native_ptype; -use vortex_dtype::{DTypeSerdeContext, Nullability}; +use vortex_dtype::Nullability; use vortex_error::{vortex_bail, VortexError}; use vortex_flatbuffers::{FlatBufferRoot, FlatBufferToBytes, ReadFlatBuffer, WriteFlatBuffer}; @@ -90,14 +90,11 @@ impl WriteFlatBuffer for Scalar { } } -impl ReadFlatBuffer for Scalar { +impl ReadFlatBuffer<()> for Scalar { type Source<'a> = fb::Scalar<'a>; type Error = VortexError; - fn read_flatbuffer( - _ctx: &DTypeSerdeContext, - fb: &Self::Source<'_>, - ) -> Result { + fn read_flatbuffer(_ctx: &(), fb: &Self::Source<'_>) -> Result { let nullability = Nullability::from(fb.nullability()); match fb.type_type() { fb::Type::Binary => { @@ -153,7 +150,7 @@ impl Serialize for Scalar { } } -struct ScalarDeserializer(DTypeSerdeContext); +struct ScalarDeserializer; impl<'de> Visitor<'de> for ScalarDeserializer { type Value = Scalar; @@ -167,7 +164,7 @@ impl<'de> Visitor<'de> for ScalarDeserializer { E: serde::de::Error, { let fb = root::(v).map_err(E::custom)?; - Scalar::read_flatbuffer(&self.0, &fb).map_err(E::custom) + Scalar::read_flatbuffer(&(), &fb).map_err(E::custom) } } @@ -177,8 +174,7 @@ impl<'de> Deserialize<'de> for Scalar { where D: Deserializer<'de>, { - let ctx = DTypeSerdeContext::new(vec![]); - deserializer.deserialize_bytes(ScalarDeserializer(ctx)) + deserializer.deserialize_bytes(ScalarDeserializer) } } diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index 50abe21673..655510b415 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -31,16 +31,16 @@ impl StructScalar { } pub fn names(&self) -> &[Arc] { - let DType::Struct(ns, _) = self.dtype() else { + let DType::Struct { names, .. } = self.dtype() else { unreachable!("Not a scalar dtype"); }; - ns.as_slice() + names.as_slice() } pub fn cast(&self, dtype: &DType) -> VortexResult { match dtype { - DType::Struct(names, field_dtypes) => { - if field_dtypes.len() != self.values.len() { + DType::Struct { names, dtypes } => { + if dtypes.len() != self.values.len() { vortex_bail!( MismatchedTypes: format!("Struct with {} fields", self.values.len()), dtype @@ -50,14 +50,14 @@ impl StructScalar { let new_fields: Vec = self .values .iter() - .zip_eq(field_dtypes.iter()) + .zip_eq(dtypes.iter()) .map(|(field, field_dtype)| field.cast(field_dtype)) .try_collect()?; - let new_type = DType::Struct( - names.clone(), - new_fields.iter().map(|x| x.dtype().clone()).collect(), - ); + let new_type = DType::Struct { + names: names.clone(), + dtypes: new_fields.iter().map(|x| x.dtype().clone()).collect(), + }; Ok(StructScalar::new(new_type, new_fields).into()) } _ => Err(vortex_err!(MismatchedTypes: "struct", dtype)), @@ -81,7 +81,7 @@ impl PartialOrd for StructScalar { impl Display for StructScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let DType::Struct(names, _) = self.dtype() else { + let DType::Struct { names, .. } = self.dtype() else { unreachable!() }; for (n, v) in names.iter().zip(self.values.iter()) { From 5c9268bd353c603a78c42d5151ba6ae36f05885c Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 30 Apr 2024 13:46:09 +0100 Subject: [PATCH 2/4] DType Structs --- vortex-dtype/src/composite.rs | 2 +- vortex-dtype/src/deserialize.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vortex-dtype/src/composite.rs b/vortex-dtype/src/composite.rs index 118a0ed3a8..49bf96acc9 100644 --- a/vortex-dtype/src/composite.rs +++ b/vortex-dtype/src/composite.rs @@ -26,7 +26,7 @@ impl<'a> TryFrom<&'a str> for CompositeID { impl AsRef for CompositeID { fn as_ref(&self) -> &str { - self.0.as_ref() + self.0 } } diff --git a/vortex-dtype/src/deserialize.rs b/vortex-dtype/src/deserialize.rs index f3fb86d5d2..9332909b66 100644 --- a/vortex-dtype/src/deserialize.rs +++ b/vortex-dtype/src/deserialize.rs @@ -10,7 +10,7 @@ impl ReadFlatBuffer<()> for DType { type Source<'a> = fb::DType<'a>; type Error = VortexError; - fn read_flatbuffer(ctx: &(), fb: &Self::Source<'_>) -> Result { + fn read_flatbuffer(_ctx: &(), fb: &Self::Source<'_>) -> Result { match fb.type_type() { fb::Type::Null => Ok(DType::Null), fb::Type::Bool => Ok(DType::Bool( @@ -31,7 +31,7 @@ impl ReadFlatBuffer<()> for DType { )), fb::Type::List => { let fb_list = fb.type__as_list().unwrap(); - let element_dtype = DType::read_flatbuffer(ctx, &fb_list.element_type().unwrap())?; + let element_dtype = DType::read_flatbuffer(&(), &fb_list.element_type().unwrap())?; Ok(DType::List( Box::new(element_dtype), fb_list.nullability().try_into()?, @@ -49,7 +49,7 @@ impl ReadFlatBuffer<()> for DType { .fields() .unwrap() .iter() - .map(|f| DType::read_flatbuffer(ctx, &f)) + .map(|f| DType::read_flatbuffer(&(), &f)) .collect::>>()?; Ok(DType::Struct { names, dtypes }) } From cd9cb925cbb2a957ce108a127dd7f0f7a8f0b20e Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 30 Apr 2024 13:47:58 +0100 Subject: [PATCH 3/4] DType Structs --- vortex-dtype/src/composite.rs | 12 ------ vortex-dtype/src/lib.rs | 6 +-- vortex-dtype/src/nullability.rs | 26 ------------- vortex-dtype/src/serde.rs | 69 ++++++--------------------------- 4 files changed, 15 insertions(+), 98 deletions(-) diff --git a/vortex-dtype/src/composite.rs b/vortex-dtype/src/composite.rs index 49bf96acc9..0ce701e7d3 100644 --- a/vortex-dtype/src/composite.rs +++ b/vortex-dtype/src/composite.rs @@ -1,7 +1,6 @@ use std::fmt::{Display, Formatter}; use linkme::distributed_slice; -use serde::{Deserialize, Deserializer}; use vortex_error::{vortex_err, VortexError}; #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] @@ -36,17 +35,6 @@ impl Display for CompositeID { } } -#[cfg(feature = "serde")] -impl<'de> Deserialize<'de> for CompositeID { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - CompositeID::try_from(<&'de str>::deserialize(deserializer)?) - .map_err(serde::de::Error::custom) - } -} - pub trait CompositeDType { fn id(&self) -> &'static str; } diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index 7ab5709623..80d7427c78 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -3,11 +3,11 @@ pub use half; pub use ptype::*; mod deserialize; pub use composite::*; -mod dtype; -mod ptype; -// mod serde; mod composite; +mod dtype; mod nullability; +mod ptype; +mod serde; mod serialize; pub use nullability::*; diff --git a/vortex-dtype/src/nullability.rs b/vortex-dtype/src/nullability.rs index b2fa7d55cb..e8416215bd 100644 --- a/vortex-dtype/src/nullability.rs +++ b/vortex-dtype/src/nullability.rs @@ -34,29 +34,3 @@ impl Display for Nullability { } } } - -#[cfg(feature = "serde")] -impl serde::Serialize for Nullability { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self { - Nullability::NonNullable => serializer.serialize_bool(false), - Nullability::Nullable => serializer.serialize_bool(true), - } - } -} - -#[cfg(feature = "serde")] -impl<'de> serde::Deserialize<'de> for Nullability { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - Ok(match bool::deserialize(deserializer)? { - true => Nullability::Nullable, - false => Nullability::NonNullable, - }) - } -} diff --git a/vortex-dtype/src/serde.rs b/vortex-dtype/src/serde.rs index a69a734941..55557cbb8d 100644 --- a/vortex-dtype/src/serde.rs +++ b/vortex-dtype/src/serde.rs @@ -1,63 +1,7 @@ #![cfg(feature = "serde")] -/// We hand-write the serde implementation for DType so we can retain more ergonomic tuple variants. -use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::{DType, Nullability}; - -impl Serialize for DType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - DType::Null => serializer - .serialize_map(Some(1))? - .serialize_entry("type", "null"), - DType::Bool(n) => serializer - .serialize_map(Some(2))? - .serialize_entry("type", "null")? - .serialize_entry("n", n), - DType::Primitive(ptype, n) => serializer - .serialize_map(Some(3))? - .serialize_entry("type", "primitive")? - .serialize_entry("ptype", *ptype)? - .serialize_entry("n", n), - DType::Utf8(n) => serializer - .serialize_map(Some(2))? - .serialize_entry("type", "utf8")? - .serialize_entry("n", n), - DType::Binary(n) => serializer - .serialize_map(Some(2))? - .serialize_entry("type", "binary")? - .serialize_entry("n", n), - DType::Struct { names, dtypes } => serializer - .serialize_map(Some(3))? - .serialize_entry("type", "struct")? - .serialize_entry("names", names)? - .serialize_entry("dtypes", dtypes), - DType::List(element, n) => serializer - .serialize_map(Some(3))? - .serialize_entry("type", "primitive")? - .serialize_entry("element", element)? - .serialize_entry("n", n), - DType::Composite(id, n) => serializer - .serialize_map(Some(3))? - .serialize_entry("type", "composite")? - .serialize_entry("id", id)? - .serialize_entry("n", n), - } - } -} - -impl<'de> Deserialize<'de> for DType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - todo!() - } -} +use crate::{CompositeID, Nullability}; impl Serialize for Nullability { fn serialize(&self, serializer: S) -> Result @@ -82,3 +26,14 @@ impl<'de> Deserialize<'de> for Nullability { }) } } + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for CompositeID { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + CompositeID::try_from(<&'de str>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) + } +} From 93b45f78b63fc004ac46aba29899f7094aae3ace Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 30 Apr 2024 14:02:39 +0100 Subject: [PATCH 4/4] DType Structs --- vortex-dtype/src/dtype.rs | 34 +--------------------------------- vortex-dtype/src/serde.rs | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index f094ef0ce1..c7314e7ab5 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -16,7 +16,7 @@ pub type Metadata = Vec; pub enum DType { Null, Bool(Nullability), - #[serde(with = "primitive_serde")] + #[serde(with = "crate::serde::dtype_primitive")] Primitive(PType, Nullability), Utf8(Nullability), Binary(Nullability), @@ -28,38 +28,6 @@ pub enum DType { Composite(CompositeID, Nullability), } -#[cfg(feature = "serde")] -mod primitive_serde { - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - use crate::{Nullability, PType}; - - #[derive(Serialize, Deserialize)] - struct PrimitiveSerde { - ptype: PType, - n: Nullability, - } - - pub fn serialize(ptype: &PType, n: &Nullability, serializer: S) -> Result - where - S: Serializer, - { - PrimitiveSerde { - ptype: *ptype, - n: *n, - } - .serialize(serializer) - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result<(PType, Nullability), D::Error> - where - D: Deserializer<'de>, - { - let PrimitiveSerde { ptype, n } = PrimitiveSerde::deserialize(deserializer)?; - Ok((ptype, n)) - } -} - impl DType { pub const BYTES: DType = Primitive(PType::U8, Nullability::NonNullable); diff --git a/vortex-dtype/src/serde.rs b/vortex-dtype/src/serde.rs index 55557cbb8d..1580b525b3 100644 --- a/vortex-dtype/src/serde.rs +++ b/vortex-dtype/src/serde.rs @@ -37,3 +37,36 @@ impl<'de> Deserialize<'de> for CompositeID { .map_err(serde::de::Error::custom) } } + +/// Implement custom serde to retain the ergonomics of a tuple enum variant. +/// Essentially, we use this wrapper to name the fields of the DType::Primitive enum variant. +pub mod dtype_primitive { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + use crate::{Nullability, PType}; + + #[derive(Serialize, Deserialize)] + struct PrimitiveSerde { + ptype: PType, + n: Nullability, + } + + pub fn serialize(ptype: &PType, n: &Nullability, serializer: S) -> Result + where + S: Serializer, + { + PrimitiveSerde { + ptype: *ptype, + n: *n, + } + .serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result<(PType, Nullability), D::Error> + where + D: Deserializer<'de>, + { + let PrimitiveSerde { ptype, n } = PrimitiveSerde::deserialize(deserializer)?; + Ok((ptype, n)) + } +}