Skip to content

Commit

Permalink
handle nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcasale committed Apr 29, 2024
1 parent 70af4fe commit 292481e
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use vortex_error::{vortex_bail, VortexResult};

use crate::buffer::Buffer;
use crate::compute::scalar_subtract::ScalarSubtractFn;
use crate::match_each_integer_ptype;
use crate::ptype::{NativePType, PType};
use crate::stats::ArrayStatistics;
use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata};
use crate::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::{impl_encoding, match_each_float_ptype, ArrayDType, OwnedArray};
use crate::{match_each_integer_ptype, ToStatic};
use crate::{match_each_native_ptype, ArrayFlatten};

mod accessor;
Expand Down Expand Up @@ -197,10 +197,18 @@ impl EncodingCompression for PrimitiveEncoding {}

impl ScalarSubtractFn for PrimitiveArray<'_> {
fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult<OwnedArray> {
if self.dtype() != to_subtract.dtype() {
if !self.dtype().eq_ignore_nullability(to_subtract.dtype()) {
vortex_bail!(MismatchedTypes: self.dtype(), to_subtract.dtype())
}

let should_wrap = match self.validity() {
Validity::AllInvalid => {
return Ok(self.clone().into_array().to_static());
}
Validity::NonNullable | Validity::AllValid => false,
Validity::Array(_) => true,
};

let result = match to_subtract.dtype() {
DType::Int(..) => {
match_each_integer_ptype!(self.ptype(), |$T| {
Expand All @@ -219,7 +227,14 @@ impl ScalarSubtractFn for PrimitiveArray<'_> {
}
}
}
let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec();

let sub_vec : Vec<$T> = if should_wrap {
self.typed_data::<$T>().iter().map(|&v| $T::checked_sub(v, to_subtract))
.map(|v| v.unwrap_or_default())
.collect_vec()
} else {
self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec()
};
PrimitiveArray::from(sub_vec)
})
}
Expand All @@ -232,13 +247,13 @@ impl ScalarSubtractFn for PrimitiveArray<'_> {
}
_ => vortex_bail!(InvalidArgument: "Can only subtract numeric types"),
};

Ok(result.into_array())
}
}

#[cfg(test)]
mod test {
use crate::array::primitive::PrimitiveArray;
use crate::compute::scalar_subtract::scalar_subtract;
use crate::IntoArray;

Expand Down Expand Up @@ -266,6 +281,19 @@ mod test {
assert_eq!(results, &[2i64, 3, 4]);
}

#[test]
fn test_scalar_subtract_nullable() {
let values = PrimitiveArray::from_nullable_vec(vec![Some(1u16), Some(2), None, Some(3)])
.into_array();
let results = scalar_subtract(&values, 1u16)
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<u16>()
.to_vec();
assert_eq!(results, &[0u16, 1, 0, 2]);
}

#[test]
fn test_scalar_subtract_float() {
let values = vec![1.0f64, 2.0, 3.0].into_array();
Expand Down

0 comments on commit 292481e

Please sign in to comment.