Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DType::Struct a struct #278

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bench-vortex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ pub struct CompressionRunStats {

impl CompressionRunStats {
pub fn to_results(&self, dataset_name: String) -> Vec<CompressionRunResults> {
let DType::Struct(ns, fs) = &self.schema else {
let DType::Struct { names, dtypes } = &self.schema else {
unreachable!()
};

self.compressed_sizes
.iter()
.zip_eq(ns.iter().zip_eq(fs))
.zip_eq(names.iter().zip_eq(dtypes))
.map(
|(&size, (column_name, column_type))| CompressionRunResults {
dataset_name: dataset_name.clone(),
Expand Down
4 changes: 2 additions & 2 deletions bench-vortex/src/vortex_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ pub fn vortex_chunk_sizes(path: &Path) -> VortexResult<CompressionRunStats> {
let file = File::open(path)?;
let total_compressed_size = file.metadata()?.size();
let vortex = open_vortex(path)?;
let DType::Struct(ns, _) = vortex.dtype() else {
let DType::Struct { names, .. } = vortex.dtype() else {
unreachable!()
};

let mut compressed_sizes = vec![0; ns.len()];
let mut compressed_sizes = vec![0; names.len()];
let chunked_array = ChunkedArray::try_from(vortex).unwrap();
for chunk in chunked_array.chunks() {
let struct_arr = StructArray::try_from(chunk).unwrap();
Expand Down
13 changes: 6 additions & 7 deletions vortex-array/src/array/composite/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use flatbuffers::root;
use vortex_dtype::flatbuffers as fb;
use vortex_dtype::{CompositeID, DTypeSerdeContext};
use vortex_dtype::CompositeID;
use vortex_error::{vortex_err, VortexResult};
use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer};

Expand Down Expand Up @@ -34,7 +34,7 @@ impl TrySerializeArrayMetadata for CompositeMetadata {
let mut fb = flexbuffers::Builder::default();
{
let mut elems = fb.start_vector();
elems.push(self.ext.id().0);
elems.push(self.ext.id().as_ref());
self.underlying_dtype
.with_flatbuffer_bytes(|b| elems.push(flexbuffers::Blob(b)));
elems.push(flexbuffers::Blob(self.underlying_metadata.as_ref()));
Expand All @@ -53,9 +53,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata {
.ok_or_else(|| vortex_err!("Unrecognized composite extension: {}", ext_id))?;

let dtype_blob = elems.index(1).expect("missing dtype").as_blob();
let ctx = DTypeSerdeContext::new(vec![]); // FIXME: composite_ids
let underlying_dtype = DType::read_flatbuffer(
&ctx,
&(),
&root::<fb::DType>(dtype_blob.0).expect("invalid dtype"),
)?;

Expand All @@ -77,8 +76,8 @@ impl TryDeserializeArrayMetadata<'_> for CompositeMetadata {

impl<'a> CompositeArray<'a> {
pub fn new(id: CompositeID, metadata: Arc<[u8]>, underlying: Array<'a>) -> Self {
let dtype = DType::Composite(id, underlying.dtype().is_nullable().into());
let ext = find_extension(id.0).expect("Unrecognized composite extension");
let dtype = DType::Composite(id.clone(), underlying.dtype().is_nullable().into());
let ext = find_extension(id.as_ref()).expect("Unrecognized composite extension");
Self::try_from_parts(
dtype,
CompositeMetadata {
Expand All @@ -101,7 +100,7 @@ impl CompositeArray<'_> {

#[inline]
pub fn extension(&self) -> CompositeExtensionRef {
find_extension(self.id().0).expect("Unrecognized composite extension")
find_extension(self.id().as_ref()).expect("Unrecognized composite extension")
}

pub fn underlying_metadata(&self) -> &Arc<[u8]> {
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub static VORTEX_COMPOSITE_EXTENSIONS: [&'static dyn CompositeExtension] = [..]
pub fn find_extension(id: &str) -> Option<&'static dyn CompositeExtension> {
VORTEX_COMPOSITE_EXTENSIONS
.iter()
.find(|ext| ext.id().0 == id)
.find(|ext| ext.id().as_ref() == id)
.copied()
}

Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/composite/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ macro_rules! impl_composite {
pub struct [<$T Extension>];

impl [<$T Extension>] {
pub const ID: CompositeID = CompositeID($id);
pub const ID: CompositeID = CompositeID::new($id);

pub fn dtype(nullability: Nullability) -> DType {
DType::Composite(Self::ID, nullability)
Expand Down
14 changes: 7 additions & 7 deletions vortex-array/src/array/struct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@ pub struct StructMetadata {

impl StructArray<'_> {
pub fn child(&self, idx: usize) -> Option<Array> {
let DType::Struct(_, fields) = self.dtype() else {
let DType::Struct { dtypes, .. } = self.dtype() else {
unreachable!()
};
let dtype = fields.get(idx)?;
let dtype = dtypes.get(idx)?;
self.array().child(idx, dtype)
}

pub fn names(&self) -> &FieldNames {
let DType::Struct(names, _fields) = self.dtype() else {
let DType::Struct { names, .. } = self.dtype() else {
unreachable!()
};
names
}

pub fn fields(&self) -> &[DType] {
let DType::Struct(_names, fields) = self.dtype() else {
let DType::Struct { dtypes, .. } = self.dtype() else {
unreachable!()
};
fields.as_slice()
dtypes.as_slice()
}

pub fn nfields(&self) -> usize {
Expand All @@ -61,9 +61,9 @@ impl StructArray<'_> {
vortex_bail!("Expected all struct fields to have length {}", length);
}

let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
let dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
Self::try_from_parts(
DType::Struct(names, field_dtypes),
DType::Struct { names, dtypes },
StructMetadata { length },
fields.into_iter().map(|a| a.into_array_data()).collect(),
HashMap::default(),
Expand Down
18 changes: 9 additions & 9 deletions vortex-array/src/arrow/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ impl TryFromArrowType<&DataType> for PType {

impl FromArrowType<SchemaRef> for DType {
fn from_arrow(value: SchemaRef) -> Self {
DType::Struct(
value
DType::Struct {
names: value
.fields()
.iter()
.map(|f| Arc::new(f.name().clone()))
.collect(),
value
dtypes: value
.fields()
.iter()
.map(|f| DType::from_arrow(f.as_ref()))
.collect_vec(),
)
}
}
}

Expand Down Expand Up @@ -81,13 +81,13 @@ impl FromArrowType<&Field> for DType {
DataType::List(e) | DataType::LargeList(e) => {
List(Box::new(DType::from_arrow(e.as_ref())), nullability)
}
DataType::Struct(f) => Struct(
f.iter().map(|f| Arc::new(f.name().clone())).collect(),
f.iter()
DataType::Struct(f) => Struct {
names: f.iter().map(|f| Arc::new(f.name().clone())).collect(),
dtypes: f
.iter()
.map(|f| DType::from_arrow(f.as_ref()))
.collect_vec(),
),
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(*p, *s, nullability),
},
_ => unimplemented!("Arrow data type not yet supported: {:?}", field.data_type()),
}
}
Expand Down
2 changes: 1 addition & 1 deletion vortex-dtype/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ half = { workspace = true }
itertools = { workspace = true }
linkme = { workspace = true }
num-traits = { workspace = true }
serde = { workspace = true, optional = true }
serde = { workspace = true, optional = true, features = ["rc"] }
thiserror = { workspace = true }
vortex-error = { path = "../vortex-error" }
vortex-flatbuffers = { path = "../vortex-flatbuffers" }
Expand Down
50 changes: 50 additions & 0 deletions vortex-dtype/src/composite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::fmt::{Display, Formatter};

use linkme::distributed_slice;
use vortex_error::{vortex_err, VortexError};

#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize))]
pub struct CompositeID(&'static str);

impl CompositeID {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given pr title I am surprised to see these changes here. Can we change the pr title or make separate pr?

pub const fn new(id: &'static str) -> Self {
Self(id)
}
}

impl<'a> TryFrom<&'a str> for CompositeID {
type Error = VortexError;

fn try_from(value: &'a str) -> Result<Self, Self::Error> {
find_composite_dtype(value)
.map(|cdt| CompositeID(cdt.id()))
.ok_or_else(|| vortex_err!("CompositeID not found for the given id: {}", value))
}
}

impl AsRef<str> for CompositeID {
fn as_ref(&self) -> &str {
self.0
}
}

impl Display for CompositeID {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

pub trait CompositeDType {
fn id(&self) -> &'static str;
}

#[distributed_slice]
pub static VORTEX_COMPOSITE_DTYPES: [&'static dyn CompositeDType] = [..];

pub fn find_composite_dtype(id: &str) -> Option<&'static dyn CompositeDType> {
VORTEX_COMPOSITE_DTYPES
.iter()
.find(|ext| ext.id() == id)
.copied()
}
44 changes: 7 additions & 37 deletions vortex-dtype/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,11 @@ use vortex_flatbuffers::ReadFlatBuffer;
use crate::{flatbuffers as fb, Nullability};
use crate::{CompositeID, DType};

#[allow(dead_code)]
pub struct DTypeSerdeContext {
composite_ids: Arc<[CompositeID]>,
}

impl DTypeSerdeContext {
pub fn new(composite_ids: Vec<CompositeID>) -> Self {
Self {
composite_ids: composite_ids.into(),
}
}

pub fn find_composite_id(&self, id: &str) -> Option<CompositeID> {
self.composite_ids.iter().find(|c| c.0 == id).copied()
}
}

impl ReadFlatBuffer<DTypeSerdeContext> for DType {
impl ReadFlatBuffer<()> for DType {
type Source<'a> = fb::DType<'a>;
type Error = VortexError;

fn read_flatbuffer(
ctx: &DTypeSerdeContext,
fb: &Self::Source<'_>,
) -> Result<Self, Self::Error> {
fn read_flatbuffer(_ctx: &(), fb: &Self::Source<'_>) -> Result<Self, Self::Error> {
match fb.type_type() {
fb::Type::Null => Ok(DType::Null),
fb::Type::Bool => Ok(DType::Bool(
Expand All @@ -43,14 +23,6 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
fb_primitive.nullability().try_into()?,
))
}
fb::Type::Decimal => {
let fb_decimal = fb.type__as_decimal().unwrap();
Ok(DType::Decimal(
fb_decimal.precision(),
fb_decimal.scale(),
fb_decimal.nullability().try_into()?,
))
}
fb::Type::Binary => Ok(DType::Binary(
fb.type__as_binary().unwrap().nullability().try_into()?,
)),
Expand All @@ -59,7 +31,7 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
)),
fb::Type::List => {
let fb_list = fb.type__as_list().unwrap();
let element_dtype = DType::read_flatbuffer(ctx, &fb_list.element_type().unwrap())?;
let element_dtype = DType::read_flatbuffer(&(), &fb_list.element_type().unwrap())?;
Ok(DType::List(
Box::new(element_dtype),
fb_list.nullability().try_into()?,
Expand All @@ -73,19 +45,17 @@ impl ReadFlatBuffer<DTypeSerdeContext> for DType {
.iter()
.map(|n| Arc::new(n.to_string()))
.collect::<Vec<_>>();
let fields: Vec<DType> = fb_struct
let dtypes: Vec<DType> = fb_struct
.fields()
.unwrap()
.iter()
.map(|f| DType::read_flatbuffer(ctx, &f))
.map(|f| DType::read_flatbuffer(&(), &f))
.collect::<VortexResult<Vec<_>>>()?;
Ok(DType::Struct(names, fields))
Ok(DType::Struct { names, dtypes })
}
fb::Type::Composite => {
let fb_composite = fb.type__as_composite().unwrap();
let id = ctx
.find_composite_id(fb_composite.id().unwrap())
.ok_or_else(|| vortex_err!("Couldn't find composite id"))?;
let id = CompositeID::try_from(fb_composite.id().unwrap())?;
Ok(DType::Composite(id, fb_composite.nullability().try_into()?))
}
_ => Err(vortex_err!("Unknown DType variant")),
Expand Down
Loading