Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExtDType #281

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions vortex-array/src/array/composite/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use flatbuffers::root;
use vortex_dtype::flatbuffers as fb;
use vortex_dtype::{CompositeID, DTypeSerdeContext};
use vortex_dtype::CompositeID;
use vortex_error::{vortex_err, VortexResult};
use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer};

Expand Down Expand Up @@ -53,11 +53,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata {
.ok_or_else(|| vortex_err!("Unrecognized composite extension: {}", ext_id))?;

let dtype_blob = elems.index(1).expect("missing dtype").as_blob();
let ctx = DTypeSerdeContext::new(vec![]); // FIXME: composite_ids
let underlying_dtype = DType::read_flatbuffer(
&ctx,
&root::<fb::DType>(dtype_blob.0).expect("invalid dtype"),
)?;
let underlying_dtype =
DType::read_flatbuffer(&root::<fb::DType>(dtype_blob.0).expect("invalid dtype"))?;

let underlying_metadata: Arc<[u8]> = elems
.index(2)
Expand Down
2 changes: 1 addition & 1 deletion vortex-dtype/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ half = { workspace = true }
itertools = { workspace = true }
linkme = { workspace = true }
num-traits = { workspace = true }
serde = { workspace = true, optional = true }
serde = { workspace = true, optional = true, features = ["rc"] }
thiserror = { workspace = true }
vortex-error = { path = "../vortex-error" }
vortex-flatbuffers = { path = "../vortex-flatbuffers" }
Expand Down
11 changes: 6 additions & 5 deletions vortex-dtype/flatbuffers/dtype.fbs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace vortex.dtype;

enum Nullability: byte {
enum Nullability: uint8 {
NonNullable,
Nullable,
}
Expand Down Expand Up @@ -32,10 +32,10 @@ table Primitive {

table Decimal {
/// Total number of decimal digits
precision: ubyte;
precision: uint8;

/// Number of digits after the decimal point "."
scale: byte;
scale: int8;
nullability: Nullability;
}

Expand All @@ -57,8 +57,9 @@ table List {
nullability: Nullability;
}

table Composite {
table Extension {
id: string;
metadata: [ubyte];
nullability: Nullability;
}

Expand All @@ -71,7 +72,7 @@ union Type {
Binary,
Struct_,
List,
Composite,
Extension,
}

table DType {
Expand Down
46 changes: 14 additions & 32 deletions vortex-dtype/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,14 @@ use std::sync::Arc;
use vortex_error::{vortex_err, VortexError, VortexResult};
use vortex_flatbuffers::ReadFlatBuffer;

use crate::{flatbuffers as fb, Nullability};
use crate::{CompositeID, DType};
use crate::DType;
use crate::{flatbuffers as fb, ExtDType, ExtID, ExtMetadata, Nullability};

#[allow(dead_code)]
pub struct DTypeSerdeContext {
composite_ids: Arc<[CompositeID]>,
}

impl DTypeSerdeContext {
pub fn new(composite_ids: Vec<CompositeID>) -> Self {
Self {
composite_ids: composite_ids.into(),
}
}

pub fn find_composite_id(&self, id: &str) -> Option<CompositeID> {
self.composite_ids.iter().find(|c| c.0 == id).copied()
}
}

impl ReadFlatBuffer<DTypeSerdeContext> for DType {
impl ReadFlatBuffer for DType {
type Source<'a> = fb::DType<'a>;
type Error = VortexError;

fn read_flatbuffer(
ctx: &DTypeSerdeContext,
fb: &Self::Source<'_>,
) -> Result<Self, Self::Error> {
fn read_flatbuffer(fb: &Self::Source<'_>) -> Result<Self, Self::Error> {
match fb.type_type() {
fb::Type::Null => Ok(DType::Null),
fb::Type::Bool => Ok(DType::Bool(
Expand Down Expand Up @@ -59,7 +39,7 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
)),
fb::Type::List => {
let fb_list = fb.type__as_list().unwrap();
let element_dtype = DType::read_flatbuffer(ctx, &fb_list.element_type().unwrap())?;
let element_dtype = DType::read_flatbuffer(&fb_list.element_type().unwrap())?;
Ok(DType::List(
Box::new(element_dtype),
fb_list.nullability().try_into()?,
Expand All @@ -77,16 +57,18 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
.fields()
.unwrap()
.iter()
.map(|f| DType::read_flatbuffer(ctx, &f))
.map(|f| DType::read_flatbuffer(&f))
.collect::<VortexResult<Vec<_>>>()?;
Ok(DType::Struct(names, fields))
}
fb::Type::Composite => {
let fb_composite = fb.type__as_composite().unwrap();
let id = ctx
.find_composite_id(fb_composite.id().unwrap())
.ok_or_else(|| vortex_err!("Couldn't find composite id"))?;
Ok(DType::Composite(id, fb_composite.nullability().try_into()?))
fb::Type::Extension => {
let fb_ext = fb.type__as_extension().unwrap();
let id = ExtID::from(fb_ext.id().unwrap());
let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes()));
Ok(DType::Extension(
ExtDType::new(id, metadata),
fb_ext.nullability().try_into()?,
))
}
_ => Err(vortex_err!("Unknown DType variant")),
}
Expand Down
18 changes: 16 additions & 2 deletions vortex-dtype/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;
use itertools::Itertools;
use DType::*;

use crate::{CompositeID, PType};
use crate::{CompositeID, ExtDType, PType};

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
Expand Down Expand Up @@ -48,6 +48,7 @@ pub type FieldNames = Vec<Arc<String>>;
pub type Metadata = Vec<u8>;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DType {
Null,
Bool(Nullability),
Expand All @@ -57,6 +58,8 @@ pub enum DType {
Binary(Nullability),
Struct(FieldNames, Vec<DType>),
List(Box<DType>, Nullability),
Extension(ExtDType, Nullability),
#[serde(skip)]
Composite(CompositeID, Nullability),
}

Expand All @@ -82,6 +85,7 @@ impl DType {
Binary(n) => matches!(n, Nullable),
Struct(_, fs) => fs.iter().all(|f| f.is_nullable()),
List(_, n) => matches!(n, Nullable),
Extension(_, n) => matches!(n, Nullable),
Composite(_, n) => matches!(n, Nullable),
}
}
Expand All @@ -107,6 +111,7 @@ impl DType {
fs.iter().map(|f| f.with_nullability(nullability)).collect(),
),
List(c, _) => List(c.clone(), nullability),
Extension(ext, _) => Extension(ext.clone(), nullability),
Composite(id, _) => Composite(*id, nullability),
}
}
Expand Down Expand Up @@ -134,6 +139,15 @@ impl Display for DType {
.join(", ")
),
List(c, n) => write!(f, "list({}){}", c, n),
Extension(ext, n) => write!(
f,
"ext({}{}){}",
ext.id(),
ext.metadata()
.map(|m| format!(", {:?}", m))
.unwrap_or_else(|| "".to_string()),
n
),
Composite(id, n) => write!(f, "<{}>{}", id, n),
}
}
Expand All @@ -147,6 +161,6 @@ mod test {

#[test]
fn size_of() {
assert_eq!(mem::size_of::<DType>(), 48);
assert_eq!(mem::size_of::<DType>(), 56);
}
}
69 changes: 69 additions & 0 deletions vortex-dtype/src/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt::{Display, Formatter};
use std::sync::Arc;

#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub struct ExtID(Arc<str>);

impl Display for ExtID {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl AsRef<str> for ExtID {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}

impl From<&str> for ExtID {
fn from(value: &str) -> Self {
ExtID(value.into())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ExtMetadata(Arc<[u8]>);

impl AsRef<[u8]> for ExtMetadata {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}

impl From<Arc<[u8]>> for ExtMetadata {
fn from(value: Arc<[u8]>) -> Self {
ExtMetadata(value)
}
}

impl From<&[u8]> for ExtMetadata {
fn from(value: &[u8]) -> Self {
ExtMetadata(value.into())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ExtDType {
id: ExtID,
metadata: Option<ExtMetadata>,
}

impl ExtDType {
pub fn new(id: ExtID, metadata: Option<ExtMetadata>) -> Self {
Self { id, metadata }
}

#[inline]
pub fn id(&self) -> &ExtID {
&self.id
}

#[inline]
pub fn metadata(&self) -> Option<&ExtMetadata> {
self.metadata.as_ref()
}
}
4 changes: 2 additions & 2 deletions vortex-dtype/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::fmt::{Display, Formatter};

pub use dtype::*;
pub use extension::*;
pub use half;
pub use ptype::*;
mod deserialize;
mod dtype;
mod extension;
mod ptype;
mod serde;
mod serialize;

pub use deserialize::*;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub struct CompositeID(pub &'static str);
Expand Down
2 changes: 1 addition & 1 deletion vortex-dtype/src/ptype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::DType;
use crate::DType::*;
use crate::Nullability::NonNullable;

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PType {
U8,
U16,
Expand Down
58 changes: 0 additions & 58 deletions vortex-dtype/src/serde.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1 @@
#![cfg(feature = "serde")]

use flatbuffers::root;
use serde::de::{DeserializeSeed, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer};

use crate::DType;
use crate::{flatbuffers as fb, DTypeSerdeContext};

impl Serialize for DType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.with_flatbuffer_bytes(|bytes| serializer.serialize_bytes(bytes))
}
}

struct DTypeDeserializer(DTypeSerdeContext);

impl<'de> Visitor<'de> for DTypeDeserializer {
type Value = DType;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a vortex dtype")
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let fb = root::<fb::DType>(v).map_err(E::custom)?;
DType::read_flatbuffer(&self.0, &fb).map_err(E::custom)
}
}

impl<'de> DeserializeSeed<'de> for DTypeSerdeContext {
type Value = DType;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_bytes(DTypeDeserializer(self))
}
}

// TODO(ngates): Remove this trait in favour of storing e.g. IdxType which doesn't require
// the context for composite types.
impl<'de> Deserialize<'de> for DType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let ctx = DTypeSerdeContext::new(vec![]);
deserializer.deserialize_bytes(DTypeDeserializer(ctx))
}
}
Loading