From 292481e5eb1b35b1cf021ce7b1be7185520370ec Mon Sep 17 00:00:00 2001 From: Josh Casale Date: Mon, 29 Apr 2024 19:24:04 -0400 Subject: [PATCH] handle nulls --- vortex-array/src/array/primitive/mod.rs | 36 ++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index 00409f2f12..77eebdd0ea 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -6,12 +6,12 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::buffer::Buffer; use crate::compute::scalar_subtract::ScalarSubtractFn; -use crate::match_each_integer_ptype; use crate::ptype::{NativePType, PType}; use crate::stats::ArrayStatistics; use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; use crate::{impl_encoding, match_each_float_ptype, ArrayDType, OwnedArray}; +use crate::{match_each_integer_ptype, ToStatic}; use crate::{match_each_native_ptype, ArrayFlatten}; mod accessor; @@ -197,10 +197,18 @@ impl EncodingCompression for PrimitiveEncoding {} impl ScalarSubtractFn for PrimitiveArray<'_> { fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult { - if self.dtype() != to_subtract.dtype() { + if !self.dtype().eq_ignore_nullability(to_subtract.dtype()) { vortex_bail!(MismatchedTypes: self.dtype(), to_subtract.dtype()) } + let should_wrap = match self.validity() { + Validity::AllInvalid => { + return Ok(self.clone().into_array().to_static()); + } + Validity::NonNullable | Validity::AllValid => false, + Validity::Array(_) => true, + }; + let result = match to_subtract.dtype() { DType::Int(..) => { match_each_integer_ptype!(self.ptype(), |$T| { @@ -219,7 +227,14 @@ impl ScalarSubtractFn for PrimitiveArray<'_> { } } } - let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec(); + + let sub_vec : Vec<$T> = if should_wrap { + self.typed_data::<$T>().iter().map(|&v| $T::checked_sub(v, to_subtract)) + .map(|v| v.unwrap_or_default()) + .collect_vec() + } else { + self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec() + }; PrimitiveArray::from(sub_vec) }) } @@ -232,13 +247,13 @@ impl ScalarSubtractFn for PrimitiveArray<'_> { } _ => vortex_bail!(InvalidArgument: "Can only subtract numeric types"), }; - Ok(result.into_array()) } } #[cfg(test)] mod test { + use crate::array::primitive::PrimitiveArray; use crate::compute::scalar_subtract::scalar_subtract; use crate::IntoArray; @@ -266,6 +281,19 @@ mod test { assert_eq!(results, &[2i64, 3, 4]); } + #[test] + fn test_scalar_subtract_nullable() { + let values = PrimitiveArray::from_nullable_vec(vec![Some(1u16), Some(2), None, Some(3)]) + .into_array(); + let results = scalar_subtract(&values, 1u16) + .unwrap() + .flatten_primitive() + .unwrap() + .typed_data::() + .to_vec(); + assert_eq!(results, &[0u16, 1, 0, 2]); + } + #[test] fn test_scalar_subtract_float() { let values = vec![1.0f64, 2.0, 3.0].into_array();