Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for DType::Primitive #276

Merged
merged 4 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions pyvortex/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use dtype::PyDType;
use log::debug;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use vortex::encoding::VORTEX_ENCODINGS;
use vortex_dtype::DType;
use vortex_dtype::Signedness::{Signed, Unsigned};
use vortex_dtype::{DType, PType};

use crate::array::*;

Expand Down Expand Up @@ -69,28 +69,51 @@ fn dtype_bool(py: Python<'_>, nullable: bool) -> PyResult<Py<PyDType>> {
#[pyfunction(name = "int")]
#[pyo3(signature = (width = None, nullable = false))]
fn dtype_int(py: Python<'_>, width: Option<u16>, nullable: bool) -> PyResult<Py<PyDType>> {
PyDType::wrap(
py,
DType::Int(width.unwrap_or(64).into(), Signed, nullable.into()),
)
let dtype = if let Some(width) = width {
match width {
8 => DType::Primitive(PType::I8, nullable.into()),
16 => DType::Primitive(PType::I16, nullable.into()),
32 => DType::Primitive(PType::I32, nullable.into()),
64 => DType::Primitive(PType::I64, nullable.into()),
_ => return Err(PyValueError::new_err("Invalid int width")),
}
} else {
DType::Primitive(PType::I64, nullable.into())
};
PyDType::wrap(py, dtype)
}

#[pyfunction(name = "uint")]
#[pyo3(signature = (width = None, nullable = false))]
fn dtype_uint(py: Python<'_>, width: Option<u16>, nullable: bool) -> PyResult<Py<PyDType>> {
PyDType::wrap(
py,
DType::Int(width.unwrap_or(64).into(), Unsigned, nullable.into()),
)
let dtype = if let Some(width) = width {
match width {
8 => DType::Primitive(PType::U8, nullable.into()),
16 => DType::Primitive(PType::U16, nullable.into()),
32 => DType::Primitive(PType::U32, nullable.into()),
64 => DType::Primitive(PType::U64, nullable.into()),
_ => return Err(PyValueError::new_err("Invalid uint width")),
}
} else {
DType::Primitive(PType::U64, nullable.into())
};
PyDType::wrap(py, dtype)
}

#[pyfunction(name = "float")]
#[pyo3(signature = (width = None, nullable = false))]
fn dtype_float(py: Python<'_>, width: Option<i8>, nullable: bool) -> PyResult<Py<PyDType>> {
PyDType::wrap(
py,
DType::Float(width.unwrap_or(64).into(), nullable.into()),
)
let dtype = if let Some(width) = width {
match width {
16 => DType::Primitive(PType::F16, nullable.into()),
32 => DType::Primitive(PType::F32, nullable.into()),
64 => DType::Primitive(PType::F64, nullable.into()),
_ => return Err(PyValueError::new_err("Invalid float width")),
}
} else {
DType::Primitive(PType::F64, nullable.into())
};
PyDType::wrap(py, dtype)
}

#[pyfunction(name = "utf8")]
Expand Down
10 changes: 5 additions & 5 deletions pyvortex/test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


def test_int():
assert str(vortex.int()) == "int(64)"
assert str(vortex.int(32)) == "int(32)"
assert str(vortex.int(32, nullable=True)) == "int(32)?"
assert str(vortex.uint(32)) == "uint(32)"
assert str(vortex.float(16)) == "float(16)"
assert str(vortex.int()) == "i64"
assert str(vortex.int(32)) == "i32"
assert str(vortex.int(32, nullable=True)) == "i32?"
assert str(vortex.uint(32)) == "u32"
assert str(vortex.float(16)) == "f16"
assert str(vortex.bool(nullable=True)) == "bool?"
10 changes: 3 additions & 7 deletions vortex-alp/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use vortex::stats::ArrayStatisticsCompute;
use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::{impl_encoding, ArrayDType, ArrayFlatten, IntoArrayData, OwnedArray, ToArrayData};
use vortex_dtype::{IntWidth, Signedness};
use vortex_dtype::PType;
use vortex_error::{vortex_bail, VortexResult};

use crate::alp::Exponents;
Expand All @@ -27,12 +27,8 @@ impl ALPArray<'_> {
) -> VortexResult<Self> {
let encoded_dtype = encoded.dtype().clone();
let dtype = match encoded.dtype() {
DType::Int(IntWidth::_32, Signedness::Signed, nullability) => {
DType::Float(32.into(), *nullability)
}
DType::Int(IntWidth::_64, Signedness::Signed, nullability) => {
DType::Float(64.into(), *nullability)
}
DType::Primitive(PType::I32, nullability) => DType::Primitive(PType::F32, *nullability),
DType::Primitive(PType::I64, nullability) => DType::Primitive(PType::F64, *nullability),
d => vortex_bail!(MismatchedTypes: "int32 or int64", d),
};

Expand Down
18 changes: 5 additions & 13 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use vortex_dtype::{IntWidth, Nullability, Signedness};
use vortex_dtype::{Nullability, PType};
use vortex_error::{vortex_bail, VortexResult};

use crate::array::primitive::PrimitiveArray;
Expand All @@ -20,11 +20,7 @@ impl_encoding!("vortex.chunked", Chunked);
pub struct ChunkedMetadata;

impl ChunkedArray<'_> {
const ENDS_DTYPE: DType = DType::Int(
IntWidth::_64,
Signedness::Unsigned,
Nullability::NonNullable,
);
const ENDS_DTYPE: DType = DType::Primitive(PType::U64, Nullability::NonNullable);

pub fn try_new(chunks: Vec<Array>, dtype: DType) -> VortexResult<Self> {
for chunk in &chunks {
Expand Down Expand Up @@ -145,8 +141,8 @@ impl EncodingCompression for ChunkedEncoding {}

#[cfg(test)]
mod test {
use vortex_dtype::NativePType;
use vortex_dtype::{DType, IntWidth, Nullability, Signedness};
use vortex_dtype::{DType, Nullability};
use vortex_dtype::{NativePType, PType};

use crate::array::chunked::{ChunkedArray, OwnedChunkedArray};
use crate::{Array, IntoArray};
Expand All @@ -159,11 +155,7 @@ mod test {
vec![4u64, 5, 6].into_array(),
vec![7u64, 8, 9].into_array(),
],
DType::Int(
IntWidth::_64,
Signedness::Unsigned,
Nullability::NonNullable,
),
DType::Primitive(PType::U64, Nullability::NonNullable),
)
.unwrap()
}
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ fn take_search_sorted(
#[cfg(test)]
mod test {
use itertools::Itertools;
use vortex_dtype::{DType, FloatWidth, Nullability};
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
Expand All @@ -156,7 +156,7 @@ mod test {
PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid)
.into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
Scalar::null(&DType::Primitive(PType::F64, Nullability::Nullable)),
)
.into_array()
}
Expand Down
5 changes: 2 additions & 3 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ impl ArrayValidity for SparseArray<'_> {
mod test {
use itertools::Itertools;
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::Signedness::Signed;
use vortex_dtype::{DType, IntWidth};
use vortex_dtype::{DType, PType};
use vortex_error::VortexError;
use vortex_scalar::Scalar;

Expand All @@ -187,7 +186,7 @@ mod test {
use crate::{Array, IntoArray, OwnedArray};

fn nullable_fill() -> Scalar {
Scalar::null(&DType::Int(IntWidth::_32, Signed, Nullable))
Scalar::null(&DType::Primitive(PType::I32, Nullable))
}

#[allow(dead_code)]
Expand Down
9 changes: 3 additions & 6 deletions vortex-array/src/array/varbin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use num_traits::AsPrimitive;
use serde::{Deserialize, Serialize};
use vortex_dtype::Nullability;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_dtype::{IntWidth, Nullability, Signedness};
use vortex_error::{vortex_bail, VortexResult};
use vortex_scalar::{BinaryScalar, Utf8Scalar};

Expand Down Expand Up @@ -37,13 +37,10 @@ impl VarBinArray<'_> {
dtype: DType,
validity: Validity,
) -> VortexResult<Self> {
if !matches!(offsets.dtype(), DType::Int(_, _, Nullability::NonNullable)) {
if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
vortex_bail!(MismatchedTypes: "non nullable int", offsets.dtype());
}
if !matches!(
bytes.dtype(),
DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable)
) {
if !matches!(bytes.dtype(), &DType::BYTES,) {
vortex_bail!(MismatchedTypes: "u8", bytes.dtype());
}
if !matches!(dtype, DType::Binary(_) | DType::Utf8(_)) {
Expand Down
12 changes: 3 additions & 9 deletions vortex-array/src/array/varbinview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt::Formatter;
use std::{mem, slice};

use ::serde::{Deserialize, Serialize};
use vortex_dtype::{IntWidth, Nullability, Signedness};
use vortex_dtype::Nullability;
use vortex_error::{vortex_bail, VortexResult};

use crate::array::primitive::PrimitiveArray;
Expand Down Expand Up @@ -111,18 +111,12 @@ impl VarBinViewArray<'_> {
dtype: DType,
validity: Validity,
) -> VortexResult<Self> {
if !matches!(
views.dtype(),
DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable)
) {
if !matches!(views.dtype(), &DType::BYTES) {
vortex_bail!(MismatchedTypes: "u8", views.dtype());
}

for d in data.iter() {
if !matches!(
d.dtype(),
DType::Int(IntWidth::_8, Signedness::Unsigned, Nullability::NonNullable)
) {
if !matches!(d.dtype(), &DType::BYTES) {
vortex_bail!(MismatchedTypes: "u8", d.dtype());
}
}
Expand Down
18 changes: 5 additions & 13 deletions vortex-array/src/arrow/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arrow_schema::TimeUnit as ArrowTimeUnit;
use arrow_schema::{DataType, Field, SchemaRef};
use itertools::Itertools;
use vortex_dtype::PType;
use vortex_dtype::{DType, FloatWidth, IntWidth, Nullability};
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_err, VortexResult};

use crate::array::datetime::{LocalDateTimeExtension, TimeUnit};
Expand Down Expand Up @@ -58,24 +58,16 @@ impl FromArrowType<SchemaRef> for DType {
impl FromArrowType<&Field> for DType {
fn from_arrow(field: &Field) -> Self {
use vortex_dtype::DType::*;
use vortex_dtype::Signedness::*;

let nullability: Nullability = field.is_nullable().into();

if let Ok(ptype) = PType::try_from_arrow(field.data_type()) {
return Primitive(ptype, nullability);
}

match field.data_type() {
DataType::Null => Null,
DataType::Boolean => Bool(nullability),
DataType::Int8 => Int(IntWidth::_8, Signed, nullability),
DataType::Int16 => Int(IntWidth::_16, Signed, nullability),
DataType::Int32 => Int(IntWidth::_32, Signed, nullability),
DataType::Int64 => Int(IntWidth::_64, Signed, nullability),
DataType::UInt8 => Int(IntWidth::_8, Unsigned, nullability),
DataType::UInt16 => Int(IntWidth::_16, Unsigned, nullability),
DataType::UInt32 => Int(IntWidth::_32, Unsigned, nullability),
DataType::UInt64 => Int(IntWidth::_64, Unsigned, nullability),
DataType::Float16 => Float(FloatWidth::_16, nullability),
DataType::Float32 => Float(FloatWidth::_32, nullability),
DataType::Float64 => Float(FloatWidth::_64, nullability),
DataType::Utf8 | DataType::LargeUtf8 => Utf8(nullability),
DataType::Binary | DataType::LargeBinary => Binary(nullability),
DataType::Timestamp(_u, tz) => match tz {
Expand Down
6 changes: 3 additions & 3 deletions vortex-datetime-parts/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ impl DateTimePartsArray<'_> {
subsecond: Array,
validity: Validity,
) -> VortexResult<Self> {
if !matches!(days.dtype(), DType::Int(_, _, _)) {
if !days.dtype().is_int() {
vortex_bail!(MismatchedTypes: "any integer", days.dtype());
}
if !matches!(seconds.dtype(), DType::Int(_, _, _)) {
if !seconds.dtype().is_int() {
vortex_bail!(MismatchedTypes: "any integer", seconds.dtype());
}
if !matches!(subsecond.dtype(), DType::Int(_, _, _)) {
if !subsecond.dtype().is_int() {
vortex_bail!(MismatchedTypes: "any integer", subsecond.dtype());
}

Expand Down
4 changes: 2 additions & 2 deletions vortex-dict/src/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::IntoArrayData;
use vortex::{impl_encoding, ArrayDType, ArrayFlatten, ToArrayData};
use vortex_dtype::{match_each_integer_ptype, Signedness};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{vortex_bail, VortexResult};

impl_encoding!("vortex.dict", Dict);
Expand All @@ -19,7 +19,7 @@ pub struct DictMetadata {

impl DictArray<'_> {
pub fn try_new(codes: Array, values: Array) -> VortexResult<Self> {
if !matches!(codes.dtype(), DType::Int(_, Signedness::Unsigned, _)) {
if !codes.dtype().is_unsigned_int() {
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
}
Self::try_from_parts(
Expand Down
41 changes: 15 additions & 26 deletions vortex-dtype/flatbuffers/dtype.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,18 @@ enum Nullability: byte {
Nullable,
}

enum Signedness: byte {
Signed,
Unsigned,
}

enum IntWidth: byte {
_8,
_16,
_32,
_64,
}

enum FloatWidth: byte {
_16,
_32,
_64,
enum PType: uint8 {
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
F16,
F32,
F64,
}

table Null {}
Expand All @@ -29,9 +25,8 @@ table Bool {
nullability: Nullability;
}

table Int {
width: IntWidth;
signedness: Signedness;
table Primitive {
ptype: PType;
nullability: Nullability;
}

Expand All @@ -44,11 +39,6 @@ table Decimal {
nullability: Nullability;
}

table Float {
width: FloatWidth;
nullability: Nullability;
}

table Utf8 {
nullability: Nullability;
}
Expand All @@ -75,9 +65,8 @@ table Composite {
union Type {
Null,
Bool,
Int,
Primitive,
Decimal,
Float,
Utf8,
Binary,
Struct_,
Expand Down
Loading