diff --git a/bench-vortex/src/lib.rs b/bench-vortex/src/lib.rs index d8621fb82b..0f7baad33e 100644 --- a/bench-vortex/src/lib.rs +++ b/bench-vortex/src/lib.rs @@ -184,13 +184,13 @@ pub struct CompressionRunStats { impl CompressionRunStats { pub fn to_results(&self, dataset_name: String) -> Vec { - let DType::Struct(ns, fs) = &self.schema else { + let DType::Struct { names, dtypes } = &self.schema else { unreachable!() }; self.compressed_sizes .iter() - .zip_eq(ns.iter().zip_eq(fs)) + .zip_eq(names.iter().zip_eq(dtypes)) .map( |(&size, (column_name, column_type))| CompressionRunResults { dataset_name: dataset_name.clone(), diff --git a/bench-vortex/src/vortex_utils.rs b/bench-vortex/src/vortex_utils.rs index b266ac07af..9ac215f9eb 100644 --- a/bench-vortex/src/vortex_utils.rs +++ b/bench-vortex/src/vortex_utils.rs @@ -16,11 +16,11 @@ pub fn vortex_chunk_sizes(path: &Path) -> VortexResult { let file = File::open(path)?; let total_compressed_size = file.metadata()?.size(); let vortex = open_vortex(path)?; - let DType::Struct(ns, _) = vortex.dtype() else { + let DType::Struct { names, .. } = vortex.dtype() else { unreachable!() }; - let mut compressed_sizes = vec![0; ns.len()]; + let mut compressed_sizes = vec![0; names.len()]; let chunked_array = ChunkedArray::try_from(vortex).unwrap(); for chunk in chunked_array.chunks() { let struct_arr = StructArray::try_from(chunk).unwrap(); diff --git a/vortex-array/src/array/composite/array.rs b/vortex-array/src/array/composite/array.rs index f3a1897cad..2db8f73d8a 100644 --- a/vortex-array/src/array/composite/array.rs +++ b/vortex-array/src/array/composite/array.rs @@ -1,6 +1,6 @@ use flatbuffers::root; use vortex_dtype::flatbuffers as fb; -use vortex_dtype::{CompositeID, DTypeSerdeContext}; +use vortex_dtype::CompositeID; use vortex_error::{vortex_err, VortexResult}; use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; @@ -34,7 +34,7 @@ impl TrySerializeArrayMetadata for CompositeMetadata { let mut fb = flexbuffers::Builder::default(); { let mut elems = fb.start_vector(); - elems.push(self.ext.id().0); + elems.push(self.ext.id().as_ref()); self.underlying_dtype .with_flatbuffer_bytes(|b| elems.push(flexbuffers::Blob(b))); elems.push(flexbuffers::Blob(self.underlying_metadata.as_ref())); @@ -53,9 +53,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata { .ok_or_else(|| vortex_err!("Unrecognized composite extension: {}", ext_id))?; let dtype_blob = elems.index(1).expect("missing dtype").as_blob(); - let ctx = DTypeSerdeContext::new(vec![]); // FIXME: composite_ids let underlying_dtype = DType::read_flatbuffer( - &ctx, + &(), &root::(dtype_blob.0).expect("invalid dtype"), )?; @@ -77,8 +76,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata { impl<'a> CompositeArray<'a> { pub fn new(id: CompositeID, metadata: Arc<[u8]>, underlying: Array<'a>) -> Self { - let dtype = DType::Composite(id, underlying.dtype().is_nullable().into()); - let ext = find_extension(id.0).expect("Unrecognized composite extension"); + let dtype = DType::Composite(id.clone(), underlying.dtype().is_nullable().into()); + let ext = find_extension(id.as_ref()).expect("Unrecognized composite extension"); Self::try_from_parts( dtype, CompositeMetadata { @@ -101,7 +100,7 @@ impl CompositeArray<'_> { #[inline] pub fn extension(&self) -> CompositeExtensionRef { - find_extension(self.id().0).expect("Unrecognized composite extension") + find_extension(self.id().as_ref()).expect("Unrecognized composite extension") } pub fn underlying_metadata(&self) -> &Arc<[u8]> { diff --git a/vortex-array/src/array/composite/mod.rs b/vortex-array/src/array/composite/mod.rs index 09f6bb5a5f..3cfa6de678 100644 --- a/vortex-array/src/array/composite/mod.rs +++ b/vortex-array/src/array/composite/mod.rs @@ -13,7 +13,7 @@ pub static VORTEX_COMPOSITE_EXTENSIONS: [&'static dyn CompositeExtension] = [..] pub fn find_extension(id: &str) -> Option<&'static dyn CompositeExtension> { VORTEX_COMPOSITE_EXTENSIONS .iter() - .find(|ext| ext.id().0 == id) + .find(|ext| ext.id().as_ref() == id) .copied() } diff --git a/vortex-array/src/array/composite/typed.rs b/vortex-array/src/array/composite/typed.rs index 7b4a158e5e..33983a3e5b 100644 --- a/vortex-array/src/array/composite/typed.rs +++ b/vortex-array/src/array/composite/typed.rs @@ -88,7 +88,7 @@ macro_rules! impl_composite { pub struct [<$T Extension>]; impl [<$T Extension>] { - pub const ID: CompositeID = CompositeID($id); + pub const ID: CompositeID = CompositeID::new($id); pub fn dtype(nullability: Nullability) -> DType { DType::Composite(Self::ID, nullability) diff --git a/vortex-array/src/array/struct/mod.rs b/vortex-array/src/array/struct/mod.rs index b07e07b0c6..e405772d92 100644 --- a/vortex-array/src/array/struct/mod.rs +++ b/vortex-array/src/array/struct/mod.rs @@ -19,25 +19,25 @@ pub struct StructMetadata { impl StructArray<'_> { pub fn child(&self, idx: usize) -> Option { - let DType::Struct(_, fields) = self.dtype() else { + let DType::Struct { dtypes, .. } = self.dtype() else { unreachable!() }; - let dtype = fields.get(idx)?; + let dtype = dtypes.get(idx)?; self.array().child(idx, dtype) } pub fn names(&self) -> &FieldNames { - let DType::Struct(names, _fields) = self.dtype() else { + let DType::Struct { names, .. } = self.dtype() else { unreachable!() }; names } pub fn fields(&self) -> &[DType] { - let DType::Struct(_names, fields) = self.dtype() else { + let DType::Struct { dtypes, .. } = self.dtype() else { unreachable!() }; - fields.as_slice() + dtypes.as_slice() } pub fn nfields(&self) -> usize { @@ -61,9 +61,9 @@ impl StructArray<'_> { vortex_bail!("Expected all struct fields to have length {}", length); } - let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect(); + let dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect(); Self::try_from_parts( - DType::Struct(names, field_dtypes), + DType::Struct { names, dtypes }, StructMetadata { length }, fields.into_iter().map(|a| a.into_array_data()).collect(), HashMap::default(), diff --git a/vortex-array/src/arrow/dtype.rs b/vortex-array/src/arrow/dtype.rs index 5a06ea18d2..f2f7018c7e 100644 --- a/vortex-array/src/arrow/dtype.rs +++ b/vortex-array/src/arrow/dtype.rs @@ -40,18 +40,18 @@ impl TryFromArrowType<&DataType> for PType { impl FromArrowType for DType { fn from_arrow(value: SchemaRef) -> Self { - DType::Struct( - value + DType::Struct { + names: value .fields() .iter() .map(|f| Arc::new(f.name().clone())) .collect(), - value + dtypes: value .fields() .iter() .map(|f| DType::from_arrow(f.as_ref())) .collect_vec(), - ) + } } } @@ -81,13 +81,13 @@ impl FromArrowType<&Field> for DType { DataType::List(e) | DataType::LargeList(e) => { List(Box::new(DType::from_arrow(e.as_ref())), nullability) } - DataType::Struct(f) => Struct( - f.iter().map(|f| Arc::new(f.name().clone())).collect(), - f.iter() + DataType::Struct(f) => Struct { + names: f.iter().map(|f| Arc::new(f.name().clone())).collect(), + dtypes: f + .iter() .map(|f| DType::from_arrow(f.as_ref())) .collect_vec(), - ), - DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(*p, *s, nullability), + }, _ => unimplemented!("Arrow data type not yet supported: {:?}", field.data_type()), } } diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index e129426438..21effb90f6 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -21,7 +21,7 @@ half = { workspace = true } itertools = { workspace = true } linkme = { workspace = true } num-traits = { workspace = true } -serde = { workspace = true, optional = true } +serde = { workspace = true, optional = true, features = ["rc"] } thiserror = { workspace = true } vortex-error = { path = "../vortex-error" } vortex-flatbuffers = { path = "../vortex-flatbuffers" } diff --git a/vortex-dtype/src/composite.rs b/vortex-dtype/src/composite.rs new file mode 100644 index 0000000000..0ce701e7d3 --- /dev/null +++ b/vortex-dtype/src/composite.rs @@ -0,0 +1,50 @@ +use std::fmt::{Display, Formatter}; + +use linkme::distributed_slice; +use vortex_error::{vortex_err, VortexError}; + +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize))] +pub struct CompositeID(&'static str); + +impl CompositeID { + pub const fn new(id: &'static str) -> Self { + Self(id) + } +} + +impl<'a> TryFrom<&'a str> for CompositeID { + type Error = VortexError; + + fn try_from(value: &'a str) -> Result { + find_composite_dtype(value) + .map(|cdt| CompositeID(cdt.id())) + .ok_or_else(|| vortex_err!("CompositeID not found for the given id: {}", value)) + } +} + +impl AsRef for CompositeID { + fn as_ref(&self) -> &str { + self.0 + } +} + +impl Display for CompositeID { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub trait CompositeDType { + fn id(&self) -> &'static str; +} + +#[distributed_slice] +pub static VORTEX_COMPOSITE_DTYPES: [&'static dyn CompositeDType] = [..]; + +pub fn find_composite_dtype(id: &str) -> Option<&'static dyn CompositeDType> { + VORTEX_COMPOSITE_DTYPES + .iter() + .find(|ext| ext.id() == id) + .copied() +} diff --git a/vortex-dtype/src/deserialize.rs b/vortex-dtype/src/deserialize.rs index 4ef3b2a406..9332909b66 100644 --- a/vortex-dtype/src/deserialize.rs +++ b/vortex-dtype/src/deserialize.rs @@ -6,31 +6,11 @@ use vortex_flatbuffers::ReadFlatBuffer; use crate::{flatbuffers as fb, Nullability}; use crate::{CompositeID, DType}; -#[allow(dead_code)] -pub struct DTypeSerdeContext { - composite_ids: Arc<[CompositeID]>, -} - -impl DTypeSerdeContext { - pub fn new(composite_ids: Vec) -> Self { - Self { - composite_ids: composite_ids.into(), - } - } - - pub fn find_composite_id(&self, id: &str) -> Option { - self.composite_ids.iter().find(|c| c.0 == id).copied() - } -} - -impl ReadFlatBuffer for DType { +impl ReadFlatBuffer<()> for DType { type Source<'a> = fb::DType<'a>; type Error = VortexError; - fn read_flatbuffer( - ctx: &DTypeSerdeContext, - fb: &Self::Source<'_>, - ) -> Result { + fn read_flatbuffer(_ctx: &(), fb: &Self::Source<'_>) -> Result { match fb.type_type() { fb::Type::Null => Ok(DType::Null), fb::Type::Bool => Ok(DType::Bool( @@ -43,14 +23,6 @@ impl ReadFlatBuffer for DType { fb_primitive.nullability().try_into()?, )) } - fb::Type::Decimal => { - let fb_decimal = fb.type__as_decimal().unwrap(); - Ok(DType::Decimal( - fb_decimal.precision(), - fb_decimal.scale(), - fb_decimal.nullability().try_into()?, - )) - } fb::Type::Binary => Ok(DType::Binary( fb.type__as_binary().unwrap().nullability().try_into()?, )), @@ -59,7 +31,7 @@ impl ReadFlatBuffer for DType { )), fb::Type::List => { let fb_list = fb.type__as_list().unwrap(); - let element_dtype = DType::read_flatbuffer(ctx, &fb_list.element_type().unwrap())?; + let element_dtype = DType::read_flatbuffer(&(), &fb_list.element_type().unwrap())?; Ok(DType::List( Box::new(element_dtype), fb_list.nullability().try_into()?, @@ -73,19 +45,17 @@ impl ReadFlatBuffer for DType { .iter() .map(|n| Arc::new(n.to_string())) .collect::>(); - let fields: Vec = fb_struct + let dtypes: Vec = fb_struct .fields() .unwrap() .iter() - .map(|f| DType::read_flatbuffer(ctx, &f)) + .map(|f| DType::read_flatbuffer(&(), &f)) .collect::>>()?; - Ok(DType::Struct(names, fields)) + Ok(DType::Struct { names, dtypes }) } fb::Type::Composite => { let fb_composite = fb.type__as_composite().unwrap(); - let id = ctx - .find_composite_id(fb_composite.id().unwrap()) - .ok_or_else(|| vortex_err!("Couldn't find composite id"))?; + let id = CompositeID::try_from(fb_composite.id().unwrap())?; Ok(DType::Composite(id, fb_composite.nullability().try_into()?)) } _ => Err(vortex_err!("Unknown DType variant")), diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 0c9411812f..c7314e7ab5 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -3,59 +3,27 @@ use std::hash::Hash; use std::sync::Arc; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use DType::*; -use crate::{CompositeID, PType}; - -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -pub enum Nullability { - #[default] - NonNullable, - Nullable, -} - -impl From for Nullability { - fn from(value: bool) -> Self { - if value { - Nullability::Nullable - } else { - Nullability::NonNullable - } - } -} - -impl From for bool { - fn from(value: Nullability) -> Self { - match value { - Nullability::NonNullable => false, - Nullability::Nullable => true, - } - } -} - -impl Display for Nullability { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Nullability::NonNullable => write!(f, ""), - Nullability::Nullable => write!(f, "?"), - } - } -} +use crate::{CompositeID, Nullability, PType}; pub type FieldNames = Vec>; pub type Metadata = Vec; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum DType { Null, Bool(Nullability), + #[serde(with = "crate::serde::dtype_primitive")] Primitive(PType, Nullability), - Decimal(u8, i8, Nullability), Utf8(Nullability), Binary(Nullability), - Struct(FieldNames, Vec), + Struct { + names: FieldNames, + dtypes: Vec, + }, List(Box, Nullability), Composite(CompositeID, Nullability), } @@ -77,10 +45,9 @@ impl DType { Null => true, Bool(n) => matches!(n, Nullable), Primitive(_, n) => matches!(n, Nullable), - Decimal(_, _, n) => matches!(n, Nullable), Utf8(n) => matches!(n, Nullable), Binary(n) => matches!(n, Nullable), - Struct(_, fs) => fs.iter().all(|f| f.is_nullable()), + Struct { dtypes, .. } => dtypes.iter().all(|dt| dt.is_nullable()), List(_, n) => matches!(n, Nullable), Composite(_, n) => matches!(n, Nullable), } @@ -99,15 +66,17 @@ impl DType { Null => Null, Bool(_) => Bool(nullability), Primitive(p, _) => Primitive(*p, nullability), - Decimal(s, p, _) => Decimal(*s, *p, nullability), Utf8(_) => Utf8(nullability), Binary(_) => Binary(nullability), - Struct(n, fs) => Struct( - n.clone(), - fs.iter().map(|f| f.with_nullability(nullability)).collect(), - ), + Struct { names, dtypes } => Struct { + names: names.clone(), + dtypes: dtypes + .iter() + .map(|dt| dt.with_nullability(nullability)) + .collect(), + }, List(c, _) => List(c.clone(), nullability), - Composite(id, _) => Composite(*id, nullability), + Composite(id, _) => Composite(id.clone(), nullability), } } @@ -122,14 +91,14 @@ impl Display for DType { Null => write!(f, "null"), Bool(n) => write!(f, "bool{}", n), Primitive(p, n) => write!(f, "{}{}", p, n), - Decimal(p, s, n) => write!(f, "decimal({}, {}){}", p, s, n), Utf8(n) => write!(f, "utf8{}", n), Binary(n) => write!(f, "binary{}", n), - Struct(n, dt) => write!( + Struct { names, dtypes } => write!( f, "{{{}}}", - n.iter() - .zip(dt.iter()) + names + .iter() + .zip(dtypes.iter()) .map(|(n, dt)| format!("{}={}", n, dt)) .join(", ") ), diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index eb293e0223..80d7427c78 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -1,25 +1,16 @@ -use std::fmt::{Display, Formatter}; - pub use dtype::*; pub use half; pub use ptype::*; mod deserialize; +pub use composite::*; +mod composite; mod dtype; +mod nullability; mod ptype; mod serde; mod serialize; -pub use deserialize::*; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -pub struct CompositeID(pub &'static str); - -impl Display for CompositeID { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} +pub use nullability::*; pub mod flatbuffers { #[allow(unused_imports)] diff --git a/vortex-dtype/src/nullability.rs b/vortex-dtype/src/nullability.rs new file mode 100644 index 0000000000..e8416215bd --- /dev/null +++ b/vortex-dtype/src/nullability.rs @@ -0,0 +1,36 @@ +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub enum Nullability { + #[default] + NonNullable, + Nullable, +} + +impl From for Nullability { + fn from(value: bool) -> Self { + if value { + Nullability::Nullable + } else { + Nullability::NonNullable + } + } +} + +impl From for bool { + fn from(value: Nullability) -> Self { + match value { + Nullability::NonNullable => false, + Nullability::Nullable => true, + } + } +} + +impl Display for Nullability { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Nullability::NonNullable => write!(f, ""), + Nullability::Nullable => write!(f, "?"), + } + } +} diff --git a/vortex-dtype/src/serde.rs b/vortex-dtype/src/serde.rs index a5bbd242df..1580b525b3 100644 --- a/vortex-dtype/src/serde.rs +++ b/vortex-dtype/src/serde.rs @@ -1,59 +1,72 @@ #![cfg(feature = "serde")] - -use flatbuffers::root; -use serde::de::{DeserializeSeed, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; -use crate::DType; -use crate::{flatbuffers as fb, DTypeSerdeContext}; +use crate::{CompositeID, Nullability}; -impl Serialize for DType { +impl Serialize for Nullability { fn serialize(&self, serializer: S) -> Result where S: Serializer, { - self.with_flatbuffer_bytes(|bytes| serializer.serialize_bytes(bytes)) + match self { + Nullability::NonNullable => serializer.serialize_bool(false), + Nullability::Nullable => serializer.serialize_bool(true), + } } } -struct DTypeDeserializer(DTypeSerdeContext); - -impl<'de> Visitor<'de> for DTypeDeserializer { - type Value = DType; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a vortex dtype") +impl<'de> Deserialize<'de> for Nullability { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(match bool::deserialize(deserializer)? { + true => Nullability::Nullable, + false => Nullability::NonNullable, + }) } +} - fn visit_bytes(self, v: &[u8]) -> Result +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for CompositeID { + fn deserialize(deserializer: D) -> Result where - E: serde::de::Error, + D: Deserializer<'de>, { - let fb = root::(v).map_err(E::custom)?; - DType::read_flatbuffer(&self.0, &fb).map_err(E::custom) + CompositeID::try_from(<&'de str>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) } } -impl<'de> DeserializeSeed<'de> for DTypeSerdeContext { - type Value = DType; +/// Implement custom serde to retain the ergonomics of a tuple enum variant. +/// Essentially, we use this wrapper to name the fields of the DType::Primitive enum variant. +pub mod dtype_primitive { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + use crate::{Nullability, PType}; + + #[derive(Serialize, Deserialize)] + struct PrimitiveSerde { + ptype: PType, + n: Nullability, + } - fn deserialize(self, deserializer: D) -> Result + pub fn serialize(ptype: &PType, n: &Nullability, serializer: S) -> Result where - D: Deserializer<'de>, + S: Serializer, { - deserializer.deserialize_bytes(DTypeDeserializer(self)) + PrimitiveSerde { + ptype: *ptype, + n: *n, + } + .serialize(serializer) } -} -// TODO(ngates): Remove this trait in favour of storing e.g. IdxType which doesn't require -// the context for composite types. -impl<'de> Deserialize<'de> for DType { - fn deserialize(deserializer: D) -> Result + pub fn deserialize<'de, D>(deserializer: D) -> Result<(PType, Nullability), D::Error> where D: Deserializer<'de>, { - let ctx = DTypeSerdeContext::new(vec![]); - deserializer.deserialize_bytes(DTypeDeserializer(ctx)) + let PrimitiveSerde { ptype, n } = PrimitiveSerde::deserialize(deserializer)?; + Ok((ptype, n)) } } diff --git a/vortex-dtype/src/serialize.rs b/vortex-dtype/src/serialize.rs index 1aa485807a..d5ad6d8220 100644 --- a/vortex-dtype/src/serialize.rs +++ b/vortex-dtype/src/serialize.rs @@ -31,15 +31,6 @@ impl WriteFlatBuffer for DType { }, ) .as_union_value(), - DType::Decimal(p, s, n) => fb::Decimal::create( - fbb, - &fb::DecimalArgs { - precision: *p, - scale: *s, - nullability: n.into(), - }, - ) - .as_union_value(), DType::Utf8(n) => fb::Utf8::create( fbb, &fb::Utf8Args { @@ -54,7 +45,7 @@ impl WriteFlatBuffer for DType { }, ) .as_union_value(), - DType::Struct(names, dtypes) => { + DType::Struct { names, dtypes } => { let names = names .iter() .map(|n| fbb.create_string(n.as_str())) @@ -81,7 +72,7 @@ impl WriteFlatBuffer for DType { .as_union_value() } DType::Composite(id, n) => { - let id = Some(fbb.create_string(id.0)); + let id = Some(fbb.create_string(id.as_ref())); fb::Composite::create( fbb, &fb::CompositeArgs { @@ -97,10 +88,9 @@ impl WriteFlatBuffer for DType { DType::Null => fb::Type::Null, DType::Bool(_) => fb::Type::Bool, DType::Primitive(..) => fb::Type::Primitive, - DType::Decimal(..) => fb::Type::Decimal, DType::Utf8(_) => fb::Type::Utf8, DType::Binary(_) => fb::Type::Binary, - DType::Struct(..) => fb::Type::Struct_, + DType::Struct { .. } => fb::Type::Struct_, DType::List(..) => fb::Type::List, DType::Composite(..) => fb::Type::Composite, }; @@ -180,15 +170,12 @@ mod test { use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; use crate::{flatbuffers as fb, PType}; - use crate::{DType, DTypeSerdeContext, Nullability}; + use crate::{DType, Nullability}; fn roundtrip_dtype(dtype: DType) { let bytes = dtype.with_flatbuffer_bytes(|bytes| bytes.to_vec()); - let deserialized = DType::read_flatbuffer( - &DTypeSerdeContext::new(vec![]), - &root::(&bytes).unwrap(), - ) - .unwrap(); + let deserialized = + DType::read_flatbuffer(&(), &root::(&bytes).unwrap()).unwrap(); assert_eq!(dtype, deserialized); } @@ -197,19 +184,18 @@ mod test { roundtrip_dtype(DType::Null); roundtrip_dtype(DType::Bool(Nullability::NonNullable)); roundtrip_dtype(DType::Primitive(PType::U64, Nullability::NonNullable)); - roundtrip_dtype(DType::Decimal(18, 9, Nullability::NonNullable)); roundtrip_dtype(DType::Binary(Nullability::NonNullable)); roundtrip_dtype(DType::Utf8(Nullability::NonNullable)); roundtrip_dtype(DType::List( Box::new(DType::Primitive(PType::F32, Nullability::Nullable)), Nullability::NonNullable, )); - roundtrip_dtype(DType::Struct( - vec![Arc::new("strings".into()), Arc::new("ints".into())], - vec![ + roundtrip_dtype(DType::Struct { + names: vec![Arc::new("strings".into()), Arc::new("ints".into())], + dtypes: vec![ DType::Utf8(Nullability::NonNullable), DType::Primitive(PType::U16, Nullability::Nullable), ], - )) + }) } } diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index 4cba0d85e7..c6239894d0 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -7,7 +7,6 @@ use flatbuffers::{root, root_unchecked}; use itertools::Itertools; use nougat::gat; use vortex::array::chunked::ChunkedArray; -use vortex::array::composite::VORTEX_COMPOSITE_EXTENSIONS; use vortex::array::primitive::PrimitiveArray; use vortex::buffer::Buffer; use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; @@ -17,7 +16,7 @@ use vortex::stats::{ArrayStatistics, Stat}; use vortex::{ Array, ArrayDType, ArrayView, IntoArray, OwnedArray, SerdeContext, ToArray, ToStatic, }; -use vortex_dtype::{match_each_integer_ptype, DType, DTypeSerdeContext}; +use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::ReadFlatBuffer; @@ -103,11 +102,8 @@ impl FallibleLendingIterator for StreamReader { .header_as_schema() .unwrap(); - // TODO(ngates): construct this from the SerdeContext. - let dtype_ctx = - DTypeSerdeContext::new(VORTEX_COMPOSITE_EXTENSIONS.iter().map(|e| e.id()).collect()); let dtype = DType::read_flatbuffer( - &dtype_ctx, + &(), &schema_msg .dtype() .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?, diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 30d4be5d8f..b34437bd02 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -126,10 +126,9 @@ impl Scalar { DType::Null => NullScalar::new().into(), DType::Bool(_) => BoolScalar::none().into(), DType::Primitive(p, _) => PrimitiveScalar::none_from_ptype(*p).into(), - DType::Decimal(..) => unimplemented!("DecimalScalar"), DType::Utf8(_) => Utf8Scalar::none().into(), DType::Binary(_) => BinaryScalar::none().into(), - DType::Struct(..) => StructScalar::new(dtype.clone(), vec![]).into(), + DType::Struct { .. } => StructScalar::new(dtype.clone(), vec![]).into(), DType::List(..) => ListScalar::new(dtype.clone(), None).into(), DType::Composite(..) => unimplemented!("CompositeScalar"), } diff --git a/vortex-scalar/src/serde.rs b/vortex-scalar/src/serde.rs index 1d77288a83..18d25af7e7 100644 --- a/vortex-scalar/src/serde.rs +++ b/vortex-scalar/src/serde.rs @@ -4,7 +4,7 @@ use flatbuffers::{root, FlatBufferBuilder, WIPOffset}; use serde::de::Visitor; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use vortex_dtype::match_each_native_ptype; -use vortex_dtype::{DTypeSerdeContext, Nullability}; +use vortex_dtype::Nullability; use vortex_error::{vortex_bail, VortexError}; use vortex_flatbuffers::{FlatBufferRoot, FlatBufferToBytes, ReadFlatBuffer, WriteFlatBuffer}; @@ -90,14 +90,11 @@ impl WriteFlatBuffer for Scalar { } } -impl ReadFlatBuffer for Scalar { +impl ReadFlatBuffer<()> for Scalar { type Source<'a> = fb::Scalar<'a>; type Error = VortexError; - fn read_flatbuffer( - _ctx: &DTypeSerdeContext, - fb: &Self::Source<'_>, - ) -> Result { + fn read_flatbuffer(_ctx: &(), fb: &Self::Source<'_>) -> Result { let nullability = Nullability::from(fb.nullability()); match fb.type_type() { fb::Type::Binary => { @@ -153,7 +150,7 @@ impl Serialize for Scalar { } } -struct ScalarDeserializer(DTypeSerdeContext); +struct ScalarDeserializer; impl<'de> Visitor<'de> for ScalarDeserializer { type Value = Scalar; @@ -167,7 +164,7 @@ impl<'de> Visitor<'de> for ScalarDeserializer { E: serde::de::Error, { let fb = root::(v).map_err(E::custom)?; - Scalar::read_flatbuffer(&self.0, &fb).map_err(E::custom) + Scalar::read_flatbuffer(&(), &fb).map_err(E::custom) } } @@ -177,8 +174,7 @@ impl<'de> Deserialize<'de> for Scalar { where D: Deserializer<'de>, { - let ctx = DTypeSerdeContext::new(vec![]); - deserializer.deserialize_bytes(ScalarDeserializer(ctx)) + deserializer.deserialize_bytes(ScalarDeserializer) } } diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index 50abe21673..655510b415 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -31,16 +31,16 @@ impl StructScalar { } pub fn names(&self) -> &[Arc] { - let DType::Struct(ns, _) = self.dtype() else { + let DType::Struct { names, .. } = self.dtype() else { unreachable!("Not a scalar dtype"); }; - ns.as_slice() + names.as_slice() } pub fn cast(&self, dtype: &DType) -> VortexResult { match dtype { - DType::Struct(names, field_dtypes) => { - if field_dtypes.len() != self.values.len() { + DType::Struct { names, dtypes } => { + if dtypes.len() != self.values.len() { vortex_bail!( MismatchedTypes: format!("Struct with {} fields", self.values.len()), dtype @@ -50,14 +50,14 @@ impl StructScalar { let new_fields: Vec = self .values .iter() - .zip_eq(field_dtypes.iter()) + .zip_eq(dtypes.iter()) .map(|(field, field_dtype)| field.cast(field_dtype)) .try_collect()?; - let new_type = DType::Struct( - names.clone(), - new_fields.iter().map(|x| x.dtype().clone()).collect(), - ); + let new_type = DType::Struct { + names: names.clone(), + dtypes: new_fields.iter().map(|x| x.dtype().clone()).collect(), + }; Ok(StructScalar::new(new_type, new_fields).into()) } _ => Err(vortex_err!(MismatchedTypes: "struct", dtype)), @@ -81,7 +81,7 @@ impl PartialOrd for StructScalar { impl Display for StructScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let DType::Struct(names, _) = self.dtype() else { + let DType::Struct { names, .. } = self.dtype() else { unreachable!() }; for (n, v) in names.iter().zip(self.values.iter()) {