From 9d955e654cb5d28eb67699205edf6d59ac3f46ac Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 30 Apr 2024 12:36:45 +0100 Subject: [PATCH] Refactor for DType::Primitive (#276) Fixes #154 --- pyvortex/src/lib.rs | 51 ++++++--- pyvortex/test/test_dtype.py | 10 +- vortex-alp/src/array.rs | 10 +- vortex-array/src/array/chunked/mod.rs | 18 +-- vortex-array/src/array/sparse/compute/mod.rs | 4 +- vortex-array/src/array/sparse/mod.rs | 5 +- vortex-array/src/array/varbin/mod.rs | 9 +- vortex-array/src/array/varbinview/mod.rs | 12 +- vortex-array/src/arrow/dtype.rs | 18 +-- vortex-datetime-parts/src/array.rs | 6 +- vortex-dict/src/dict.rs | 4 +- vortex-dtype/flatbuffers/dtype.fbs | 41 +++---- vortex-dtype/src/deserialize.rs | 59 +--------- vortex-dtype/src/dtype.rs | 112 ++----------------- vortex-dtype/src/lib.rs | 2 +- vortex-dtype/src/ptype.rs | 89 +++++++-------- vortex-dtype/src/serialize.rs | 91 +++++++-------- vortex-fastlanes/src/bitpacking/mod.rs | 10 +- vortex-ipc/benches/ipc_array_reader_take.rs | 8 +- vortex-ipc/src/reader.rs | 29 +++-- vortex-ree/src/ree.rs | 6 +- vortex-roaring/src/integer/compress.rs | 7 +- vortex-scalar/flatbuffers/scalar.fbs | 18 +-- vortex-scalar/src/composite.rs | 8 +- vortex-scalar/src/lib.rs | 19 +--- vortex-scalar/src/list.rs | 9 +- vortex-scalar/src/primitive.rs | 30 +++-- vortex-scalar/src/serde.rs | 40 ------- vortex-zigzag/src/zigzag.rs | 12 +- 29 files changed, 255 insertions(+), 482 deletions(-) diff --git a/pyvortex/src/lib.rs b/pyvortex/src/lib.rs index 14305e84ba..439706fab3 100644 --- a/pyvortex/src/lib.rs +++ b/pyvortex/src/lib.rs @@ -1,9 +1,9 @@ use dtype::PyDType; use log::debug; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use vortex::encoding::VORTEX_ENCODINGS; -use vortex_dtype::DType; -use vortex_dtype::Signedness::{Signed, Unsigned}; +use vortex_dtype::{DType, PType}; use crate::array::*; @@ -69,28 +69,51 @@ fn dtype_bool(py: Python<'_>, nullable: bool) -> PyResult> { #[pyfunction(name = "int")] #[pyo3(signature = (width = None, nullable = false))] fn dtype_int(py: Python<'_>, width: Option, nullable: bool) -> PyResult> { - PyDType::wrap( - py, - DType::Int(width.unwrap_or(64).into(), Signed, nullable.into()), - ) + let dtype = if let Some(width) = width { + match width { + 8 => DType::Primitive(PType::I8, nullable.into()), + 16 => DType::Primitive(PType::I16, nullable.into()), + 32 => DType::Primitive(PType::I32, nullable.into()), + 64 => DType::Primitive(PType::I64, nullable.into()), + _ => return Err(PyValueError::new_err("Invalid int width")), + } + } else { + DType::Primitive(PType::I64, nullable.into()) + }; + PyDType::wrap(py, dtype) } #[pyfunction(name = "uint")] #[pyo3(signature = (width = None, nullable = false))] fn dtype_uint(py: Python<'_>, width: Option, nullable: bool) -> PyResult> { - PyDType::wrap( - py, - DType::Int(width.unwrap_or(64).into(), Unsigned, nullable.into()), - ) + let dtype = if let Some(width) = width { + match width { + 8 => DType::Primitive(PType::U8, nullable.into()), + 16 => DType::Primitive(PType::U16, nullable.into()), + 32 => DType::Primitive(PType::U32, nullable.into()), + 64 => DType::Primitive(PType::U64, nullable.into()), + _ => return Err(PyValueError::new_err("Invalid uint width")), + } + } else { + DType::Primitive(PType::U64, nullable.into()) + }; + PyDType::wrap(py, dtype) } #[pyfunction(name = "float")] #[pyo3(signature = (width = None, nullable = false))] fn dtype_float(py: Python<'_>, width: Option, nullable: bool) -> PyResult> { - PyDType::wrap( - py, - DType::Float(width.unwrap_or(64).into(), nullable.into()), - ) + let dtype = if let Some(width) = width { + match width { + 16 => DType::Primitive(PType::F16, nullable.into()), + 32 => DType::Primitive(PType::F32, nullable.into()), + 64 => DType::Primitive(PType::F64, nullable.into()), + _ => return Err(PyValueError::new_err("Invalid float width")), + } + } else { + DType::Primitive(PType::F64, nullable.into()) + }; + PyDType::wrap(py, dtype) } #[pyfunction(name = "utf8")] diff --git a/pyvortex/test/test_dtype.py b/pyvortex/test/test_dtype.py index c7758983aa..207a8ebc53 100644 --- a/pyvortex/test/test_dtype.py +++ b/pyvortex/test/test_dtype.py @@ -2,9 +2,9 @@ def test_int(): - assert str(vortex.int()) == "int(64)" - assert str(vortex.int(32)) == "int(32)" - assert str(vortex.int(32, nullable=True)) == "int(32)?" - assert str(vortex.uint(32)) == "uint(32)" - assert str(vortex.float(16)) == "float(16)" + assert str(vortex.int()) == "i64" + assert str(vortex.int(32)) == "i32" + assert str(vortex.int(32, nullable=True)) == "i32?" + assert str(vortex.uint(32)) == "u32" + assert str(vortex.float(16)) == "f16" assert str(vortex.bool(nullable=True)) == "bool?" diff --git a/vortex-alp/src/array.rs b/vortex-alp/src/array.rs index af823d7476..198e2555f9 100644 --- a/vortex-alp/src/array.rs +++ b/vortex-alp/src/array.rs @@ -4,7 +4,7 @@ use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, ArrayFlatten, IntoArrayData, OwnedArray, ToArrayData}; -use vortex_dtype::{IntWidth, Signedness}; +use vortex_dtype::PType; use vortex_error::{vortex_bail, VortexResult}; use crate::alp::Exponents; @@ -27,12 +27,8 @@ impl ALPArray<'_> { ) -> VortexResult { let encoded_dtype = encoded.dtype().clone(); let dtype = match encoded.dtype() { - DType::Int(IntWidth::_32, Signedness::Signed, nullability) => { - DType::Float(32.into(), *nullability) - } - DType::Int(IntWidth::_64, Signedness::Signed, nullability) => { - DType::Float(64.into(), *nullability) - } + DType::Primitive(PType::I32, nullability) => DType::Primitive(PType::F32, *nullability), + DType::Primitive(PType::I64, nullability) => DType::Primitive(PType::F64, *nullability), d => vortex_bail!(MismatchedTypes: "int32 or int64", d), }; diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index b1acacff23..a0f185fd40 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -1,6 +1,6 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; -use vortex_dtype::{IntWidth, Nullability, Signedness}; +use vortex_dtype::{Nullability, PType}; use vortex_error::{vortex_bail, VortexResult}; use crate::array::primitive::PrimitiveArray; @@ -20,11 +20,7 @@ impl_encoding!("vortex.chunked", Chunked); pub struct ChunkedMetadata; impl ChunkedArray<'_> { - const ENDS_DTYPE: DType = DType::Int( - IntWidth::_64, - Signedness::Unsigned, - Nullability::NonNullable, - ); + const ENDS_DTYPE: DType = DType::Primitive(PType::U64, Nullability::NonNullable); pub fn try_new(chunks: Vec, dtype: DType) -> VortexResult { for chunk in &chunks { @@ -145,8 +141,8 @@ impl EncodingCompression for ChunkedEncoding {} #[cfg(test)] mod test { - use vortex_dtype::NativePType; - use vortex_dtype::{DType, IntWidth, Nullability, Signedness}; + use vortex_dtype::{DType, Nullability}; + use vortex_dtype::{NativePType, PType}; use crate::array::chunked::{ChunkedArray, OwnedChunkedArray}; use crate::{Array, IntoArray}; @@ -159,11 +155,7 @@ mod test { vec![4u64, 5, 6].into_array(), vec![7u64, 8, 9].into_array(), ], - DType::Int( - IntWidth::_64, - Signedness::Unsigned, - Nullability::NonNullable, - ), + DType::Primitive(PType::U64, Nullability::NonNullable), ) .unwrap() } diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index ab97fb7a9d..ff1cba531c 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -138,7 +138,7 @@ fn take_search_sorted( #[cfg(test)] mod test { use itertools::Itertools; - use vortex_dtype::{DType, FloatWidth, Nullability}; + use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; @@ -156,7 +156,7 @@ mod test { PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid) .into_array(), 100, - Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)), + Scalar::null(&DType::Primitive(PType::F64, Nullability::Nullable)), ) .into_array() } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 164bca4bba..3bd3a7f55a 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -174,8 +174,7 @@ impl ArrayValidity for SparseArray<'_> { mod test { use itertools::Itertools; use vortex_dtype::Nullability::Nullable; - use vortex_dtype::Signedness::Signed; - use vortex_dtype::{DType, IntWidth}; + use vortex_dtype::{DType, PType}; use vortex_error::VortexError; use vortex_scalar::Scalar; @@ -187,7 +186,7 @@ mod test { use crate::{Array, IntoArray, OwnedArray}; fn nullable_fill() -> Scalar { - Scalar::null(&DType::Int(IntWidth::_32, Signed, Nullable)) + Scalar::null(&DType::Primitive(PType::I32, Nullable)) } #[allow(dead_code)] diff --git a/vortex-array/src/array/varbin/mod.rs b/vortex-array/src/array/varbin/mod.rs index 536eed04f3..4227d1383b 100644 --- a/vortex-array/src/array/varbin/mod.rs +++ b/vortex-array/src/array/varbin/mod.rs @@ -1,7 +1,7 @@ use num_traits::AsPrimitive; use serde::{Deserialize, Serialize}; +use vortex_dtype::Nullability; use vortex_dtype::{match_each_native_ptype, NativePType}; -use vortex_dtype::{IntWidth, Nullability, Signedness}; use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::{BinaryScalar, Utf8Scalar}; @@ -37,13 +37,10 @@ impl VarBinArray<'_> { dtype: DType, validity: Validity, ) -> VortexResult { - if !matches!(offsets.dtype(), DType::Int(_, _, Nullability::NonNullable)) { + if !offsets.dtype().is_int() || offsets.dtype().is_nullable() { vortex_bail!(MismatchedTypes: "non nullable int", offsets.dtype()); } - if !matches!( - bytes.dtype(), - DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable) - ) { + if !matches!(bytes.dtype(), &DType::BYTES,) { vortex_bail!(MismatchedTypes: "u8", bytes.dtype()); } if !matches!(dtype, DType::Binary(_) | DType::Utf8(_)) { diff --git a/vortex-array/src/array/varbinview/mod.rs b/vortex-array/src/array/varbinview/mod.rs index 55c993df07..333764b17f 100644 --- a/vortex-array/src/array/varbinview/mod.rs +++ b/vortex-array/src/array/varbinview/mod.rs @@ -2,7 +2,7 @@ use std::fmt::Formatter; use std::{mem, slice}; use ::serde::{Deserialize, Serialize}; -use vortex_dtype::{IntWidth, Nullability, Signedness}; +use vortex_dtype::Nullability; use vortex_error::{vortex_bail, VortexResult}; use crate::array::primitive::PrimitiveArray; @@ -111,18 +111,12 @@ impl VarBinViewArray<'_> { dtype: DType, validity: Validity, ) -> VortexResult { - if !matches!( - views.dtype(), - DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable) - ) { + if !matches!(views.dtype(), &DType::BYTES) { vortex_bail!(MismatchedTypes: "u8", views.dtype()); } for d in data.iter() { - if !matches!( - d.dtype(), - DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable) - ) { + if !matches!(d.dtype(), &DType::BYTES) { vortex_bail!(MismatchedTypes: "u8", d.dtype()); } } diff --git a/vortex-array/src/arrow/dtype.rs b/vortex-array/src/arrow/dtype.rs index 133e0afb98..5a06ea18d2 100644 --- a/vortex-array/src/arrow/dtype.rs +++ b/vortex-array/src/arrow/dtype.rs @@ -4,7 +4,7 @@ use arrow_schema::TimeUnit as ArrowTimeUnit; use arrow_schema::{DataType, Field, SchemaRef}; use itertools::Itertools; use vortex_dtype::PType; -use vortex_dtype::{DType, FloatWidth, IntWidth, Nullability}; +use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_err, VortexResult}; use crate::array::datetime::{LocalDateTimeExtension, TimeUnit}; @@ -58,24 +58,16 @@ impl FromArrowType for DType { impl FromArrowType<&Field> for DType { fn from_arrow(field: &Field) -> Self { use vortex_dtype::DType::*; - use vortex_dtype::Signedness::*; let nullability: Nullability = field.is_nullable().into(); + if let Ok(ptype) = PType::try_from_arrow(field.data_type()) { + return Primitive(ptype, nullability); + } + match field.data_type() { DataType::Null => Null, DataType::Boolean => Bool(nullability), - DataType::Int8 => Int(IntWidth::_8, Signed, nullability), - DataType::Int16 => Int(IntWidth::_16, Signed, nullability), - DataType::Int32 => Int(IntWidth::_32, Signed, nullability), - DataType::Int64 => Int(IntWidth::_64, Signed, nullability), - DataType::UInt8 => Int(IntWidth::_8, Unsigned, nullability), - DataType::UInt16 => Int(IntWidth::_16, Unsigned, nullability), - DataType::UInt32 => Int(IntWidth::_32, Unsigned, nullability), - DataType::UInt64 => Int(IntWidth::_64, Unsigned, nullability), - DataType::Float16 => Float(FloatWidth::_16, nullability), - DataType::Float32 => Float(FloatWidth::_32, nullability), - DataType::Float64 => Float(FloatWidth::_64, nullability), DataType::Utf8 | DataType::LargeUtf8 => Utf8(nullability), DataType::Binary | DataType::LargeBinary => Binary(nullability), DataType::Timestamp(_u, tz) => match tz { diff --git a/vortex-datetime-parts/src/array.rs b/vortex-datetime-parts/src/array.rs index 8042eea392..3180138141 100644 --- a/vortex-datetime-parts/src/array.rs +++ b/vortex-datetime-parts/src/array.rs @@ -23,13 +23,13 @@ impl DateTimePartsArray<'_> { subsecond: Array, validity: Validity, ) -> VortexResult { - if !matches!(days.dtype(), DType::Int(_, _, _)) { + if !days.dtype().is_int() { vortex_bail!(MismatchedTypes: "any integer", days.dtype()); } - if !matches!(seconds.dtype(), DType::Int(_, _, _)) { + if !seconds.dtype().is_int() { vortex_bail!(MismatchedTypes: "any integer", seconds.dtype()); } - if !matches!(subsecond.dtype(), DType::Int(_, _, _)) { + if !subsecond.dtype().is_int() { vortex_bail!(MismatchedTypes: "any integer", subsecond.dtype()); } diff --git a/vortex-dict/src/dict.rs b/vortex-dict/src/dict.rs index ec85706e8b..9cf8d6c579 100644 --- a/vortex-dict/src/dict.rs +++ b/vortex-dict/src/dict.rs @@ -7,7 +7,7 @@ use vortex::validity::{ArrayValidity, LogicalValidity}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::IntoArrayData; use vortex::{impl_encoding, ArrayDType, ArrayFlatten, ToArrayData}; -use vortex_dtype::{match_each_integer_ptype, Signedness}; +use vortex_dtype::match_each_integer_ptype; use vortex_error::{vortex_bail, VortexResult}; impl_encoding!("vortex.dict", Dict); @@ -19,7 +19,7 @@ pub struct DictMetadata { impl DictArray<'_> { pub fn try_new(codes: Array, values: Array) -> VortexResult { - if !matches!(codes.dtype(), DType::Int(_, Signedness::Unsigned, _)) { + if !codes.dtype().is_unsigned_int() { vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype()); } Self::try_from_parts( diff --git a/vortex-dtype/flatbuffers/dtype.fbs b/vortex-dtype/flatbuffers/dtype.fbs index ab7e9825e4..041d61b934 100644 --- a/vortex-dtype/flatbuffers/dtype.fbs +++ b/vortex-dtype/flatbuffers/dtype.fbs @@ -5,22 +5,18 @@ enum Nullability: byte { Nullable, } -enum Signedness: byte { - Signed, - Unsigned, -} - -enum IntWidth: byte { - _8, - _16, - _32, - _64, -} - -enum FloatWidth: byte { - _16, - _32, - _64, +enum PType: uint8 { + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + F16, + F32, + F64, } table Null {} @@ -29,9 +25,8 @@ table Bool { nullability: Nullability; } -table Int { - width: IntWidth; - signedness: Signedness; +table Primitive { + ptype: PType; nullability: Nullability; } @@ -44,11 +39,6 @@ table Decimal { nullability: Nullability; } -table Float { - width: FloatWidth; - nullability: Nullability; -} - table Utf8 { nullability: Nullability; } @@ -75,9 +65,8 @@ table Composite { union Type { Null, Bool, - Int, + Primitive, Decimal, - Float, Utf8, Binary, Struct_, diff --git a/vortex-dtype/src/deserialize.rs b/vortex-dtype/src/deserialize.rs index efaf1acad2..4ef3b2a406 100644 --- a/vortex-dtype/src/deserialize.rs +++ b/vortex-dtype/src/deserialize.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use vortex_error::{vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::ReadFlatBuffer; -use crate::{flatbuffers as fb, FloatWidth, IntWidth, Nullability, Signedness}; +use crate::{flatbuffers as fb, Nullability}; use crate::{CompositeID, DType}; #[allow(dead_code)] @@ -36,19 +36,11 @@ impl ReadFlatBuffer for DType { fb::Type::Bool => Ok(DType::Bool( fb.type__as_bool().unwrap().nullability().try_into()?, )), - fb::Type::Int => { - let fb_int = fb.type__as_int().unwrap(); - Ok(DType::Int( - fb_int.width().try_into()?, - fb_int.signedness().try_into()?, - fb_int.nullability().try_into()?, - )) - } - fb::Type::Float => { - let fb_float = fb.type__as_float().unwrap(); - Ok(DType::Float( - fb_float.width().try_into()?, - fb_float.nullability().try_into()?, + fb::Type::Primitive => { + let fb_primitive = fb.type__as_primitive().unwrap(); + Ok(DType::Primitive( + fb_primitive.ptype().try_into()?, + fb_primitive.nullability().try_into()?, )) } fb::Type::Decimal => { @@ -112,42 +104,3 @@ impl TryFrom for Nullability { } } } - -impl TryFrom for IntWidth { - type Error = VortexError; - - fn try_from(value: fb::IntWidth) -> VortexResult { - match value { - fb::IntWidth::_8 => Ok(IntWidth::_8), - fb::IntWidth::_16 => Ok(IntWidth::_16), - fb::IntWidth::_32 => Ok(IntWidth::_32), - fb::IntWidth::_64 => Ok(IntWidth::_64), - _ => Err(vortex_err!("Unknown IntWidth value")), - } - } -} - -impl TryFrom for Signedness { - type Error = VortexError; - - fn try_from(value: fb::Signedness) -> VortexResult { - match value { - fb::Signedness::Unsigned => Ok(Signedness::Unsigned), - fb::Signedness::Signed => Ok(Signedness::Signed), - _ => Err(vortex_err!("Unknown Signedness value")), - } - } -} - -impl TryFrom for FloatWidth { - type Error = VortexError; - - fn try_from(value: fb::FloatWidth) -> VortexResult { - match value { - fb::FloatWidth::_16 => Ok(FloatWidth::_16), - fb::FloatWidth::_32 => Ok(FloatWidth::_32), - fb::FloatWidth::_64 => Ok(FloatWidth::_64), - _ => Err(vortex_err!("Unknown IntWidth value")), - } - } -} diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index e1094c6195..0c9411812f 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use itertools::Itertools; use DType::*; -use crate::CompositeID; +use crate::{CompositeID, PType}; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] @@ -43,101 +43,16 @@ impl Display for Nullability { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub enum Signedness { - Unsigned, - Signed, -} - -impl From for Signedness { - fn from(value: bool) -> Self { - if value { - Signedness::Signed - } else { - Signedness::Unsigned - } - } -} - -impl Display for Signedness { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Signedness::Unsigned => write!(f, "unsigned"), - Signedness::Signed => write!(f, "signed"), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub enum IntWidth { - _8, - _16, - _32, - _64, -} - -impl From for IntWidth { - fn from(item: u16) -> Self { - match item { - 8 => IntWidth::_8, - 16 => IntWidth::_16, - 32 => IntWidth::_32, - 64 => IntWidth::_64, - _ => panic!("Invalid int width: {}", item), - } - } -} - -impl Display for IntWidth { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - IntWidth::_8 => write!(f, "8"), - IntWidth::_16 => write!(f, "16"), - IntWidth::_32 => write!(f, "32"), - IntWidth::_64 => write!(f, "64"), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub enum FloatWidth { - _16, - _32, - _64, -} - -impl From for FloatWidth { - fn from(item: i8) -> Self { - match item { - 16 => FloatWidth::_16, - 32 => FloatWidth::_32, - 64 => FloatWidth::_64, - _ => panic!("Invalid float width: {}", item), - } - } -} - -impl Display for FloatWidth { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - FloatWidth::_16 => write!(f, "16"), - FloatWidth::_32 => write!(f, "32"), - FloatWidth::_64 => write!(f, "64"), - } - } -} - pub type FieldNames = Vec>; pub type Metadata = Vec; -#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum DType { Null, Bool(Nullability), - Int(IntWidth, Signedness, Nullability), + Primitive(PType, Nullability), Decimal(u8, i8, Nullability), - Float(FloatWidth, Nullability), Utf8(Nullability), Binary(Nullability), Struct(FieldNames, Vec), @@ -146,14 +61,10 @@ pub enum DType { } impl DType { - pub const BYTES: DType = Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable); + pub const BYTES: DType = Primitive(PType::U8, Nullability::NonNullable); /// The default DType for indices - pub const IDX: DType = Int( - IntWidth::_64, - Signedness::Unsigned, - Nullability::NonNullable, - ); + pub const IDX: DType = Primitive(PType::U64, Nullability::NonNullable); pub fn nullability(&self) -> Nullability { self.is_nullable().into() @@ -165,9 +76,8 @@ impl DType { match self { Null => true, Bool(n) => matches!(n, Nullable), - Int(_, _, n) => matches!(n, Nullable), + Primitive(_, n) => matches!(n, Nullable), Decimal(_, _, n) => matches!(n, Nullable), - Float(_, n) => matches!(n, Nullable), Utf8(n) => matches!(n, Nullable), Binary(n) => matches!(n, Nullable), Struct(_, fs) => fs.iter().all(|f| f.is_nullable()), @@ -188,9 +98,8 @@ impl DType { match self { Null => Null, Bool(_) => Bool(nullability), - Int(w, s, _) => Int(*w, *s, nullability), + Primitive(p, _) => Primitive(*p, nullability), Decimal(s, p, _) => Decimal(*s, *p, nullability), - Float(w, _) => Float(*w, nullability), Utf8(_) => Utf8(nullability), Binary(_) => Binary(nullability), Struct(n, fs) => Struct( @@ -209,16 +118,11 @@ impl DType { impl Display for DType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use Signedness::*; match self { Null => write!(f, "null"), Bool(n) => write!(f, "bool{}", n), - Int(w, s, n) => match s { - Unsigned => write!(f, "uint({}){}", w, n), - Signed => write!(f, "int({}){}", w, n), - }, + Primitive(p, n) => write!(f, "{}{}", p, n), Decimal(p, s, n) => write!(f, "decimal({}, {}){}", p, s, n), - Float(w, n) => write!(f, "float({}){}", w, n), Utf8(n) => write!(f, "utf8{}", n), Binary(n) => write!(f, "binary{}", n), Struct(n, dt) => write!( diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index b7909b6ed3..eb293e0223 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -11,7 +11,7 @@ mod serialize; pub use deserialize::*; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] pub struct CompositeID(pub &'static str); diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index 59a67d7c22..e1715dd710 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -5,8 +5,9 @@ use num_traits::{FromPrimitive, Num, NumCast}; use vortex_error::{vortex_err, VortexError, VortexResult}; use crate::half::f16; +use crate::DType; use crate::DType::*; -use crate::{DType, FloatWidth, IntWidth}; +use crate::Nullability::NonNullable; #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)] @@ -150,6 +151,32 @@ impl PType { } } +impl DType { + pub fn is_unsigned_int(&self) -> bool { + PType::try_from(self) + .map(|ptype| ptype.is_unsigned_int()) + .unwrap_or_default() + } + + pub fn is_signed_int(&self) -> bool { + PType::try_from(self) + .map(|ptype| ptype.is_signed_int()) + .unwrap_or_default() + } + + pub fn is_int(&self) -> bool { + PType::try_from(self) + .map(|ptype| ptype.is_int()) + .unwrap_or_default() + } + + pub fn is_float(&self) -> bool { + PType::try_from(self) + .map(|ptype| ptype.is_float()) + .unwrap_or_default() + } +} + impl Display for PType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -172,23 +199,8 @@ impl TryFrom<&DType> for PType { type Error = VortexError; fn try_from(value: &DType) -> VortexResult { - use crate::Signedness::*; match value { - Int(w, s, _) => match (w, s) { - (IntWidth::_8, Signed) => Ok(PType::I8), - (IntWidth::_16, Signed) => Ok(PType::I16), - (IntWidth::_32, Signed) => Ok(PType::I32), - (IntWidth::_64, Signed) => Ok(PType::I64), - (IntWidth::_8, Unsigned) => Ok(PType::U8), - (IntWidth::_16, Unsigned) => Ok(PType::U16), - (IntWidth::_32, Unsigned) => Ok(PType::U32), - (IntWidth::_64, Unsigned) => Ok(PType::U64), - }, - Float(f, _) => match f { - FloatWidth::_16 => Ok(PType::F16), - FloatWidth::_32 => Ok(PType::F32), - FloatWidth::_64 => Ok(PType::F64), - }, + Primitive(p, _) => Ok(*p), _ => Err(vortex_err!("Cannot convert DType {} into PType", value)), } } @@ -196,42 +208,25 @@ impl TryFrom<&DType> for PType { impl From for &DType { fn from(item: PType) -> Self { - use crate::Nullability::*; - use crate::Signedness::*; - + // We expand this match statement so that we can return a static reference. match item { - PType::I8 => &Int(IntWidth::_8, Signed, NonNullable), - PType::I16 => &Int(IntWidth::_16, Signed, NonNullable), - PType::I32 => &Int(IntWidth::_32, Signed, NonNullable), - PType::I64 => &Int(IntWidth::_64, Signed, NonNullable), - PType::U8 => &Int(IntWidth::_8, Unsigned, NonNullable), - PType::U16 => &Int(IntWidth::_16, Unsigned, NonNullable), - PType::U32 => &Int(IntWidth::_32, Unsigned, NonNullable), - PType::U64 => &Int(IntWidth::_64, Unsigned, NonNullable), - PType::F16 => &Float(FloatWidth::_16, NonNullable), - PType::F32 => &Float(FloatWidth::_32, NonNullable), - PType::F64 => &Float(FloatWidth::_64, NonNullable), + PType::I8 => &Primitive(PType::I8, NonNullable), + PType::I16 => &Primitive(PType::I16, NonNullable), + PType::I32 => &Primitive(PType::I32, NonNullable), + PType::I64 => &Primitive(PType::I64, NonNullable), + PType::U8 => &Primitive(PType::U8, NonNullable), + PType::U16 => &Primitive(PType::U16, NonNullable), + PType::U32 => &Primitive(PType::U32, NonNullable), + PType::U64 => &Primitive(PType::U64, NonNullable), + PType::F16 => &Primitive(PType::F16, NonNullable), + PType::F32 => &Primitive(PType::F32, NonNullable), + PType::F64 => &Primitive(PType::F64, NonNullable), } } } impl From for DType { fn from(item: PType) -> Self { - use crate::Nullability::*; - use crate::Signedness::*; - - match item { - PType::I8 => Int(IntWidth::_8, Signed, NonNullable), - PType::I16 => Int(IntWidth::_16, Signed, NonNullable), - PType::I32 => Int(IntWidth::_32, Signed, NonNullable), - PType::I64 => Int(IntWidth::_64, Signed, NonNullable), - PType::U8 => Int(IntWidth::_8, Unsigned, NonNullable), - PType::U16 => Int(IntWidth::_16, Unsigned, NonNullable), - PType::U32 => Int(IntWidth::_32, Unsigned, NonNullable), - PType::U64 => Int(IntWidth::_64, Unsigned, NonNullable), - PType::F16 => Float(FloatWidth::_16, NonNullable), - PType::F32 => Float(FloatWidth::_32, NonNullable), - PType::F64 => Float(FloatWidth::_64, NonNullable), - } + Primitive(item, NonNullable) } } diff --git a/vortex-dtype/src/serialize.rs b/vortex-dtype/src/serialize.rs index d7c2bd234b..1aa485807a 100644 --- a/vortex-dtype/src/serialize.rs +++ b/vortex-dtype/src/serialize.rs @@ -1,9 +1,10 @@ use flatbuffers::{FlatBufferBuilder, WIPOffset}; use itertools::Itertools; +use vortex_error::{vortex_bail, VortexError}; use vortex_flatbuffers::{FlatBufferRoot, WriteFlatBuffer}; -use crate::flatbuffers as fb; -use crate::{DType, FloatWidth, IntWidth, Nullability, Signedness}; +use crate::{flatbuffers as fb, PType}; +use crate::{DType, Nullability}; impl FlatBufferRoot for DType {} impl WriteFlatBuffer for DType { @@ -22,11 +23,10 @@ impl WriteFlatBuffer for DType { }, ) .as_union_value(), - DType::Int(width, signednedss, n) => fb::Int::create( + DType::Primitive(ptype, n) => fb::Primitive::create( fbb, - &fb::IntArgs { - width: width.into(), - signedness: signednedss.into(), + &fb::PrimitiveArgs { + ptype: (*ptype).into(), nullability: n.into(), }, ) @@ -40,14 +40,6 @@ impl WriteFlatBuffer for DType { }, ) .as_union_value(), - DType::Float(width, n) => fb::Float::create( - fbb, - &fb::FloatArgs { - width: width.into(), - nullability: n.into(), - }, - ) - .as_union_value(), DType::Utf8(n) => fb::Utf8::create( fbb, &fb::Utf8Args { @@ -104,9 +96,8 @@ impl WriteFlatBuffer for DType { let dtype_type = match self { DType::Null => fb::Type::Null, DType::Bool(_) => fb::Type::Bool, - DType::Int(..) => fb::Type::Int, + DType::Primitive(..) => fb::Type::Primitive, DType::Decimal(..) => fb::Type::Decimal, - DType::Float(..) => fb::Type::Float, DType::Utf8(_) => fb::Type::Utf8, DType::Binary(_) => fb::Type::Binary, DType::Struct(..) => fb::Type::Struct_, @@ -142,33 +133,42 @@ impl From<&Nullability> for fb::Nullability { } } -impl From<&IntWidth> for fb::IntWidth { - fn from(value: &IntWidth) -> Self { - match value { - IntWidth::_8 => fb::IntWidth::_8, - IntWidth::_16 => fb::IntWidth::_16, - IntWidth::_32 => fb::IntWidth::_32, - IntWidth::_64 => fb::IntWidth::_64, - } - } -} - -impl From<&Signedness> for fb::Signedness { - fn from(value: &Signedness) -> Self { +impl From for fb::PType { + fn from(value: PType) -> Self { match value { - Signedness::Unsigned => fb::Signedness::Unsigned, - Signedness::Signed => fb::Signedness::Signed, + PType::U8 => fb::PType::U8, + PType::U16 => fb::PType::U16, + PType::U32 => fb::PType::U32, + PType::U64 => fb::PType::U64, + PType::I8 => fb::PType::I8, + PType::I16 => fb::PType::I16, + PType::I32 => fb::PType::I32, + PType::I64 => fb::PType::I64, + PType::F16 => fb::PType::F16, + PType::F32 => fb::PType::F32, + PType::F64 => fb::PType::F64, } } } -impl From<&FloatWidth> for fb::FloatWidth { - fn from(value: &FloatWidth) -> Self { - match value { - FloatWidth::_16 => fb::FloatWidth::_16, - FloatWidth::_32 => fb::FloatWidth::_32, - FloatWidth::_64 => fb::FloatWidth::_64, - } +impl TryFrom for PType { + type Error = VortexError; + + fn try_from(value: fb::PType) -> Result { + Ok(match value { + fb::PType::U8 => PType::U8, + fb::PType::U16 => PType::U16, + fb::PType::U32 => PType::U32, + fb::PType::U64 => PType::U64, + fb::PType::I8 => PType::I8, + fb::PType::I16 => PType::I16, + fb::PType::I32 => PType::I32, + fb::PType::I64 => PType::I64, + fb::PType::F16 => PType::F16, + fb::PType::F32 => PType::F32, + fb::PType::F64 => PType::F64, + _ => vortex_bail!(InvalidSerde: "Unknown PType variant"), + }) } } @@ -179,8 +179,8 @@ mod test { use flatbuffers::root; use vortex_flatbuffers::{FlatBufferToBytes, ReadFlatBuffer}; - use crate::flatbuffers as fb; - use crate::{DType, DTypeSerdeContext, FloatWidth, IntWidth, Nullability, Signedness}; + use crate::{flatbuffers as fb, PType}; + use crate::{DType, DTypeSerdeContext, Nullability}; fn roundtrip_dtype(dtype: DType) { let bytes = dtype.with_flatbuffer_bytes(|bytes| bytes.to_vec()); @@ -196,24 +196,19 @@ mod test { fn roundtrip() { roundtrip_dtype(DType::Null); roundtrip_dtype(DType::Bool(Nullability::NonNullable)); - roundtrip_dtype(DType::Int( - IntWidth::_64, - Signedness::Unsigned, - Nullability::NonNullable, - )); + roundtrip_dtype(DType::Primitive(PType::U64, Nullability::NonNullable)); roundtrip_dtype(DType::Decimal(18, 9, Nullability::NonNullable)); - roundtrip_dtype(DType::Float(FloatWidth::_64, Nullability::NonNullable)); roundtrip_dtype(DType::Binary(Nullability::NonNullable)); roundtrip_dtype(DType::Utf8(Nullability::NonNullable)); roundtrip_dtype(DType::List( - Box::new(DType::Float(FloatWidth::_32, Nullability::Nullable)), + Box::new(DType::Primitive(PType::F32, Nullability::Nullable)), Nullability::NonNullable, )); roundtrip_dtype(DType::Struct( vec![Arc::new("strings".into()), Arc::new("ints".into())], vec![ DType::Utf8(Nullability::NonNullable), - DType::Int(IntWidth::_16, Signedness::Unsigned, Nullability::Nullable), + DType::Primitive(PType::U16, Nullability::Nullable), ], )) } diff --git a/vortex-fastlanes/src/bitpacking/mod.rs b/vortex-fastlanes/src/bitpacking/mod.rs index a362670bb3..070501e6d3 100644 --- a/vortex-fastlanes/src/bitpacking/mod.rs +++ b/vortex-fastlanes/src/bitpacking/mod.rs @@ -5,7 +5,6 @@ use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, ArrayFlatten, IntoArrayData}; -use vortex_dtype::{IntWidth, Nullability, Signedness}; use vortex_error::{vortex_bail, vortex_err, VortexResult}; mod compress; @@ -24,9 +23,6 @@ pub struct BitPackedMetadata { /// NB: All non-null values in the patches array are considered patches impl BitPackedArray<'_> { - const ENCODED_DTYPE: DType = - DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable); - pub fn try_new( packed: Array, validity: Validity, @@ -47,13 +43,13 @@ impl BitPackedArray<'_> { length: usize, offset: usize, ) -> VortexResult { - if packed.dtype() != &Self::ENCODED_DTYPE { - vortex_bail!(MismatchedTypes: Self::ENCODED_DTYPE, packed.dtype()); + if packed.dtype() != &DType::BYTES { + vortex_bail!(MismatchedTypes: DType::BYTES, packed.dtype()); } if bit_width > 64 { vortex_bail!("Unsupported bit width {}", bit_width); } - if !matches!(dtype, DType::Int(_, _, _)) { + if !dtype.is_int() { vortex_bail!(MismatchedTypes: "int", dtype); } diff --git a/vortex-ipc/benches/ipc_array_reader_take.rs b/vortex-ipc/benches/ipc_array_reader_take.rs index 2a815d2600..9823f4de59 100644 --- a/vortex-ipc/benches/ipc_array_reader_take.rs +++ b/vortex-ipc/benches/ipc_array_reader_take.rs @@ -4,7 +4,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use itertools::Itertools; use vortex::array::primitive::PrimitiveArray; use vortex::{IntoArray, SerdeContext}; -use vortex_dtype::{DType, Nullability, Signedness}; +use vortex_dtype::{DType, Nullability, PType}; use vortex_ipc::iter::FallibleLendingIterator; use vortex_ipc::reader::StreamReader; use vortex_ipc::writer::StreamWriter; @@ -25,11 +25,7 @@ fn ipc_array_reader_take(c: &mut Criterion) { let mut cursor = Cursor::new(&mut buffer); let mut writer = StreamWriter::try_new(&mut cursor, SerdeContext::default()).unwrap(); writer - .write_schema(&DType::Int( - 32.into(), - Signedness::Signed, - Nullability::Nullable, - )) + .write_schema(&DType::Primitive(PType::I32, Nullability::Nullable)) .unwrap(); (0..100i32).for_each(|i| { let data = PrimitiveArray::from(vec![i; 100_000]).into_array(); diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index b8379613a1..4cba0d85e7 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -17,7 +17,7 @@ use vortex::stats::{ArrayStatistics, Stat}; use vortex::{ Array, ArrayDType, ArrayView, IntoArray, OwnedArray, SerdeContext, ToArray, ToStatic, }; -use vortex_dtype::{match_each_integer_ptype, DType, DTypeSerdeContext, Signedness}; +use vortex_dtype::{match_each_integer_ptype, DType, DTypeSerdeContext}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_flatbuffers::ReadFlatBuffer; @@ -159,21 +159,18 @@ impl<'a, R: Read> StreamArrayReader<'a, R> { vortex_bail!("Indices must not contain nulls") } - match indices.dtype() { - DType::Int(_, signedness, _) => { - // indices must be positive integers - if signedness == &Signedness::Signed - && indices - .statistics() - // min cast should be safe - .compute_as_cast::(Stat::Min) - .unwrap() - < 0 - { - vortex_bail!("Indices must be positive") - } - } - _ => vortex_bail!("Indices must be integers"), + if !indices.dtype().is_int() { + vortex_bail!("Indices must be integers") + } + if indices.dtype().is_signed_int() + && indices + .statistics() + // min cast should be safe + .compute_as_cast::(Stat::Min) + .unwrap() + < 0 + { + vortex_bail!("Indices must be positive") } if self.row_offset != 0 { diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 2186b335b9..ef82ade23e 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -145,7 +145,7 @@ mod test { use vortex::compute::slice::slice; use vortex::validity::Validity; use vortex::{ArrayDType, ArrayTrait, IntoArray}; - use vortex_dtype::{DType, IntWidth, Nullability, Signedness}; + use vortex_dtype::{DType, Nullability, PType}; use crate::REEArray; @@ -160,7 +160,7 @@ mod test { assert_eq!(arr.len(), 10); assert_eq!( arr.dtype(), - &DType::Int(IntWidth::_32, Signedness::Signed, Nullability::NonNullable) + &DType::Primitive(PType::I32, Nullability::NonNullable) ); // 0, 1 => 1 @@ -188,7 +188,7 @@ mod test { .unwrap(); assert_eq!( arr.dtype(), - &DType::Int(IntWidth::_32, Signedness::Signed, Nullability::NonNullable) + &DType::Primitive(PType::I32, Nullability::NonNullable) ); assert_eq!(arr.len(), 5); diff --git a/vortex-roaring/src/integer/compress.rs b/vortex-roaring/src/integer/compress.rs index 02f15ebace..bb20b5ee2a 100644 --- a/vortex-roaring/src/integer/compress.rs +++ b/vortex-roaring/src/integer/compress.rs @@ -5,9 +5,6 @@ use vortex::array::primitive::PrimitiveArray; use vortex::compress::{CompressConfig, CompressCtx, EncodingCompression}; use vortex::stats::{ArrayStatistics, Stat}; use vortex::{Array, ArrayDType, ArrayDef, IntoArray, OwnedArray, ToStatic}; -use vortex_dtype::DType; -use vortex_dtype::Nullability::NonNullable; -use vortex_dtype::Signedness::Unsigned; use vortex_dtype::{NativePType, PType}; use vortex_error::VortexResult; @@ -25,8 +22,8 @@ impl EncodingCompression for RoaringIntEncoding { } // Only support non-nullable uint arrays - if !matches!(array.dtype(), DType::Int(_, Unsigned, NonNullable)) { - debug!("Skipping roaring int, not non-nullable"); + if !array.dtype().is_unsigned_int() || array.dtype().is_nullable() { + debug!("Skipping roaring int, not a uint"); return None; } diff --git a/vortex-scalar/flatbuffers/scalar.fbs b/vortex-scalar/flatbuffers/scalar.fbs index c6372dae81..c333c04e7e 100644 --- a/vortex-scalar/flatbuffers/scalar.fbs +++ b/vortex-scalar/flatbuffers/scalar.fbs @@ -17,24 +17,8 @@ table List { table Null { } -// Since Rust doesn't support structs in a union, it would be very inefficient to wrap each primitive type in a table. -// So instead we store a PType and a byte vector. -enum PType: uint8 { - U8, - U16, - U32, - U64, - I8, - I16, - I32, - I64, - F16, - F32, - F64, -} - table Primitive { - ptype: PType; + ptype: dtype.PType; // TODO(ngates): this isn't an ideal way to store the bytes. bytes: [ubyte]; } diff --git a/vortex-scalar/src/composite.rs b/vortex-scalar/src/composite.rs index a51f353abc..710d6a1893 100644 --- a/vortex-scalar/src/composite.rs +++ b/vortex-scalar/src/composite.rs @@ -5,7 +5,7 @@ use vortex_error::VortexResult; use crate::Scalar; -#[derive(Debug, Clone, PartialEq, PartialOrd)] +#[derive(Debug, Clone, PartialEq)] pub struct CompositeScalar { dtype: DType, scalar: Box, @@ -34,6 +34,12 @@ impl CompositeScalar { } } +impl PartialOrd for CompositeScalar { + fn partial_cmp(&self, other: &Self) -> Option { + self.scalar.as_ref().partial_cmp(other.scalar.as_ref()) + } +} + impl Display for CompositeScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{} ({})", self.scalar, self.dtype) diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index dd12c39982..30d4be5d8f 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -8,9 +8,8 @@ pub use null::*; pub use primitive::*; pub use struct_::*; pub use utf8::*; -use vortex_dtype::half::f16; use vortex_dtype::NativePType; -use vortex_dtype::{DType, FloatWidth, IntWidth, Nullability, Signedness}; +use vortex_dtype::{DType, Nullability}; use vortex_error::VortexResult; mod binary; @@ -126,22 +125,8 @@ impl Scalar { match dtype { DType::Null => NullScalar::new().into(), DType::Bool(_) => BoolScalar::none().into(), - DType::Int(w, s, _) => match (w, s) { - (IntWidth::_8, Signedness::Signed) => PrimitiveScalar::none::().into(), - (IntWidth::_16, Signedness::Signed) => PrimitiveScalar::none::().into(), - (IntWidth::_32, Signedness::Signed) => PrimitiveScalar::none::().into(), - (IntWidth::_64, Signedness::Signed) => PrimitiveScalar::none::().into(), - (IntWidth::_8, Signedness::Unsigned) => PrimitiveScalar::none::().into(), - (IntWidth::_16, Signedness::Unsigned) => PrimitiveScalar::none::().into(), - (IntWidth::_32, Signedness::Unsigned) => PrimitiveScalar::none::().into(), - (IntWidth::_64, Signedness::Unsigned) => PrimitiveScalar::none::().into(), - }, + DType::Primitive(p, _) => PrimitiveScalar::none_from_ptype(*p).into(), DType::Decimal(..) => unimplemented!("DecimalScalar"), - DType::Float(w, _) => match w { - FloatWidth::_16 => PrimitiveScalar::none::().into(), - FloatWidth::_32 => PrimitiveScalar::none::().into(), - FloatWidth::_64 => PrimitiveScalar::none::().into(), - }, DType::Utf8(_) => Utf8Scalar::none().into(), DType::Binary(_) => BinaryScalar::none().into(), DType::Struct(..) => StructScalar::new(dtype.clone(), vec![]).into(), diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 77f6b69c3a..063eaecb71 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use std::fmt::{Display, Formatter}; use itertools::Itertools; @@ -6,7 +7,7 @@ use vortex_error::{vortex_err, VortexError, VortexResult}; use crate::Scalar; -#[derive(Debug, Clone, PartialEq, PartialOrd)] +#[derive(Debug, Clone, PartialEq)] pub struct ListScalar { dtype: DType, values: Option>, @@ -58,6 +59,12 @@ impl ListScalar { } } +impl PartialOrd for ListScalar { + fn partial_cmp(&self, _other: &Self) -> Option { + todo!() + } +} + #[derive(Debug, Clone, PartialEq)] pub struct ListScalarVec(pub Vec); diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 552100a29d..cff0ef2ec2 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,4 +1,5 @@ use std::any; +use std::cmp::Ordering; use std::fmt::{Display, Formatter}; use std::mem::size_of; @@ -13,7 +14,7 @@ use crate::Scalar; pub trait PScalarType: NativePType + Into + TryFrom {} impl + TryFrom> PScalarType for T {} -#[derive(Debug, Clone, PartialEq, PartialOrd)] +#[derive(Debug, Clone, PartialEq)] pub struct PrimitiveScalar { ptype: PType, dtype: DType, @@ -37,6 +38,15 @@ impl PrimitiveScalar { }) } + pub fn none_from_ptype(ptype: PType) -> Self { + Self { + ptype, + dtype: DType::from(ptype).with_nullability(Nullability::Nullable), + nullability: Nullability::Nullable, + value: None, + } + } + pub fn nullable(value: Option) -> Self { Self::try_new(value, Nullability::Nullable).unwrap() } @@ -92,6 +102,16 @@ impl PrimitiveScalar { } } +impl PartialOrd for PrimitiveScalar { + fn partial_cmp(&self, other: &Self) -> Option { + if let (Some(s), Some(o)) = (self.value, other.value) { + s.partial_cmp(&o) + } else { + None + } + } +} + impl Display for PrimitiveScalar { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self.value() { @@ -341,7 +361,7 @@ impl Display for PScalar { #[cfg(test)] mod test { use vortex_dtype::PType; - use vortex_dtype::{DType, IntWidth, Nullability, Signedness}; + use vortex_dtype::{DType, Nullability}; use vortex_error::VortexError; use crate::Scalar; @@ -365,11 +385,7 @@ mod test { fn cast() { let scalar: Scalar = 10u16.into(); let u32_scalar = scalar - .cast(&DType::Int( - IntWidth::_32, - Signedness::Unsigned, - Nullability::NonNullable, - )) + .cast(&DType::Primitive(PType::U32, Nullability::NonNullable)) .unwrap(); let u32_scalar_ptype: PType = u32_scalar.dtype().try_into().unwrap(); assert_eq!(u32_scalar_ptype, PType::U32); diff --git a/vortex-scalar/src/serde.rs b/vortex-scalar/src/serde.rs index 8d953cc878..1d77288a83 100644 --- a/vortex-scalar/src/serde.rs +++ b/vortex-scalar/src/serde.rs @@ -4,7 +4,6 @@ use flatbuffers::{root, FlatBufferBuilder, WIPOffset}; use serde::de::Visitor; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use vortex_dtype::match_each_native_ptype; -use vortex_dtype::PType; use vortex_dtype::{DTypeSerdeContext, Nullability}; use vortex_error::{vortex_bail, VortexError}; use vortex_flatbuffers::{FlatBufferRoot, FlatBufferToBytes, ReadFlatBuffer, WriteFlatBuffer}; @@ -145,45 +144,6 @@ impl ReadFlatBuffer for Scalar { } } -impl From for fb::PType { - fn from(value: PType) -> Self { - match value { - PType::U8 => fb::PType::U8, - PType::U16 => fb::PType::U16, - PType::U32 => fb::PType::U32, - PType::U64 => fb::PType::U64, - PType::I8 => fb::PType::I8, - PType::I16 => fb::PType::I16, - PType::I32 => fb::PType::I32, - PType::I64 => fb::PType::I64, - PType::F16 => fb::PType::F16, - PType::F32 => fb::PType::F32, - PType::F64 => fb::PType::F64, - } - } -} - -impl TryFrom for PType { - type Error = VortexError; - - fn try_from(value: fb::PType) -> Result { - Ok(match value { - fb::PType::U8 => PType::U8, - fb::PType::U16 => PType::U16, - fb::PType::U32 => PType::U32, - fb::PType::U64 => PType::U64, - fb::PType::I8 => PType::I8, - fb::PType::I16 => PType::I16, - fb::PType::I32 => PType::I32, - fb::PType::I64 => PType::I64, - fb::PType::F16 => PType::F16, - fb::PType::F32 => PType::F32, - fb::PType::F64 => PType::F64, - _ => vortex_bail!(InvalidSerde: "Unrecognized PType"), - }) - } -} - impl Serialize for Scalar { fn serialize(&self, serializer: S) -> Result where diff --git a/vortex-zigzag/src/zigzag.rs b/vortex-zigzag/src/zigzag.rs index 99a061f136..ce531d2552 100644 --- a/vortex-zigzag/src/zigzag.rs +++ b/vortex-zigzag/src/zigzag.rs @@ -4,7 +4,7 @@ use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, ArrayFlatten, IntoArrayData}; -use vortex_dtype::Signedness; +use vortex_dtype::PType; use vortex_error::{vortex_bail, vortex_err, VortexResult}; use crate::compress::zigzag_encode; @@ -23,11 +23,11 @@ impl ZigZagArray<'_> { pub fn try_new(encoded: Array) -> VortexResult { let encoded_dtype = encoded.dtype().clone(); - let dtype = match encoded_dtype { - DType::Int(width, Signedness::Unsigned, nullability) => { - DType::Int(width, Signedness::Signed, nullability) - } - d => vortex_bail!(MismatchedTypes: "unsigned int", d), + let dtype = if encoded_dtype.is_unsigned_int() { + DType::from(PType::try_from(&encoded_dtype).unwrap().to_signed()) + .with_nullability(encoded_dtype.nullability()) + } else { + vortex_bail!(MismatchedTypes: "unsigned int", encoded_dtype) }; let children = vec![encoded.into_array_data()];