diff --git a/vortex-array/src/array/composite/array.rs b/vortex-array/src/array/composite/array.rs index f3a1897cad..0faabba7e3 100644 --- a/vortex-array/src/array/composite/array.rs +++ b/vortex-array/src/array/composite/array.rs @@ -1,6 +1,6 @@ use flatbuffers::root; use vortex_dtype::flatbuffers as fb; -use vortex_dtype::{CompositeID, DTypeSerdeContext}; +use vortex_dtype::CompositeID; use vortex_error::{vortex_err, VortexResult}; use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; @@ -53,11 +53,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata { .ok_or_else(|| vortex_err!("Unrecognized composite extension: {}", ext_id))?; let dtype_blob = elems.index(1).expect("missing dtype").as_blob(); - let ctx = DTypeSerdeContext::new(vec![]); // FIXME: composite_ids - let underlying_dtype = DType::read_flatbuffer( - &ctx, - &root::(dtype_blob.0).expect("invalid dtype"), - )?; + let underlying_dtype = + DType::read_flatbuffer(&root::(dtype_blob.0).expect("invalid dtype"))?; let underlying_metadata: Arc<[u8]> = elems .index(2) diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index e129426438..21effb90f6 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -21,7 +21,7 @@ half = { workspace = true } itertools = { workspace = true } linkme = { workspace = true } num-traits = { workspace = true } -serde = { workspace = true, optional = true } +serde = { workspace = true, optional = true, features = ["rc"] } thiserror = { workspace = true } vortex-error = { path = "../vortex-error" } vortex-flatbuffers = { path = "../vortex-flatbuffers" } diff --git a/vortex-dtype/flatbuffers/dtype.fbs b/vortex-dtype/flatbuffers/dtype.fbs index 041d61b934..61cc64a026 100644 --- a/vortex-dtype/flatbuffers/dtype.fbs +++ b/vortex-dtype/flatbuffers/dtype.fbs @@ -1,6 +1,6 @@ namespace vortex.dtype; -enum Nullability: byte { +enum Nullability: uint8 { NonNullable, Nullable, } @@ -32,10 +32,10 @@ table Primitive { table Decimal { /// Total number of decimal digits - precision: ubyte; + precision: uint8; /// Number of digits after the decimal point "." - scale: byte; + scale: int8; nullability: Nullability; } @@ -57,8 +57,9 @@ table List { nullability: Nullability; } -table Composite { +table Extension { id: string; + metadata: [ubyte]; nullability: Nullability; } @@ -71,7 +72,7 @@ union Type { Binary, Struct_, List, - Composite, + Extension, } table DType { diff --git a/vortex-dtype/src/deserialize.rs b/vortex-dtype/src/deserialize.rs index 4ef3b2a406..b5bc4d06f4 100644 --- a/vortex-dtype/src/deserialize.rs +++ b/vortex-dtype/src/deserialize.rs @@ -3,34 +3,14 @@ use std::sync::Arc; use vortex_error::{vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::ReadFlatBuffer; -use crate::{flatbuffers as fb, Nullability}; -use crate::{CompositeID, DType}; +use crate::DType; +use crate::{flatbuffers as fb, ExtDType, ExtID, ExtMetadata, Nullability}; -#[allow(dead_code)] -pub struct DTypeSerdeContext { - composite_ids: Arc<[CompositeID]>, -} - -impl DTypeSerdeContext { - pub fn new(composite_ids: Vec) -> Self { - Self { - composite_ids: composite_ids.into(), - } - } - - pub fn find_composite_id(&self, id: &str) -> Option { - self.composite_ids.iter().find(|c| c.0 == id).copied() - } -} - -impl ReadFlatBuffer for DType { +impl ReadFlatBuffer for DType { type Source<'a> = fb::DType<'a>; type Error = VortexError; - fn read_flatbuffer( - ctx: &DTypeSerdeContext, - fb: &Self::Source<'_>, - ) -> Result { + fn read_flatbuffer(fb: &Self::Source<'_>) -> Result { match fb.type_type() { fb::Type::Null => Ok(DType::Null), fb::Type::Bool => Ok(DType::Bool( @@ -59,7 +39,7 @@ impl ReadFlatBuffer for DType { )), fb::Type::List => { let fb_list = fb.type__as_list().unwrap(); - let element_dtype = DType::read_flatbuffer(ctx, &fb_list.element_type().unwrap())?; + let element_dtype = DType::read_flatbuffer(&fb_list.element_type().unwrap())?; Ok(DType::List( Box::new(element_dtype), fb_list.nullability().try_into()?, @@ -77,16 +57,18 @@ impl ReadFlatBuffer for DType { .fields() .unwrap() .iter() - .map(|f| DType::read_flatbuffer(ctx, &f)) + .map(|f| DType::read_flatbuffer(&f)) .collect::>>()?; Ok(DType::Struct(names, fields)) } - fb::Type::Composite => { - let fb_composite = fb.type__as_composite().unwrap(); - let id = ctx - .find_composite_id(fb_composite.id().unwrap()) - .ok_or_else(|| vortex_err!("Couldn't find composite id"))?; - Ok(DType::Composite(id, fb_composite.nullability().try_into()?)) + fb::Type::Extension => { + let fb_ext = fb.type__as_extension().unwrap(); + let id = ExtID::from(fb_ext.id().unwrap()); + let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes())); + Ok(DType::Extension( + ExtDType::new(id, metadata), + fb_ext.nullability().try_into()?, + )) } _ => Err(vortex_err!("Unknown DType variant")), } diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 0c9411812f..e46fccb506 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use itertools::Itertools; use DType::*; -use crate::{CompositeID, PType}; +use crate::{CompositeID, ExtDType, PType}; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] @@ -48,6 +48,7 @@ pub type FieldNames = Vec>; pub type Metadata = Vec; #[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum DType { Null, Bool(Nullability), @@ -57,6 +58,8 @@ pub enum DType { Binary(Nullability), Struct(FieldNames, Vec), List(Box, Nullability), + Extension(ExtDType, Nullability), + #[serde(skip)] Composite(CompositeID, Nullability), } @@ -82,6 +85,7 @@ impl DType { Binary(n) => matches!(n, Nullable), Struct(_, fs) => fs.iter().all(|f| f.is_nullable()), List(_, n) => matches!(n, Nullable), + Extension(_, n) => matches!(n, Nullable), Composite(_, n) => matches!(n, Nullable), } } @@ -107,6 +111,7 @@ impl DType { fs.iter().map(|f| f.with_nullability(nullability)).collect(), ), List(c, _) => List(c.clone(), nullability), + Extension(ext, _) => Extension(ext.clone(), nullability), Composite(id, _) => Composite(*id, nullability), } } @@ -134,6 +139,15 @@ impl Display for DType { .join(", ") ), List(c, n) => write!(f, "list({}){}", c, n), + Extension(ext, n) => write!( + f, + "ext({}{}){}", + ext.id(), + ext.metadata() + .map(|m| format!(", {:?}", m)) + .unwrap_or_else(|| "".to_string()), + n + ), Composite(id, n) => write!(f, "<{}>{}", id, n), } } @@ -147,6 +161,6 @@ mod test { #[test] fn size_of() { - assert_eq!(mem::size_of::(), 48); + assert_eq!(mem::size_of::(), 56); } } diff --git a/vortex-dtype/src/extension.rs b/vortex-dtype/src/extension.rs new file mode 100644 index 0000000000..1eea49bc77 --- /dev/null +++ b/vortex-dtype/src/extension.rs @@ -0,0 +1,69 @@ +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] +pub struct ExtID(Arc); + +impl Display for ExtID { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl AsRef for ExtID { + fn as_ref(&self) -> &str { + self.0.as_ref() + } +} + +impl From<&str> for ExtID { + fn from(value: &str) -> Self { + ExtID(value.into()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ExtMetadata(Arc<[u8]>); + +impl AsRef<[u8]> for ExtMetadata { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl From> for ExtMetadata { + fn from(value: Arc<[u8]>) -> Self { + ExtMetadata(value) + } +} + +impl From<&[u8]> for ExtMetadata { + fn from(value: &[u8]) -> Self { + ExtMetadata(value.into()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ExtDType { + id: ExtID, + metadata: Option, +} + +impl ExtDType { + pub fn new(id: ExtID, metadata: Option) -> Self { + Self { id, metadata } + } + + #[inline] + pub fn id(&self) -> &ExtID { + &self.id + } + + #[inline] + pub fn metadata(&self) -> Option<&ExtMetadata> { + self.metadata.as_ref() + } +} diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index eb293e0223..3ed41be2b6 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -1,16 +1,16 @@ use std::fmt::{Display, Formatter}; pub use dtype::*; +pub use extension::*; pub use half; pub use ptype::*; mod deserialize; mod dtype; +mod extension; mod ptype; mod serde; mod serialize; -pub use deserialize::*; - #[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] pub struct CompositeID(pub &'static str); diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index e1715dd710..559988ad7c 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -9,8 +9,8 @@ use crate::DType; use crate::DType::*; use crate::Nullability::NonNullable; -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PType { U8, U16, diff --git a/vortex-dtype/src/serde.rs b/vortex-dtype/src/serde.rs index a5bbd242df..29df755186 100644 --- a/vortex-dtype/src/serde.rs +++ b/vortex-dtype/src/serde.rs @@ -1,59 +1 @@ #![cfg(feature = "serde")] - -use flatbuffers::root; -use serde::de::{DeserializeSeed, Visitor}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; - -use crate::DType; -use crate::{flatbuffers as fb, DTypeSerdeContext}; - -impl Serialize for DType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - self.with_flatbuffer_bytes(|bytes| serializer.serialize_bytes(bytes)) - } -} - -struct DTypeDeserializer(DTypeSerdeContext); - -impl<'de> Visitor<'de> for DTypeDeserializer { - type Value = DType; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a vortex dtype") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - let fb = root::(v).map_err(E::custom)?; - DType::read_flatbuffer(&self.0, &fb).map_err(E::custom) - } -} - -impl<'de> DeserializeSeed<'de> for DTypeSerdeContext { - type Value = DType; - - fn deserialize(self, deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_bytes(DTypeDeserializer(self)) - } -} - -// TODO(ngates): Remove this trait in favour of storing e.g. IdxType which doesn't require -// the context for composite types. -impl<'de> Deserialize<'de> for DType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let ctx = DTypeSerdeContext::new(vec![]); - deserializer.deserialize_bytes(DTypeDeserializer(ctx)) - } -} diff --git a/vortex-dtype/src/serialize.rs b/vortex-dtype/src/serialize.rs index 1aa485807a..3bb47e22d0 100644 --- a/vortex-dtype/src/serialize.rs +++ b/vortex-dtype/src/serialize.rs @@ -80,17 +80,20 @@ impl WriteFlatBuffer for DType { ) .as_union_value() } - DType::Composite(id, n) => { - let id = Some(fbb.create_string(id.0)); - fb::Composite::create( + DType::Extension(ext, n) => { + let id = Some(fbb.create_string(ext.id().as_ref())); + let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref())); + fb::Extension::create( fbb, - &fb::CompositeArgs { + &fb::ExtensionArgs { id, + metadata, nullability: n.into(), }, ) .as_union_value() } + DType::Composite(..) => todo!(), }; let dtype_type = match self { @@ -102,7 +105,8 @@ impl WriteFlatBuffer for DType { DType::Binary(_) => fb::Type::Binary, DType::Struct(..) => fb::Type::Struct_, DType::List(..) => fb::Type::List, - DType::Composite(..) => fb::Type::Composite, + DType::Extension { .. } => fb::Type::Extension, + DType::Composite(..) => unreachable!(), }; fb::DType::create( @@ -180,15 +184,11 @@ mod test { use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; use crate::{flatbuffers as fb, PType}; - use crate::{DType, DTypeSerdeContext, Nullability}; + use crate::{DType, Nullability}; fn roundtrip_dtype(dtype: DType) { let bytes = dtype.with_flatbuffer_bytes(|bytes| bytes.to_vec()); - let deserialized = DType::read_flatbuffer( - &DTypeSerdeContext::new(vec![]), - &root::(&bytes).unwrap(), - ) - .unwrap(); + let deserialized = DType::read_flatbuffer(&root::(&bytes).unwrap()).unwrap(); assert_eq!(dtype, deserialized); } diff --git a/vortex-flatbuffers/src/lib.rs b/vortex-flatbuffers/src/lib.rs index 5c6b214356..e120c5248a 100644 --- a/vortex-flatbuffers/src/lib.rs +++ b/vortex-flatbuffers/src/lib.rs @@ -5,11 +5,11 @@ use flatbuffers::{FlatBufferBuilder, WIPOffset}; pub trait FlatBufferRoot {} -pub trait ReadFlatBuffer: Sized { +pub trait ReadFlatBuffer: Sized { type Source<'a>; type Error; - fn read_flatbuffer(ctx: &Ctx, fb: &Self::Source<'_>) -> Result; + fn read_flatbuffer(fb: &Self::Source<'_>) -> Result; } pub trait WriteFlatBuffer { diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index 4cba0d85e7..8c44d9814c 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -7,7 +7,6 @@ use flatbuffers::{root, root_unchecked}; use itertools::Itertools; use nougat::gat; use vortex::array::chunked::ChunkedArray; -use vortex::array::composite::VORTEX_COMPOSITE_EXTENSIONS; use vortex::array::primitive::PrimitiveArray; use vortex::buffer::Buffer; use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; @@ -17,7 +16,7 @@ use vortex::stats::{ArrayStatistics, Stat}; use vortex::{ Array, ArrayDType, ArrayView, IntoArray, OwnedArray, SerdeContext, ToArray, ToStatic, }; -use vortex_dtype::{match_each_integer_ptype, DType, DTypeSerdeContext}; +use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::ReadFlatBuffer; @@ -103,11 +102,7 @@ impl FallibleLendingIterator for StreamReader { .header_as_schema() .unwrap(); - // TODO(ngates): construct this from the SerdeContext. - let dtype_ctx = - DTypeSerdeContext::new(VORTEX_COMPOSITE_EXTENSIONS.iter().map(|e| e.id()).collect()); let dtype = DType::read_flatbuffer( - &dtype_ctx, &schema_msg .dtype() .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?, diff --git a/vortex-scalar/flatbuffers/scalar.fbs b/vortex-scalar/flatbuffers/scalar.fbs index c333c04e7e..eff54ee082 100644 --- a/vortex-scalar/flatbuffers/scalar.fbs +++ b/vortex-scalar/flatbuffers/scalar.fbs @@ -25,14 +25,16 @@ table Primitive { table Struct_ { names: [string]; - value: [Scalar]; + scalars: [Scalar]; } table UTF8 { value: string; } -table Composite { +table Extension { + id: string; + metadata: [ubyte]; value: Scalar; } @@ -44,7 +46,7 @@ union Type { Primitive, Struct_, UTF8, - Composite, + Extension, } // TODO(ngates): separate out ScalarValue from Scalar, even in-memory, so we can avoid duplicating dtype information (e.g. Struct field names). diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs new file mode 100644 index 0000000000..be0e2c973f --- /dev/null +++ b/vortex-scalar/src/extension.rs @@ -0,0 +1,91 @@ +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability}; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::Scalar; + +#[derive(Debug, Clone, PartialEq)] +pub struct ExtScalar { + dtype: DType, + value: Option>, +} + +impl ExtScalar { + pub fn try_new( + ext: ExtDType, + nullability: Nullability, + value: Option>, + ) -> VortexResult { + if value.is_none() && nullability == Nullability::NonNullable { + vortex_bail!("Value cannot be None for NonNullable Scalar"); + } + Ok(Self { + dtype: DType::Extension(ext, nullability), + value, + }) + } + + pub fn null(ext: ExtDType) -> Self { + Self::try_new(ext, Nullability::Nullable, None).expect("Incorrect nullability check") + } + + #[inline] + pub fn id(&self) -> &ExtID { + self.ext_dtype().id() + } + + #[inline] + pub fn metadata(&self) -> Option<&ExtMetadata> { + self.ext_dtype().metadata() + } + + #[inline] + pub fn ext_dtype(&self) -> &ExtDType { + let DType::Extension(ext, _) = &self.dtype else { + unreachable!() + }; + ext + } + + #[inline] + pub fn dtype(&self) -> &DType { + &self.dtype + } + + pub fn value(&self) -> Option<&Arc> { + self.value.as_ref() + } + + pub fn cast(&self, _dtype: &DType) -> VortexResult { + todo!() + } + + pub fn nbytes(&self) -> usize { + todo!() + } +} + +impl PartialOrd for ExtScalar { + fn partial_cmp(&self, other: &Self) -> Option { + if let (Some(s), Some(o)) = (self.value(), other.value()) { + s.partial_cmp(o) + } else { + None + } + } +} + +impl Display for ExtScalar { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} ({})", + self.value() + .map(|s| format!("{}", s)) + .unwrap_or_else(|| "".to_string()), + self.dtype + ) + } +} diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 30d4be5d8f..91a4634a9f 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -12,9 +12,12 @@ use vortex_dtype::NativePType; use vortex_dtype::{DType, Nullability}; use vortex_error::VortexResult; +use crate::extension::ExtScalar; + mod binary; mod bool; mod composite; +mod extension; mod list; mod null; mod primitive; @@ -51,6 +54,7 @@ pub enum Scalar { Primitive(PrimitiveScalar), Struct(StructScalar), Utf8(Utf8Scalar), + Extension(ExtScalar), Composite(CompositeScalar), } @@ -71,6 +75,7 @@ impls_for_scalars!(Null, NullScalar); impls_for_scalars!(Primitive, PrimitiveScalar); impls_for_scalars!(Struct, StructScalar); impls_for_scalars!(Utf8, Utf8Scalar); +impls_for_scalars!(Extension, ExtScalar); impls_for_scalars!(Composite, CompositeScalar); macro_rules! match_each_scalar { @@ -84,6 +89,7 @@ macro_rules! match_each_scalar { Scalar::Primitive(s) => __with_scalar__! { s }, Scalar::Struct(s) => __with_scalar__! { s }, Scalar::Utf8(s) => __with_scalar__! { s }, + Scalar::Extension(s) => __with_scalar__! { s }, Scalar::Composite(s) => __with_scalar__! { s }, } }) @@ -116,6 +122,7 @@ impl Scalar { // FIXME(ngates): can't have a null struct? Scalar::Struct(_) => false, Scalar::Utf8(u) => u.value().is_none(), + Scalar::Extension(e) => e.value().is_none(), Scalar::Composite(c) => c.scalar().is_null(), } } @@ -131,7 +138,8 @@ impl Scalar { DType::Binary(_) => BinaryScalar::none().into(), DType::Struct(..) => StructScalar::new(dtype.clone(), vec![]).into(), DType::List(..) => ListScalar::new(dtype.clone(), None).into(), - DType::Composite(..) => unimplemented!("CompositeScalar"), + DType::Extension(ext, _) => ExtScalar::null(ext.clone()).into(), + DType::Composite(..) => unimplemented!(), } } } @@ -164,6 +172,6 @@ mod test { #[test] fn size_of() { - assert_eq!(mem::size_of::(), 80); + assert_eq!(mem::size_of::(), 88); } } diff --git a/vortex-scalar/src/serde.rs b/vortex-scalar/src/serde.rs index 1d77288a83..1327b19695 100644 --- a/vortex-scalar/src/serde.rs +++ b/vortex-scalar/src/serde.rs @@ -4,7 +4,7 @@ use flatbuffers::{root, FlatBufferBuilder, WIPOffset}; use serde::de::Visitor; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use vortex_dtype::match_each_native_ptype; -use vortex_dtype::{DTypeSerdeContext, Nullability}; +use vortex_dtype::Nullability; use vortex_error::{vortex_bail, VortexError}; use vortex_flatbuffers::{FlatBufferRoot, FlatBufferToBytes, ReadFlatBuffer, WriteFlatBuffer}; @@ -83,6 +83,26 @@ impl WriteFlatBuffer for Scalar { nullability: self.nullability().into(), } } + Scalar::Extension(ext) => { + let id = Some(fbb.create_string(ext.id().as_ref())); + let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref())); + let value = ext.value().map(|s| s.write_flatbuffer(fbb)); + fb::ScalarArgs { + type_type: fb::Type::Extension, + type_: Some( + fb::Extension::create( + fbb, + &fb::ExtensionArgs { + id, + metadata, + value, + }, + ) + .as_union_value(), + ), + nullability: self.nullability().into(), + } + } Scalar::Composite(_) => panic!(), }; @@ -90,14 +110,11 @@ impl WriteFlatBuffer for Scalar { } } -impl ReadFlatBuffer for Scalar { +impl ReadFlatBuffer for Scalar { type Source<'a> = fb::Scalar<'a>; type Error = VortexError; - fn read_flatbuffer( - _ctx: &DTypeSerdeContext, - fb: &Self::Source<'_>, - ) -> Result { + fn read_flatbuffer(fb: &Self::Source<'_>) -> Result { let nullability = Nullability::from(fb.nullability()); match fb.type_type() { fb::Type::Binary => { @@ -136,9 +153,6 @@ impl ReadFlatBuffer for Scalar { .map(|s| s.to_string()), nullability, )?)), - fb::Type::Composite => { - todo!() - } _ => vortex_bail!(InvalidSerde: "Unrecognized scalar type"), } } @@ -153,7 +167,7 @@ impl Serialize for Scalar { } } -struct ScalarDeserializer(DTypeSerdeContext); +struct ScalarDeserializer; impl<'de> Visitor<'de> for ScalarDeserializer { type Value = Scalar; @@ -167,7 +181,7 @@ impl<'de> Visitor<'de> for ScalarDeserializer { E: serde::de::Error, { let fb = root::(v).map_err(E::custom)?; - Scalar::read_flatbuffer(&self.0, &fb).map_err(E::custom) + Scalar::read_flatbuffer(&fb).map_err(E::custom) } } @@ -177,8 +191,7 @@ impl<'de> Deserialize<'de> for Scalar { where D: Deserializer<'de>, { - let ctx = DTypeSerdeContext::new(vec![]); - deserializer.deserialize_bytes(ScalarDeserializer(ctx)) + deserializer.deserialize_bytes(ScalarDeserializer) } }