Skip to content

Commit

Permalink
1. checked_numeric_operator returns a PrimitiveScalar (avoiding a D…
Browse files Browse the repository at this point in the history
…Type allocation)

2. `checked_sub` delegates to `checked_numeric_operator`

As a consequence of (2), `checked_sub` now returns an `None` when there is an underflow/overflow
rather than returning a Null PrimitiveScalar (the return type is now
`VortexResult<Option<PrimitiveScalar<'a>>>`).

Moreover, `std::ops::Sub` now returns an error on underflow/overflow rather than a Null
PrimitiveScalar.
  • Loading branch information
danking committed Dec 19, 2024
1 parent 18eef43 commit 530c2c0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 68 deletions.
3 changes: 2 additions & 1 deletion vortex-array/src/array/constant/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ impl BinaryNumericFn<ConstantArray> for ConstantEncoding {
.scalar()
.as_primitive()
.checked_numeric_operator(rhs.as_primitive(), op)?
.ok_or_else(|| vortex_err!("numeric overflow"))?,
.ok_or_else(|| vortex_err!("numeric overflow"))?
.to_scalar(),
array.len(),
)
.into_array(),
Expand Down
3 changes: 2 additions & 1 deletion vortex-array/src/array/sparse/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ impl BinaryNumericFn<SparseArray> for SparseEncoding {
.fill_scalar()
.as_primitive()
.checked_numeric_operator(rhs_scalar.as_primitive(), op)?
.ok_or_else(|| vortex_err!("numeric overflow"))?;
.ok_or_else(|| vortex_err!("numeric overflow"))?
.to_scalar();
SparseArray::try_new_from_patches(
new_patches,
array.len(),
Expand Down
152 changes: 86 additions & 66 deletions vortex-scalar/src/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::any::type_name;

use num_traits::{FromPrimitive, NumCast};
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, NumCast};
use vortex_dtype::half::f16;
use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType};
use vortex_error::{
Expand All @@ -11,7 +11,7 @@ use crate::pvalue::PValue;
use crate::value::ScalarValue;
use crate::{InnerScalarValue, Scalar};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub struct PrimitiveScalar<'a> {
dtype: &'a DType,
ptype: PType,
Expand Down Expand Up @@ -106,34 +106,17 @@ impl<'a> PrimitiveScalar<'a> {
}
}

pub fn checked_sub(&self, other: &PrimitiveScalar) -> VortexResult<Self> {
if self.ptype != other.ptype {
vortex_bail!("Failed to subtract {} and {}", self.ptype, other.ptype)
}
let result_pvalue: Option<PValue> = match_each_native_ptype!(
self.ptype,
integral: |$T| {
let lhs = self.as_::<$T>()?;
let rhs = other.as_::<$T>()?;
match (lhs, rhs) {
(Some(lv), Some(rv)) => lv.checked_sub(rv).map(PValue::from),
_ => None
}
}
floating_point: |$T| {
let lhs = self.as_::<$T>()?;
let rhs = other.as_::<$T>()?;
match (lhs, rhs) {
(Some(lv), Some(rv)) => Some((lv - rv).into()),
_ => None
}
}
);
Ok(Self {
dtype: self.dtype,
ptype: self.ptype,
pvalue: result_pvalue,
})
pub fn checked_sub(self, other: PrimitiveScalar<'a>) -> VortexResult<Option<Self>> {
self.checked_numeric_operator(other, BinaryNumericOperator::Sub)
}

pub fn to_scalar(self) -> Scalar {
Scalar::new(
self.dtype().clone(),
self.pvalue
.map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
.unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
)
}
}

Expand Down Expand Up @@ -190,7 +173,8 @@ impl std::ops::Sub for PrimitiveScalar<'_> {
type Output = VortexResult<Self>;

fn sub(self, rhs: Self) -> Self::Output {
self.checked_sub(&rhs)
self.checked_sub(rhs)?
.ok_or_else(|| vortex_err!("PrimitiveScalar subtract: overflow or underflow"))
}
}

Expand Down Expand Up @@ -323,7 +307,7 @@ pub enum BinaryNumericOperator {
// Pow,
}

impl PrimitiveScalar<'_> {
impl<'a> PrimitiveScalar<'a> {
/// Apply the (checked) operator to self and other using SQL-style null semantics.
///
/// If the operation overflows, Ok(None) is returned.
Expand All @@ -333,55 +317,82 @@ impl PrimitiveScalar<'_> {
/// If either value is null, the result is null.
pub fn checked_numeric_operator(
self,
other: PrimitiveScalar<'_>,
other: PrimitiveScalar<'a>,
op: BinaryNumericOperator,
) -> VortexResult<Option<Scalar>> {
) -> VortexResult<Option<PrimitiveScalar<'a>>> {
if !self.dtype().eq_ignore_nullability(other.dtype()) {
vortex_bail!("types must match: {} {}", self.dtype(), other.dtype());
}

let nullability =
Nullability::from(self.dtype().is_nullable() || other.dtype().is_nullable());
let result_dtype = if self.dtype().is_nullable() {
self.dtype()
} else {
other.dtype()
};
let ptype = self.ptype();

Ok(match_each_native_ptype!(
self.ptype(),
integral: |$P| {
let lhs = self.typed_value::<$P>();
let rhs = other.typed_value::<$P>();
match (lhs, rhs) {
(_, None) | (None, _) => Some(Scalar::null(self.dtype().with_nullability(nullability))),
(Some(lhs), Some(rhs)) => match op {
BinaryNumericOperator::Add =>
lhs.checked_add(rhs).map(|result| Scalar::primitive(result, nullability)),
BinaryNumericOperator::Sub =>
lhs.checked_sub(rhs).map(|result| Scalar::primitive(result, nullability)),
BinaryNumericOperator::Mul =>
lhs.checked_mul(rhs).map(|result| Scalar::primitive(result, nullability)),
BinaryNumericOperator::Div =>
lhs.checked_div(rhs).map(|result| Scalar::primitive(result, nullability)),
}
}
self.checked_integeral_numeric_operator::<$P>(other, result_dtype, ptype, op)
}
floating_point: |$P| {
let lhs = self.typed_value::<$P>();
let rhs = other.typed_value::<$P>();
Some(match (lhs, rhs) {
(_, None) | (None, _) => Scalar::null(self.dtype().with_nullability(nullability)),
(Some(lhs), Some(rhs)) => match op {
BinaryNumericOperator::Add => Scalar::primitive(lhs + rhs, nullability),
BinaryNumericOperator::Sub => Scalar::primitive(lhs - rhs, nullability),
BinaryNumericOperator::Mul => Scalar::primitive(lhs - rhs, nullability),
BinaryNumericOperator::Div => Scalar::primitive(lhs - rhs, nullability),
let value_or_null = match (lhs, rhs) {
(_, None) | (None, _) => None,
(Some(lhs), Some(rhs)) => match op {
BinaryNumericOperator::Add => Some(lhs + rhs),
BinaryNumericOperator::Sub => Some(lhs - rhs),
BinaryNumericOperator::Mul => Some(lhs * rhs),
BinaryNumericOperator::Div => Some(lhs / rhs),
}
})
};
Some(Self { dtype: result_dtype, ptype: ptype, pvalue: value_or_null.map(PValue::from) })
}
))
}

fn checked_integeral_numeric_operator<
P: NativePType
+ TryFrom<PValue, Error = VortexError>
+ CheckedSub
+ CheckedAdd
+ CheckedMul
+ CheckedDiv,
>(
self,
other: PrimitiveScalar<'a>,
result_dtype: &'a DType,
ptype: PType,
op: BinaryNumericOperator,
) -> Option<PrimitiveScalar<'a>>
where
PValue: From<P>,
{
let lhs = self.typed_value::<P>();
let rhs = other.typed_value::<P>();
let value_or_null_or_overflow = match (lhs, rhs) {
(_, None) | (None, _) => Some(None),
(Some(lhs), Some(rhs)) => match op {
BinaryNumericOperator::Add => lhs.checked_add(&rhs).map(Some),
BinaryNumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
BinaryNumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
BinaryNumericOperator::Div => lhs.checked_div(&rhs).map(Some),
},
};

value_or_null_or_overflow.map(|value_or_null| Self {
dtype: result_dtype,
ptype,
pvalue: value_or_null.map(PValue::from),
})
}
}

#[cfg(test)]
mod tests {
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::VortexError;

use crate::value::InnerScalarValue;
use crate::{PValue, PrimitiveScalar, ScalarValue};
Expand All @@ -399,8 +410,9 @@ mod tests {
&ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))),
)
.unwrap();
let res = p_scalar1.checked_sub(&p_scalar2).unwrap();
assert_eq!(res.as_::<i32>().unwrap().unwrap(), 1);
let pscalar_or_overflow = p_scalar1.checked_sub(p_scalar2).unwrap();
let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<i32>();
assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 1);

assert_eq!(
(p_scalar1 - p_scalar2)
Expand All @@ -413,6 +425,7 @@ mod tests {
}

#[test]
#[allow(clippy::assertions_on_constants)]
fn test_integer_subtract_overflow() {
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let p_scalar1 = PrimitiveScalar::try_new(
Expand All @@ -425,8 +438,14 @@ mod tests {
&ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))),
)
.unwrap();
let res = p_scalar1 - p_scalar2;
assert!(res.unwrap().pvalue.is_none());
let pscalar_or_error = p_scalar1 - p_scalar2;
match pscalar_or_error {
Err(VortexError::InvalidArgument(message, _)) => assert_eq!(
message.as_ref(),
"PrimitiveScalar subtract: overflow or underflow"
),
res => assert!(false, "expected overflow error but got: {:?}", res),
}
}

#[test]
Expand All @@ -442,8 +461,9 @@ mod tests {
&ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))),
)
.unwrap();
let res = p_scalar1.checked_sub(&p_scalar2).unwrap();
assert_eq!(res.as_::<f32>().unwrap().unwrap(), 0.99f32);
let pscalar_or_overflow = p_scalar1.checked_sub(p_scalar2).unwrap();
let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::<f32>();
assert_eq!(value_or_null_or_type_error.unwrap().unwrap(), 0.99f32);

assert_eq!(
(p_scalar1 - p_scalar2)
Expand Down

0 comments on commit 530c2c0

Please sign in to comment.