Skip to content

Commit

Permalink
Refactor for DType::Primitive (#276)
Browse files Browse the repository at this point in the history
Fixes #154
  • Loading branch information
gatesn authored Apr 30, 2024
1 parent 7c45dbb commit 9d955e6
Show file tree
Hide file tree
Showing 29 changed files with 255 additions and 482 deletions.
51 changes: 37 additions & 14 deletions pyvortex/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use dtype::PyDType;
use log::debug;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use vortex::encoding::VORTEX_ENCODINGS;
use vortex_dtype::DType;
use vortex_dtype::Signedness::{Signed, Unsigned};
use vortex_dtype::{DType, PType};

use crate::array::*;

Expand Down Expand Up @@ -69,28 +69,51 @@ fn dtype_bool(py: Python<'_>, nullable: bool) -> PyResult<Py<PyDType>> {
#[pyfunction(name = "int")]
#[pyo3(signature = (width = None, nullable = false))]
fn dtype_int(py: Python<'_>, width: Option<u16>, nullable: bool) -> PyResult<Py<PyDType>> {
PyDType::wrap(
py,
DType::Int(width.unwrap_or(64).into(), Signed, nullable.into()),
)
let dtype = if let Some(width) = width {
match width {
8 => DType::Primitive(PType::I8, nullable.into()),
16 => DType::Primitive(PType::I16, nullable.into()),
32 => DType::Primitive(PType::I32, nullable.into()),
64 => DType::Primitive(PType::I64, nullable.into()),
_ => return Err(PyValueError::new_err("Invalid int width")),
}
} else {
DType::Primitive(PType::I64, nullable.into())
};
PyDType::wrap(py, dtype)
}

#[pyfunction(name = "uint")]
#[pyo3(signature = (width = None, nullable = false))]
fn dtype_uint(py: Python<'_>, width: Option<u16>, nullable: bool) -> PyResult<Py<PyDType>> {
PyDType::wrap(
py,
DType::Int(width.unwrap_or(64).into(), Unsigned, nullable.into()),
)
let dtype = if let Some(width) = width {
match width {
8 => DType::Primitive(PType::U8, nullable.into()),
16 => DType::Primitive(PType::U16, nullable.into()),
32 => DType::Primitive(PType::U32, nullable.into()),
64 => DType::Primitive(PType::U64, nullable.into()),
_ => return Err(PyValueError::new_err("Invalid uint width")),
}
} else {
DType::Primitive(PType::U64, nullable.into())
};
PyDType::wrap(py, dtype)
}

#[pyfunction(name = "float")]
#[pyo3(signature = (width = None, nullable = false))]
fn dtype_float(py: Python<'_>, width: Option<i8>, nullable: bool) -> PyResult<Py<PyDType>> {
PyDType::wrap(
py,
DType::Float(width.unwrap_or(64).into(), nullable.into()),
)
let dtype = if let Some(width) = width {
match width {
16 => DType::Primitive(PType::F16, nullable.into()),
32 => DType::Primitive(PType::F32, nullable.into()),
64 => DType::Primitive(PType::F64, nullable.into()),
_ => return Err(PyValueError::new_err("Invalid float width")),
}
} else {
DType::Primitive(PType::F64, nullable.into())
};
PyDType::wrap(py, dtype)
}

#[pyfunction(name = "utf8")]
Expand Down
10 changes: 5 additions & 5 deletions pyvortex/test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


def test_int():
assert str(vortex.int()) == "int(64)"
assert str(vortex.int(32)) == "int(32)"
assert str(vortex.int(32, nullable=True)) == "int(32)?"
assert str(vortex.uint(32)) == "uint(32)"
assert str(vortex.float(16)) == "float(16)"
assert str(vortex.int()) == "i64"
assert str(vortex.int(32)) == "i32"
assert str(vortex.int(32, nullable=True)) == "i32?"
assert str(vortex.uint(32)) == "u32"
assert str(vortex.float(16)) == "f16"
assert str(vortex.bool(nullable=True)) == "bool?"
10 changes: 3 additions & 7 deletions vortex-alp/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use vortex::stats::ArrayStatisticsCompute;
use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::{impl_encoding, ArrayDType, ArrayFlatten, IntoArrayData, OwnedArray, ToArrayData};
use vortex_dtype::{IntWidth, Signedness};
use vortex_dtype::PType;
use vortex_error::{vortex_bail, VortexResult};

use crate::alp::Exponents;
Expand All @@ -27,12 +27,8 @@ impl ALPArray<'_> {
) -> VortexResult<Self> {
let encoded_dtype = encoded.dtype().clone();
let dtype = match encoded.dtype() {
DType::Int(IntWidth::_32, Signedness::Signed, nullability) => {
DType::Float(32.into(), *nullability)
}
DType::Int(IntWidth::_64, Signedness::Signed, nullability) => {
DType::Float(64.into(), *nullability)
}
DType::Primitive(PType::I32, nullability) => DType::Primitive(PType::F32, *nullability),
DType::Primitive(PType::I64, nullability) => DType::Primitive(PType::F64, *nullability),
d => vortex_bail!(MismatchedTypes: "int32 or int64", d),
};

Expand Down
18 changes: 5 additions & 13 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use vortex_dtype::{IntWidth, Nullability, Signedness};
use vortex_dtype::{Nullability, PType};
use vortex_error::{vortex_bail, VortexResult};

use crate::array::primitive::PrimitiveArray;
Expand All @@ -20,11 +20,7 @@ impl_encoding!("vortex.chunked", Chunked);
pub struct ChunkedMetadata;

impl ChunkedArray<'_> {
const ENDS_DTYPE: DType = DType::Int(
IntWidth::_64,
Signedness::Unsigned,
Nullability::NonNullable,
);
const ENDS_DTYPE: DType = DType::Primitive(PType::U64, Nullability::NonNullable);

pub fn try_new(chunks: Vec<Array>, dtype: DType) -> VortexResult<Self> {
for chunk in &chunks {
Expand Down Expand Up @@ -145,8 +141,8 @@ impl EncodingCompression for ChunkedEncoding {}

#[cfg(test)]
mod test {
use vortex_dtype::NativePType;
use vortex_dtype::{DType, IntWidth, Nullability, Signedness};
use vortex_dtype::{DType, Nullability};
use vortex_dtype::{NativePType, PType};

use crate::array::chunked::{ChunkedArray, OwnedChunkedArray};
use crate::{Array, IntoArray};
Expand All @@ -159,11 +155,7 @@ mod test {
vec![4u64, 5, 6].into_array(),
vec![7u64, 8, 9].into_array(),
],
DType::Int(
IntWidth::_64,
Signedness::Unsigned,
Nullability::NonNullable,
),
DType::Primitive(PType::U64, Nullability::NonNullable),
)
.unwrap()
}
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ fn take_search_sorted(
#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_dtype::{DType, FloatWidth, Nullability};
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
Expand All @@ -156,7 +156,7 @@ mod test {
PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid)
.into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
Scalar::null(&DType::Primitive(PType::F64, Nullability::Nullable)),
)
.into_array()
}
Expand Down
5 changes: 2 additions & 3 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ impl ArrayValidity for SparseArray<'_> {
mod test {
use itertools::Itertools;
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::Signedness::Signed;
use vortex_dtype::{DType, IntWidth};
use vortex_dtype::{DType, PType};
use vortex_error::VortexError;
use vortex_scalar::Scalar;

Expand All @@ -187,7 +186,7 @@ mod test {
use crate::{Array, IntoArray, OwnedArray};

fn nullable_fill() -> Scalar {
Scalar::null(&DType::Int(IntWidth::_32, Signed, Nullable))
Scalar::null(&DType::Primitive(PType::I32, Nullable))
}

#[allow(dead_code)]
Expand Down
9 changes: 3 additions & 6 deletions vortex-array/src/array/varbin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use num_traits::AsPrimitive;
use serde::{Deserialize, Serialize};
use vortex_dtype::Nullability;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_dtype::{IntWidth, Nullability, Signedness};
use vortex_error::{vortex_bail, VortexResult};
use vortex_scalar::{BinaryScalar, Utf8Scalar};

Expand Down Expand Up @@ -37,13 +37,10 @@ impl VarBinArray<'_> {
dtype: DType,
validity: Validity,
) -> VortexResult<Self> {
if !matches!(offsets.dtype(), DType::Int(_, _, Nullability::NonNullable)) {
if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
vortex_bail!(MismatchedTypes: "non nullable int", offsets.dtype());
}
if !matches!(
bytes.dtype(),
DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable)
) {
if !matches!(bytes.dtype(), &DType::BYTES,) {
vortex_bail!(MismatchedTypes: "u8", bytes.dtype());
}
if !matches!(dtype, DType::Binary(_) | DType::Utf8(_)) {
Expand Down
12 changes: 3 additions & 9 deletions vortex-array/src/array/varbinview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt::Formatter;
use std::{mem, slice};

use ::serde::{Deserialize, Serialize};
use vortex_dtype::{IntWidth, Nullability, Signedness};
use vortex_dtype::Nullability;
use vortex_error::{vortex_bail, VortexResult};

use crate::array::primitive::PrimitiveArray;
Expand Down Expand Up @@ -111,18 +111,12 @@ impl VarBinViewArray<'_> {
dtype: DType,
validity: Validity,
) -> VortexResult<Self> {
if !matches!(
views.dtype(),
DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable)
) {
if !matches!(views.dtype(), &DType::BYTES) {
vortex_bail!(MismatchedTypes: "u8", views.dtype());
}

for d in data.iter() {
if !matches!(
d.dtype(),
DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable)
) {
if !matches!(d.dtype(), &DType::BYTES) {
vortex_bail!(MismatchedTypes: "u8", d.dtype());
}
}
Expand Down
18 changes: 5 additions & 13 deletions vortex-array/src/arrow/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arrow_schema::TimeUnit as ArrowTimeUnit;
use arrow_schema::{DataType, Field, SchemaRef};
use itertools::Itertools;
use vortex_dtype::PType;
use vortex_dtype::{DType, FloatWidth, IntWidth, Nullability};
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_err, VortexResult};

use crate::array::datetime::{LocalDateTimeExtension, TimeUnit};
Expand Down Expand Up @@ -58,24 +58,16 @@ impl FromArrowType<SchemaRef> for DType {
impl FromArrowType<&Field> for DType {
fn from_arrow(field: &Field) -> Self {
use vortex_dtype::DType::*;
use vortex_dtype::Signedness::*;

let nullability: Nullability = field.is_nullable().into();

if let Ok(ptype) = PType::try_from_arrow(field.data_type()) {
return Primitive(ptype, nullability);
}

match field.data_type() {
DataType::Null => Null,
DataType::Boolean => Bool(nullability),
DataType::Int8 => Int(IntWidth::_8, Signed, nullability),
DataType::Int16 => Int(IntWidth::_16, Signed, nullability),
DataType::Int32 => Int(IntWidth::_32, Signed, nullability),
DataType::Int64 => Int(IntWidth::_64, Signed, nullability),
DataType::UInt8 => Int(IntWidth::_8, Unsigned, nullability),
DataType::UInt16 => Int(IntWidth::_16, Unsigned, nullability),
DataType::UInt32 => Int(IntWidth::_32, Unsigned, nullability),
DataType::UInt64 => Int(IntWidth::_64, Unsigned, nullability),
DataType::Float16 => Float(FloatWidth::_16, nullability),
DataType::Float32 => Float(FloatWidth::_32, nullability),
DataType::Float64 => Float(FloatWidth::_64, nullability),
DataType::Utf8 | DataType::LargeUtf8 => Utf8(nullability),
DataType::Binary | DataType::LargeBinary => Binary(nullability),
DataType::Timestamp(_u, tz) => match tz {
Expand Down
6 changes: 3 additions & 3 deletions vortex-datetime-parts/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ impl DateTimePartsArray<'_> {
subsecond: Array,
validity: Validity,
) -> VortexResult<Self> {
if !matches!(days.dtype(), DType::Int(_, _, _)) {
if !days.dtype().is_int() {
vortex_bail!(MismatchedTypes: "any integer", days.dtype());
}
if !matches!(seconds.dtype(), DType::Int(_, _, _)) {
if !seconds.dtype().is_int() {
vortex_bail!(MismatchedTypes: "any integer", seconds.dtype());
}
if !matches!(subsecond.dtype(), DType::Int(_, _, _)) {
if !subsecond.dtype().is_int() {
vortex_bail!(MismatchedTypes: "any integer", subsecond.dtype());
}

Expand Down
4 changes: 2 additions & 2 deletions vortex-dict/src/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::IntoArrayData;
use vortex::{impl_encoding, ArrayDType, ArrayFlatten, ToArrayData};
use vortex_dtype::{match_each_integer_ptype, Signedness};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{vortex_bail, VortexResult};

impl_encoding!("vortex.dict", Dict);
Expand All @@ -19,7 +19,7 @@ pub struct DictMetadata {

impl DictArray<'_> {
pub fn try_new(codes: Array, values: Array) -> VortexResult<Self> {
if !matches!(codes.dtype(), DType::Int(_, Signedness::Unsigned, _)) {
if !codes.dtype().is_unsigned_int() {
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
}
Self::try_from_parts(
Expand Down
41 changes: 15 additions & 26 deletions vortex-dtype/flatbuffers/dtype.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,18 @@ enum Nullability: byte {
Nullable,
}

enum Signedness: byte {
Signed,
Unsigned,
}

enum IntWidth: byte {
_8,
_16,
_32,
_64,
}

enum FloatWidth: byte {
_16,
_32,
_64,
enum PType: uint8 {
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
F16,
F32,
F64,
}

table Null {}
Expand All @@ -29,9 +25,8 @@ table Bool {
nullability: Nullability;
}

table Int {
width: IntWidth;
signedness: Signedness;
table Primitive {
ptype: PType;
nullability: Nullability;
}

Expand All @@ -44,11 +39,6 @@ table Decimal {
nullability: Nullability;
}

table Float {
width: FloatWidth;
nullability: Nullability;
}

table Utf8 {
nullability: Nullability;
}
Expand All @@ -75,9 +65,8 @@ table Composite {
union Type {
Null,
Bool,
Int,
Primitive,
Decimal,
Float,
Utf8,
Binary,
Struct_,
Expand Down
Loading

0 comments on commit 9d955e6

Please sign in to comment.