Skip to content

Commit

Permalink
Nullable scalars (#152)
Browse files Browse the repository at this point in the history
Fixes #144
  • Loading branch information
gatesn authored Mar 26, 2024
1 parent 6c5fae7 commit 0001d8d
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 222 deletions.
11 changes: 6 additions & 5 deletions vortex-array/src/array/bool/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ impl FlattenFn for BoolArray {

impl ScalarAtFn for BoolArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
if self.is_valid(index) {
Ok(self.buffer.value(index).into())
} else {
Ok(BoolScalar::new(None).into())
}
Ok(BoolScalar::try_new(
self.is_valid(index).then(|| self.buffer.value(index)),
self.nullability(),
)
.unwrap()
.into())
}
}

Expand Down
39 changes: 20 additions & 19 deletions vortex-array/src/array/constant/compute.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use itertools::Itertools;

use vortex_error::VortexResult;
use vortex_schema::Nullability;

use crate::array::bool::BoolArray;
use crate::array::constant::ConstantArray;
Expand All @@ -14,6 +15,7 @@ use crate::compute::take::TakeFn;
use crate::compute::ArrayCompute;
use crate::match_each_native_ptype;
use crate::scalar::Scalar;
use crate::validity::{ArrayValidity, Validity};

impl ArrayCompute for ConstantArray {
fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> {
Expand Down Expand Up @@ -51,27 +53,26 @@ impl AsContiguousFn for ConstantArray {

impl FlattenFn for ConstantArray {
fn flatten(&self) -> VortexResult<FlattenedArray> {
let validity = match self.nullability() {
Nullability::NonNullable => None,
Nullability::Nullable => Some(match self.scalar().is_null() {
true => Validity::Invalid(self.len()),
false => Validity::Valid(self.len()),
}),
};

Ok(match self.scalar() {
Scalar::Bool(b) => {
if let Some(bv) = b.value() {
FlattenedArray::Bool(BoolArray::from(vec![bv; self.len()]))
} else {
FlattenedArray::Bool(BoolArray::null(self.len()))
}
}
Scalar::Bool(b) => FlattenedArray::Bool(BoolArray::from_nullable(
vec![b.value().copied().unwrap_or_default(); self.len()],
validity,
)),
Scalar::Primitive(p) => {
if let Some(ps) = p.value() {
match_each_native_ptype!(ps.ptype(), |$P| {
FlattenedArray::Primitive(PrimitiveArray::from_value::<$P>(
$P::try_from(self.scalar())?,
self.len(),
))
})
} else {
match_each_native_ptype!(p.ptype(), |$P| {
FlattenedArray::Primitive(PrimitiveArray::null::<$P>(self.len()))
})
}
match_each_native_ptype!(p.ptype(), |$P| {
FlattenedArray::Primitive(PrimitiveArray::from_nullable::<$P>(
vec![$P::try_from(self.scalar())?; self.len()],
validity,
))
})
}
_ => panic!("Unsupported scalar type {}", self.dtype()),
})
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/constant/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl StatsCompute for ConstantArray {
return Ok(StatsSet::from(
[(
Stat::TrueCount,
(self.len() as u64 * b.value().map(|v| v as u64).unwrap_or(0)).into(),
(self.len() as u64 * b.value().cloned().map(|v| v as u64).unwrap_or(0)).into(),
)]
.into(),
));
Expand Down
11 changes: 6 additions & 5 deletions vortex-array/src/array/primitive/compute/scalar_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ use crate::validity::ArrayValidity;

impl ScalarAtFn for PrimitiveArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
if self.is_valid(index) {
Ok(match_each_native_ptype!(self.ptype, |$T| self.typed_data::<$T>()[index].into()))
} else {
Ok(PrimitiveScalar::none(self.ptype).into())
}
match_each_native_ptype!(self.ptype, |$T| {
Ok(PrimitiveScalar::try_new(
self.is_valid(index).then(|| self.typed_data::<$T>()[index]),
self.nullability(),
)?.into())
})
}
}
5 changes: 3 additions & 2 deletions vortex-array/src/array/varbin/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ impl ScalarAtFn for VarBinArray {
bytes.into()
}
})
// FIXME(ngates): there's something weird about this.
} else if matches!(self.dtype, DType::Utf8(_)) {
Ok(Utf8Scalar::new(None).into())
Ok(Utf8Scalar::none().into())
} else {
Ok(BinaryScalar::new(None).into())
Ok(BinaryScalar::none().into())
}
}
}
36 changes: 12 additions & 24 deletions vortex-array/src/scalar/binary.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,21 @@
use std::fmt::{Display, Formatter};

use vortex_error::{VortexError, VortexResult};
use vortex_schema::{DType, Nullability};
use vortex_schema::DType;
use vortex_schema::Nullability::{NonNullable, Nullable};

use crate::scalar::value::ScalarValue;
use crate::scalar::Scalar;

#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct BinaryScalar {
value: Option<Vec<u8>>,
}
pub type BinaryScalar = ScalarValue<Vec<u8>>;

impl BinaryScalar {
pub fn new(value: Option<Vec<u8>>) -> Self {
Self { value }
}

pub fn none() -> Self {
Self { value: None }
}

pub fn some(value: Vec<u8>) -> Self {
Self { value: Some(value) }
}

pub fn value(&self) -> Option<&[u8]> {
self.value.as_deref()
}

#[inline]
pub fn dtype(&self) -> &DType {
&DType::Binary(Nullability::NonNullable)
match self.nullability() {
NonNullable => &DType::Binary(NonNullable),
Nullable => &DType::Binary(Nullable),
}
}

pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
Expand All @@ -43,7 +29,7 @@ impl BinaryScalar {

impl From<Vec<u8>> for Scalar {
fn from(value: Vec<u8>) -> Self {
BinaryScalar::new(Some(value)).into()
BinaryScalar::some(value).into()
}
}

Expand All @@ -55,7 +41,9 @@ impl TryFrom<Scalar> for Vec<u8> {
return Err(VortexError::InvalidDType(value.dtype().clone()));
};
let dtype = b.dtype().clone();
b.value.ok_or_else(|| VortexError::InvalidDType(dtype))
b.value()
.cloned()
.ok_or_else(|| VortexError::InvalidDType(dtype))
}
}

Expand Down
31 changes: 8 additions & 23 deletions vortex-array/src/scalar/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,18 @@ use std::fmt::{Display, Formatter};
use vortex_error::{VortexError, VortexResult};
use vortex_schema::{DType, Nullability};

use crate::scalar::value::ScalarValue;
use crate::scalar::Scalar;

#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct BoolScalar {
value: Option<bool>,
}
pub type BoolScalar = ScalarValue<bool>;

impl BoolScalar {
pub fn new(value: Option<bool>) -> Self {
Self { value }
}

pub fn none() -> Self {
Self { value: None }
}

pub fn some(value: bool) -> Self {
Self { value: Some(value) }
}

pub fn value(&self) -> Option<bool> {
self.value
}

#[inline]
pub fn dtype(&self) -> &DType {
&DType::Bool(Nullability::NonNullable)
match self.nullability() {
Nullability::NonNullable => &DType::Bool(Nullability::NonNullable),
Nullability::Nullable => &DType::Bool(Nullability::Nullable),
}
}

pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
Expand All @@ -47,7 +32,7 @@ impl BoolScalar {
impl From<bool> for Scalar {
#[inline]
fn from(value: bool) -> Self {
BoolScalar::new(Some(value)).into()
BoolScalar::some(value).into()
}
}

Expand All @@ -58,8 +43,8 @@ impl TryFrom<Scalar> for bool {
let Scalar::Bool(b) = value else {
return Err(VortexError::InvalidDType(value.dtype().clone()));
};

b.value()
.cloned()
.ok_or_else(|| VortexError::InvalidDType(b.dtype().clone()))
}
}
Expand Down
47 changes: 26 additions & 21 deletions vortex-array/src/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use half::f16;
use std::fmt::{Debug, Display, Formatter};

pub use binary::*;
Expand All @@ -10,9 +11,9 @@ pub use serde::*;
pub use struct_::*;
pub use utf8::*;
use vortex_error::VortexResult;
use vortex_schema::{DType, FloatWidth, IntWidth, Signedness};
use vortex_schema::{DType, FloatWidth, IntWidth, Nullability, Signedness};

use crate::ptype::{NativePType, PType};
use crate::ptype::NativePType;

mod binary;
mod bool;
Expand All @@ -23,6 +24,7 @@ mod primitive;
mod serde;
mod struct_;
mod utf8;
mod value;

#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum Scalar {
Expand Down Expand Up @@ -84,6 +86,10 @@ impl Scalar {
match_each_scalar! { self, |$s| $s.nbytes() }
}

pub fn nullability(&self) -> Nullability {
self.dtype().nullability()
}

pub fn is_null(&self) -> bool {
match self {
Scalar::Binary(b) => b.value().is_none(),
Expand All @@ -99,42 +105,41 @@ impl Scalar {
}

pub fn null(dtype: &DType) -> Self {
assert!(dtype.is_nullable());
match dtype {
DType::Null => NullScalar::new().into(),
DType::Bool(_) => BoolScalar::new(None).into(),
DType::Bool(_) => BoolScalar::none().into(),
DType::Int(w, s, _) => match (w, s) {
(IntWidth::Unknown, Signedness::Unknown | Signedness::Signed) => {
PrimitiveScalar::none(PType::I64).into()
PrimitiveScalar::none::<i64>().into()
}
(IntWidth::_8, Signedness::Unknown | Signedness::Signed) => {
PrimitiveScalar::none(PType::I8).into()
PrimitiveScalar::none::<i8>().into()
}
(IntWidth::_16, Signedness::Unknown | Signedness::Signed) => {
PrimitiveScalar::none(PType::I16).into()
PrimitiveScalar::none::<i16>().into()
}
(IntWidth::_32, Signedness::Unknown | Signedness::Signed) => {
PrimitiveScalar::none(PType::I32).into()
PrimitiveScalar::none::<i32>().into()
}
(IntWidth::_64, Signedness::Unknown | Signedness::Signed) => {
PrimitiveScalar::none(PType::I64).into()
}
(IntWidth::Unknown, Signedness::Unsigned) => {
PrimitiveScalar::none(PType::U64).into()
PrimitiveScalar::none::<i64>().into()
}
(IntWidth::_8, Signedness::Unsigned) => PrimitiveScalar::none(PType::U8).into(),
(IntWidth::_16, Signedness::Unsigned) => PrimitiveScalar::none(PType::U16).into(),
(IntWidth::_32, Signedness::Unsigned) => PrimitiveScalar::none(PType::U32).into(),
(IntWidth::_64, Signedness::Unsigned) => PrimitiveScalar::none(PType::U64).into(),
(IntWidth::Unknown, Signedness::Unsigned) => PrimitiveScalar::none::<u64>().into(),
(IntWidth::_8, Signedness::Unsigned) => PrimitiveScalar::none::<u8>().into(),
(IntWidth::_16, Signedness::Unsigned) => PrimitiveScalar::none::<u16>().into(),
(IntWidth::_32, Signedness::Unsigned) => PrimitiveScalar::none::<u32>().into(),
(IntWidth::_64, Signedness::Unsigned) => PrimitiveScalar::none::<u64>().into(),
},
DType::Decimal(_, _, _) => unimplemented!("DecimalScalar"),
DType::Float(w, _) => match w {
FloatWidth::Unknown => PrimitiveScalar::none(PType::F64).into(),
FloatWidth::_16 => PrimitiveScalar::none(PType::F16).into(),
FloatWidth::_32 => PrimitiveScalar::none(PType::F32).into(),
FloatWidth::_64 => PrimitiveScalar::none(PType::F64).into(),
FloatWidth::Unknown => PrimitiveScalar::none::<f64>().into(),
FloatWidth::_16 => PrimitiveScalar::none::<f16>().into(),
FloatWidth::_32 => PrimitiveScalar::none::<f32>().into(),
FloatWidth::_64 => PrimitiveScalar::none::<f64>().into(),
},
DType::Utf8(_) => Utf8Scalar::new(None).into(),
DType::Binary(_) => BinaryScalar::new(None).into(),
DType::Utf8(_) => Utf8Scalar::none().into(),
DType::Binary(_) => BinaryScalar::none().into(),
DType::Struct(_, _) => StructScalar::new(dtype.clone(), vec![]).into(),
DType::List(_, _) => ListScalar::new(dtype.clone(), None).into(),
DType::Composite(_, _) => unimplemented!("CompositeScalar"),
Expand Down
Loading

0 comments on commit 0001d8d

Please sign in to comment.