Skip to content

Commit

Permalink
Add ExtDType (#281)
Browse files Browse the repository at this point in the history
Start of the work on #174
  • Loading branch information
gatesn authored May 1, 2024
1 parent eabb8e6 commit 2c7d81d
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 144 deletions.
9 changes: 3 additions & 6 deletions vortex-array/src/array/composite/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use flatbuffers::root;
use vortex_dtype::flatbuffers as fb;
use vortex_dtype::{CompositeID, DTypeSerdeContext};
use vortex_dtype::CompositeID;
use vortex_error::{vortex_err, VortexResult};
use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer};

Expand Down Expand Up @@ -53,11 +53,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata {
.ok_or_else(|| vortex_err!("Unrecognized composite extension: {}", ext_id))?;

let dtype_blob = elems.index(1).expect("missing dtype").as_blob();
let ctx = DTypeSerdeContext::new(vec![]); // FIXME: composite_ids
let underlying_dtype = DType::read_flatbuffer(
&ctx,
&root::<fb::DType>(dtype_blob.0).expect("invalid dtype"),
)?;
let underlying_dtype =
DType::read_flatbuffer(&root::<fb::DType>(dtype_blob.0).expect("invalid dtype"))?;

let underlying_metadata: Arc<[u8]> = elems
.index(2)
Expand Down
2 changes: 1 addition & 1 deletion vortex-dtype/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ half = { workspace = true }
itertools = { workspace = true }
linkme = { workspace = true }
num-traits = { workspace = true }
serde = { workspace = true, optional = true }
serde = { workspace = true, optional = true, features = ["rc"] }
thiserror = { workspace = true }
vortex-error = { path = "../vortex-error" }
vortex-flatbuffers = { path = "../vortex-flatbuffers" }
Expand Down
11 changes: 6 additions & 5 deletions vortex-dtype/flatbuffers/dtype.fbs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace vortex.dtype;

enum Nullability: byte {
enum Nullability: uint8 {
NonNullable,
Nullable,
}
Expand Down Expand Up @@ -32,10 +32,10 @@ table Primitive {

table Decimal {
/// Total number of decimal digits
precision: ubyte;
precision: uint8;

/// Number of digits after the decimal point "."
scale: byte;
scale: int8;
nullability: Nullability;
}

Expand All @@ -57,8 +57,9 @@ table List {
nullability: Nullability;
}

table Composite {
table Extension {
id: string;
metadata: [ubyte];
nullability: Nullability;
}

Expand All @@ -71,7 +72,7 @@ union Type {
Binary,
Struct_,
List,
Composite,
Extension,
}

table DType {
Expand Down
46 changes: 14 additions & 32 deletions vortex-dtype/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,14 @@ use std::sync::Arc;
use vortex_error::{vortex_err, VortexError, VortexResult};
use vortex_flatbuffers::ReadFlatBuffer;

use crate::{flatbuffers as fb, Nullability};
use crate::{CompositeID, DType};
use crate::DType;
use crate::{flatbuffers as fb, ExtDType, ExtID, ExtMetadata, Nullability};

#[allow(dead_code)]
pub struct DTypeSerdeContext {
composite_ids: Arc<[CompositeID]>,
}

impl DTypeSerdeContext {
pub fn new(composite_ids: Vec<CompositeID>) -> Self {
Self {
composite_ids: composite_ids.into(),
}
}

pub fn find_composite_id(&self, id: &str) -> Option<CompositeID> {
self.composite_ids.iter().find(|c| c.0 == id).copied()
}
}

impl ReadFlatBuffer<DTypeSerdeContext> for DType {
impl ReadFlatBuffer for DType {
type Source<'a> = fb::DType<'a>;
type Error = VortexError;

fn read_flatbuffer(
ctx: &DTypeSerdeContext,
fb: &Self::Source<'_>,
) -> Result<Self, Self::Error> {
fn read_flatbuffer(fb: &Self::Source<'_>) -> Result<Self, Self::Error> {
match fb.type_type() {
fb::Type::Null => Ok(DType::Null),
fb::Type::Bool => Ok(DType::Bool(
Expand Down Expand Up @@ -59,7 +39,7 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
)),
fb::Type::List => {
let fb_list = fb.type__as_list().unwrap();
let element_dtype = DType::read_flatbuffer(ctx, &fb_list.element_type().unwrap())?;
let element_dtype = DType::read_flatbuffer(&fb_list.element_type().unwrap())?;
Ok(DType::List(
Box::new(element_dtype),
fb_list.nullability().try_into()?,
Expand All @@ -77,16 +57,18 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
.fields()
.unwrap()
.iter()
.map(|f| DType::read_flatbuffer(ctx, &f))
.map(|f| DType::read_flatbuffer(&f))
.collect::<VortexResult<Vec<_>>>()?;
Ok(DType::Struct(names, fields))
}
fb::Type::Composite => {
let fb_composite = fb.type__as_composite().unwrap();
let id = ctx
.find_composite_id(fb_composite.id().unwrap())
.ok_or_else(|| vortex_err!("Couldn't find composite id"))?;
Ok(DType::Composite(id, fb_composite.nullability().try_into()?))
fb::Type::Extension => {
let fb_ext = fb.type__as_extension().unwrap();
let id = ExtID::from(fb_ext.id().unwrap());
let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes()));
Ok(DType::Extension(
ExtDType::new(id, metadata),
fb_ext.nullability().try_into()?,
))
}
_ => Err(vortex_err!("Unknown DType variant")),
}
Expand Down
18 changes: 16 additions & 2 deletions vortex-dtype/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;
use itertools::Itertools;
use DType::*;

use crate::{CompositeID, PType};
use crate::{CompositeID, ExtDType, PType};

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
Expand Down Expand Up @@ -48,6 +48,7 @@ pub type FieldNames = Vec<Arc<String>>;
pub type Metadata = Vec<u8>;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DType {
Null,
Bool(Nullability),
Expand All @@ -57,6 +58,8 @@ pub enum DType {
Binary(Nullability),
Struct(FieldNames, Vec<DType>),
List(Box<DType>, Nullability),
Extension(ExtDType, Nullability),
#[serde(skip)]
Composite(CompositeID, Nullability),
}

Expand All @@ -82,6 +85,7 @@ impl DType {
Binary(n) => matches!(n, Nullable),
Struct(_, fs) => fs.iter().all(|f| f.is_nullable()),
List(_, n) => matches!(n, Nullable),
Extension(_, n) => matches!(n, Nullable),
Composite(_, n) => matches!(n, Nullable),
}
}
Expand All @@ -107,6 +111,7 @@ impl DType {
fs.iter().map(|f| f.with_nullability(nullability)).collect(),
),
List(c, _) => List(c.clone(), nullability),
Extension(ext, _) => Extension(ext.clone(), nullability),
Composite(id, _) => Composite(*id, nullability),
}
}
Expand Down Expand Up @@ -134,6 +139,15 @@ impl Display for DType {
.join(", ")
),
List(c, n) => write!(f, "list({}){}", c, n),
Extension(ext, n) => write!(
f,
"ext({}{}){}",
ext.id(),
ext.metadata()
.map(|m| format!(", {:?}", m))
.unwrap_or_else(|| "".to_string()),
n
),
Composite(id, n) => write!(f, "<{}>{}", id, n),
}
}
Expand All @@ -147,6 +161,6 @@ mod test {

#[test]
fn size_of() {
assert_eq!(mem::size_of::<DType>(), 48);
assert_eq!(mem::size_of::<DType>(), 56);
}
}
69 changes: 69 additions & 0 deletions vortex-dtype/src/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt::{Display, Formatter};
use std::sync::Arc;

#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub struct ExtID(Arc<str>);

impl Display for ExtID {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl AsRef<str> for ExtID {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}

impl From<&str> for ExtID {
fn from(value: &str) -> Self {
ExtID(value.into())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ExtMetadata(Arc<[u8]>);

impl AsRef<[u8]> for ExtMetadata {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}

impl From<Arc<[u8]>> for ExtMetadata {
fn from(value: Arc<[u8]>) -> Self {
ExtMetadata(value)
}
}

impl From<&[u8]> for ExtMetadata {
fn from(value: &[u8]) -> Self {
ExtMetadata(value.into())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ExtDType {
id: ExtID,
metadata: Option<ExtMetadata>,
}

impl ExtDType {
pub fn new(id: ExtID, metadata: Option<ExtMetadata>) -> Self {
Self { id, metadata }
}

#[inline]
pub fn id(&self) -> &ExtID {
&self.id
}

#[inline]
pub fn metadata(&self) -> Option<&ExtMetadata> {
self.metadata.as_ref()
}
}
4 changes: 2 additions & 2 deletions vortex-dtype/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::fmt::{Display, Formatter};

pub use dtype::*;
pub use extension::*;
pub use half;
pub use ptype::*;
mod deserialize;
mod dtype;
mod extension;
mod ptype;
mod serde;
mod serialize;

pub use deserialize::*;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub struct CompositeID(pub &'static str);
Expand Down
2 changes: 1 addition & 1 deletion vortex-dtype/src/ptype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::DType;
use crate::DType::*;
use crate::Nullability::NonNullable;

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PType {
U8,
U16,
Expand Down
58 changes: 0 additions & 58 deletions vortex-dtype/src/serde.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1 @@
#![cfg(feature = "serde")]

use flatbuffers::root;
use serde::de::{DeserializeSeed, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer};

use crate::DType;
use crate::{flatbuffers as fb, DTypeSerdeContext};

impl Serialize for DType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.with_flatbuffer_bytes(|bytes| serializer.serialize_bytes(bytes))
}
}

struct DTypeDeserializer(DTypeSerdeContext);

impl<'de> Visitor<'de> for DTypeDeserializer {
type Value = DType;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a vortex dtype")
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let fb = root::<fb::DType>(v).map_err(E::custom)?;
DType::read_flatbuffer(&self.0, &fb).map_err(E::custom)
}
}

impl<'de> DeserializeSeed<'de> for DTypeSerdeContext {
type Value = DType;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_bytes(DTypeDeserializer(self))
}
}

// TODO(ngates): Remove this trait in favour of storing e.g. IdxType which doesn't require
// the context for composite types.
impl<'de> Deserialize<'de> for DType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let ctx = DTypeSerdeContext::new(vec![]);
deserializer.deserialize_bytes(DTypeDeserializer(ctx))
}
}
Loading

0 comments on commit 2c7d81d

Please sign in to comment.