From 530c2c0f5e08c66fe9a8cde9a85167e22ebe746c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 18 Dec 2024 17:56:03 -0500 Subject: [PATCH] 1. `checked_numeric_operator` returns a PrimitiveScalar (avoiding a DType allocation) 2. `checked_sub` delegates to `checked_numeric_operator` As a consequence of (2), `checked_sub` now returns an `None` when there is an underflow/overflow rather than returning a Null PrimitiveScalar (the return type is now `VortexResult>>`). Moreover, `std::ops::Sub` now returns an error on underflow/overflow rather than a Null PrimitiveScalar. --- .../array/constant/compute/binary_numeric.rs | 3 +- .../array/sparse/compute/binary_numeric.rs | 3 +- vortex-scalar/src/primitive.rs | 152 ++++++++++-------- 3 files changed, 90 insertions(+), 68 deletions(-) diff --git a/vortex-array/src/array/constant/compute/binary_numeric.rs b/vortex-array/src/array/constant/compute/binary_numeric.rs index b7994b216..10cc6dadf 100644 --- a/vortex-array/src/array/constant/compute/binary_numeric.rs +++ b/vortex-array/src/array/constant/compute/binary_numeric.rs @@ -22,7 +22,8 @@ impl BinaryNumericFn for ConstantEncoding { .scalar() .as_primitive() .checked_numeric_operator(rhs.as_primitive(), op)? - .ok_or_else(|| vortex_err!("numeric overflow"))?, + .ok_or_else(|| vortex_err!("numeric overflow"))? + .to_scalar(), array.len(), ) .into_array(), diff --git a/vortex-array/src/array/sparse/compute/binary_numeric.rs b/vortex-array/src/array/sparse/compute/binary_numeric.rs index 50d1e686e..b58225728 100644 --- a/vortex-array/src/array/sparse/compute/binary_numeric.rs +++ b/vortex-array/src/array/sparse/compute/binary_numeric.rs @@ -23,7 +23,8 @@ impl BinaryNumericFn for SparseEncoding { .fill_scalar() .as_primitive() .checked_numeric_operator(rhs_scalar.as_primitive(), op)? - .ok_or_else(|| vortex_err!("numeric overflow"))?; + .ok_or_else(|| vortex_err!("numeric overflow"))? + .to_scalar(); SparseArray::try_new_from_patches( new_patches, array.len(), diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 5a79ff0e1..f14fc8e94 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,6 +1,6 @@ use std::any::type_name; -use num_traits::{FromPrimitive, NumCast}; +use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, NumCast}; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{ @@ -11,7 +11,7 @@ use crate::pvalue::PValue; use crate::value::ScalarValue; use crate::{InnerScalarValue, Scalar}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct PrimitiveScalar<'a> { dtype: &'a DType, ptype: PType, @@ -106,34 +106,17 @@ impl<'a> PrimitiveScalar<'a> { } } - pub fn checked_sub(&self, other: &PrimitiveScalar) -> VortexResult { - if self.ptype != other.ptype { - vortex_bail!("Failed to subtract {} and {}", self.ptype, other.ptype) - } - let result_pvalue: Option = match_each_native_ptype!( - self.ptype, - integral: |$T| { - let lhs = self.as_::<$T>()?; - let rhs = other.as_::<$T>()?; - match (lhs, rhs) { - (Some(lv), Some(rv)) => lv.checked_sub(rv).map(PValue::from), - _ => None - } - } - floating_point: |$T| { - let lhs = self.as_::<$T>()?; - let rhs = other.as_::<$T>()?; - match (lhs, rhs) { - (Some(lv), Some(rv)) => Some((lv - rv).into()), - _ => None - } - } - ); - Ok(Self { - dtype: self.dtype, - ptype: self.ptype, - pvalue: result_pvalue, - }) + pub fn checked_sub(self, other: PrimitiveScalar<'a>) -> VortexResult> { + self.checked_numeric_operator(other, BinaryNumericOperator::Sub) + } + + pub fn to_scalar(self) -> Scalar { + Scalar::new( + self.dtype().clone(), + self.pvalue + .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue))) + .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)), + ) } } @@ -190,7 +173,8 @@ impl std::ops::Sub for PrimitiveScalar<'_> { type Output = VortexResult; fn sub(self, rhs: Self) -> Self::Output { - self.checked_sub(&rhs) + self.checked_sub(rhs)? + .ok_or_else(|| vortex_err!("PrimitiveScalar subtract: overflow or underflow")) } } @@ -323,7 +307,7 @@ pub enum BinaryNumericOperator { // Pow, } -impl PrimitiveScalar<'_> { +impl<'a> PrimitiveScalar<'a> { /// Apply the (checked) operator to self and other using SQL-style null semantics. /// /// If the operation overflows, Ok(None) is returned. @@ -333,55 +317,82 @@ impl PrimitiveScalar<'_> { /// If either value is null, the result is null. pub fn checked_numeric_operator( self, - other: PrimitiveScalar<'_>, + other: PrimitiveScalar<'a>, op: BinaryNumericOperator, - ) -> VortexResult> { + ) -> VortexResult>> { if !self.dtype().eq_ignore_nullability(other.dtype()) { vortex_bail!("types must match: {} {}", self.dtype(), other.dtype()); } - - let nullability = - Nullability::from(self.dtype().is_nullable() || other.dtype().is_nullable()); + let result_dtype = if self.dtype().is_nullable() { + self.dtype() + } else { + other.dtype() + }; + let ptype = self.ptype(); Ok(match_each_native_ptype!( self.ptype(), integral: |$P| { - let lhs = self.typed_value::<$P>(); - let rhs = other.typed_value::<$P>(); - match (lhs, rhs) { - (_, None) | (None, _) => Some(Scalar::null(self.dtype().with_nullability(nullability))), - (Some(lhs), Some(rhs)) => match op { - BinaryNumericOperator::Add => - lhs.checked_add(rhs).map(|result| Scalar::primitive(result, nullability)), - BinaryNumericOperator::Sub => - lhs.checked_sub(rhs).map(|result| Scalar::primitive(result, nullability)), - BinaryNumericOperator::Mul => - lhs.checked_mul(rhs).map(|result| Scalar::primitive(result, nullability)), - BinaryNumericOperator::Div => - lhs.checked_div(rhs).map(|result| Scalar::primitive(result, nullability)), - } - } + self.checked_integeral_numeric_operator::<$P>(other, result_dtype, ptype, op) } floating_point: |$P| { let lhs = self.typed_value::<$P>(); let rhs = other.typed_value::<$P>(); - Some(match (lhs, rhs) { - (_, None) | (None, _) => Scalar::null(self.dtype().with_nullability(nullability)), - (Some(lhs), Some(rhs)) => match op { - BinaryNumericOperator::Add => Scalar::primitive(lhs + rhs, nullability), - BinaryNumericOperator::Sub => Scalar::primitive(lhs - rhs, nullability), - BinaryNumericOperator::Mul => Scalar::primitive(lhs - rhs, nullability), - BinaryNumericOperator::Div => Scalar::primitive(lhs - rhs, nullability), + let value_or_null = match (lhs, rhs) { + (_, None) | (None, _) => None, + (Some(lhs), Some(rhs)) => match op { + BinaryNumericOperator::Add => Some(lhs + rhs), + BinaryNumericOperator::Sub => Some(lhs - rhs), + BinaryNumericOperator::Mul => Some(lhs * rhs), + BinaryNumericOperator::Div => Some(lhs / rhs), } - }) + }; + Some(Self { dtype: result_dtype, ptype: ptype, pvalue: value_or_null.map(PValue::from) }) } )) } + + fn checked_integeral_numeric_operator< + P: NativePType + + TryFrom + + CheckedSub + + CheckedAdd + + CheckedMul + + CheckedDiv, + >( + self, + other: PrimitiveScalar<'a>, + result_dtype: &'a DType, + ptype: PType, + op: BinaryNumericOperator, + ) -> Option> + where + PValue: From

, + { + let lhs = self.typed_value::

(); + let rhs = other.typed_value::

(); + let value_or_null_or_overflow = match (lhs, rhs) { + (_, None) | (None, _) => Some(None), + (Some(lhs), Some(rhs)) => match op { + BinaryNumericOperator::Add => lhs.checked_add(&rhs).map(Some), + BinaryNumericOperator::Sub => lhs.checked_sub(&rhs).map(Some), + BinaryNumericOperator::Mul => lhs.checked_mul(&rhs).map(Some), + BinaryNumericOperator::Div => lhs.checked_div(&rhs).map(Some), + }, + }; + + value_or_null_or_overflow.map(|value_or_null| Self { + dtype: result_dtype, + ptype, + pvalue: value_or_null.map(PValue::from), + }) + } } #[cfg(test)] mod tests { use vortex_dtype::{DType, Nullability, PType}; + use vortex_error::VortexError; use crate::value::InnerScalarValue; use crate::{PValue, PrimitiveScalar, ScalarValue}; @@ -399,8 +410,9 @@ mod tests { &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), ) .unwrap(); - let res = p_scalar1.checked_sub(&p_scalar2).unwrap(); - assert_eq!(res.as_::().unwrap().unwrap(), 1); + let pscalar_or_overflow = p_scalar1.checked_sub(p_scalar2).unwrap(); + let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::(); + assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1); assert_eq!( (p_scalar1 - p_scalar2) @@ -413,6 +425,7 @@ mod tests { } #[test] + #[allow(clippy::assertions_on_constants)] fn test_integer_subtract_overflow() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); let p_scalar1 = PrimitiveScalar::try_new( @@ -425,8 +438,14 @@ mod tests { &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))), ) .unwrap(); - let res = p_scalar1 - p_scalar2; - assert!(res.unwrap().pvalue.is_none()); + let pscalar_or_error = p_scalar1 - p_scalar2; + match pscalar_or_error { + Err(VortexError::InvalidArgument(message, _)) => assert_eq!( + message.as_ref(), + "PrimitiveScalar subtract: overflow or underflow" + ), + res => assert!(false, "expected overflow error but got: {:?}", res), + } } #[test] @@ -442,8 +461,9 @@ mod tests { &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))), ) .unwrap(); - let res = p_scalar1.checked_sub(&p_scalar2).unwrap(); - assert_eq!(res.as_::().unwrap().unwrap(), 0.99f32); + let pscalar_or_overflow = p_scalar1.checked_sub(p_scalar2).unwrap(); + let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::(); + assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32); assert_eq!( (p_scalar1 - p_scalar2)