diff --git a/bench-vortex/src/bin/notimplemented.rs b/bench-vortex/src/bin/notimplemented.rs index 96062ed275..c2892f206c 100644 --- a/bench-vortex/src/bin/notimplemented.rs +++ b/bench-vortex/src/bin/notimplemented.rs @@ -174,7 +174,7 @@ fn compute_funcs(encodings: &[ArrayData]) { "fill_forward", "filter", "scalar_at", - "subtract_scalar", + "binary_numeric", "search_sorted", "slice", "take", @@ -190,7 +190,7 @@ fn compute_funcs(encodings: &[ArrayData]) { impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some())); impls.push(bool_to_cell(arr.encoding().filter_fn().is_some())); impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some())); - impls.push(bool_to_cell(arr.encoding().subtract_scalar_fn().is_some())); + impls.push(bool_to_cell(arr.encoding().binary_numeric_fn().is_some())); impls.push(bool_to_cell(arr.encoding().search_sorted_fn().is_some())); impls.push(bool_to_cell(arr.encoding().slice_fn().is_some())); impls.push(bool_to_cell(arr.encoding().take_fn().is_some())); diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 51490e9ce2..09ee3e9eea 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -2,16 +2,20 @@ mod compare; mod like; use vortex_array::compute::{ - filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, LikeFn, - ScalarAtFn, SliceFn, TakeFn, + binary_numeric, filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable, + FilterFn, FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; -use vortex_scalar::Scalar; +use vortex_scalar::{BinaryNumericOperator, Scalar}; use crate::{DictArray, DictEncoding}; impl ComputeVTable for DictEncoding { + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + Some(self) + } + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } @@ -37,6 +41,23 @@ impl ComputeVTable for DictEncoding { } } +impl BinaryNumericFn for DictEncoding { + fn binary_numeric( + &self, + array: &DictArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + if !rhs.is_constant() { + return Ok(None); + } + + DictArray::try_new(array.codes(), binary_numeric(&array.values(), rhs, op)?) + .map(IntoArrayData::into_array) + .map(Some) + } +} + impl ScalarAtFn for DictEncoding { fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult { let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?; diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index a44d2af547..974bc82520 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -9,18 +9,22 @@ use std::ops::AddAssign; use num_traits::AsPrimitive; use vortex_array::array::{BooleanBuffer, PrimitiveArray}; use vortex_array::compute::{ - filter, scalar_at, slice, CompareFn, ComputeVTable, FillNullFn, FilterFn, FilterMask, InvertFn, - ScalarAtFn, SliceFn, TakeFn, + binary_numeric, filter, scalar_at, slice, BinaryNumericFn, CompareFn, ComputeVTable, + FillNullFn, FilterFn, FilterMask, InvertFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType}; use vortex_error::{VortexResult, VortexUnwrap}; -use vortex_scalar::Scalar; +use vortex_scalar::{BinaryNumericOperator, Scalar}; use crate::{RunEndArray, RunEndEncoding}; impl ComputeVTable for RunEndEncoding { + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + Some(self) + } + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } @@ -50,6 +54,28 @@ impl ComputeVTable for RunEndEncoding { } } +impl BinaryNumericFn for RunEndEncoding { + fn binary_numeric( + &self, + array: &RunEndArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + if !rhs.is_constant() { + return Ok(None); + } + + RunEndArray::with_offset_and_length( + array.ends(), + binary_numeric(&array.values(), rhs, op)?, + array.offset(), + array.len(), + ) + .map(IntoArrayData::into_array) + .map(Some) + } +} + impl ScalarAtFn for RunEndEncoding { fn scalar_at(&self, array: &RunEndArray, index: usize) -> VortexResult { scalar_at(array.values(), array.find_physical_index(index)?) diff --git a/vortex-array/benches/scalar_subtract.rs b/vortex-array/benches/scalar_subtract.rs index 1efed008c8..58233c9e76 100644 --- a/vortex-array/benches/scalar_subtract.rs +++ b/vortex-array/benches/scalar_subtract.rs @@ -28,8 +28,7 @@ fn scalar_subtract(c: &mut Criterion) { group.bench_function("vortex", |b| { b.iter(|| { - let array = - vortex_array::compute::subtract_scalar(&chunked, &to_subtract.into()).unwrap(); + let array = vortex_array::compute::sub_scalar(&chunked, to_subtract.into()).unwrap(); let chunked = ChunkedArray::try_from(array).unwrap(); black_box(chunked); diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index 8b63279719..242082ae03 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -4,8 +4,8 @@ use vortex_error::VortexResult; use crate::array::chunked::ChunkedArray; use crate::array::ChunkedEncoding; use crate::compute::{ - try_cast, BinaryBooleanFn, CastFn, CompareFn, ComputeVTable, FillNullFn, FilterFn, InvertFn, - ScalarAtFn, SliceFn, SubtractScalarFn, TakeFn, + try_cast, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, ComputeVTable, FillNullFn, + FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn, }; use crate::{ArrayData, IntoArrayData}; @@ -23,6 +23,10 @@ impl ComputeVTable for ChunkedEncoding { Some(self) } + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + Some(self) + } + fn cast_fn(&self) -> Option<&dyn CastFn> { Some(self) } @@ -51,10 +55,6 @@ impl ComputeVTable for ChunkedEncoding { Some(self) } - fn subtract_scalar_fn(&self) -> Option<&dyn SubtractScalarFn> { - Some(self) - } - fn take_fn(&self) -> Option<&dyn TakeFn> { Some(self) } diff --git a/vortex-array/src/array/chunked/compute/take.rs b/vortex-array/src/array/chunked/compute/take.rs index 5069b648a2..ef54f8046e 100644 --- a/vortex-array/src/array/chunked/compute/take.rs +++ b/vortex-array/src/array/chunked/compute/take.rs @@ -6,8 +6,7 @@ use vortex_scalar::Scalar; use crate::array::chunked::ChunkedArray; use crate::array::ChunkedEncoding; use crate::compute::{ - scalar_at, search_sorted_usize, slice, subtract_scalar, take, try_cast, SearchSortedSide, - TakeFn, + scalar_at, search_sorted_usize, slice, sub_scalar, take, try_cast, SearchSortedSide, TakeFn, }; use crate::stats::ArrayStatistics; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; @@ -93,15 +92,15 @@ fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResu .max_value_as_u64() .try_into()? { - subtract_scalar( + sub_scalar( &chunk_indices, - &Scalar::from(chunk_begin).cast(chunk_indices.dtype())?, + Scalar::from(chunk_begin).cast(chunk_indices.dtype())?, )? } else { // Note. this try_cast (memory copy) is unnecessary, could instead upcast in the subtract fn. // and avoid an extra let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())?; - subtract_scalar(&u64_chunk_indices, &chunk_begin.into())? + sub_scalar(&u64_chunk_indices, chunk_begin.into())? }; indices_by_chunk[chunk_idx] = Some(chunk_indices); diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index 1aea504a7c..32ed452964 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -9,11 +9,11 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap}; -use vortex_scalar::Scalar; +use vortex_scalar::BinaryNumericOperator; use crate::array::primitive::PrimitiveArray; use crate::compute::{ - scalar_at, search_sorted_usize, subtract_scalar, SearchSortedSide, SubtractScalarFn, + binary_numeric, scalar_at, search_sorted_usize, slice, BinaryNumericFn, SearchSortedSide, }; use crate::encoding::ids; use crate::iter::{ArrayIterator, ArrayIteratorAdapter}; @@ -234,17 +234,25 @@ impl ValidityVTable for ChunkedEncoding { } } -impl SubtractScalarFn for ChunkedEncoding { - fn subtract_scalar( +impl BinaryNumericFn for ChunkedEncoding { + fn binary_numeric( &self, array: &ChunkedArray, - to_subtract: &Scalar, - ) -> VortexResult { - let chunks = array - .chunks() - .map(|chunk| subtract_scalar(&chunk, to_subtract)) - .collect::>>()?; - Ok(ChunkedArray::try_new(chunks, array.dtype().clone())?.into_array()) + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let mut start = 0; + + let mut new_chunks = Vec::with_capacity(array.nchunks()); + for chunk in array.chunks() { + let end = start + chunk.len(); + new_chunks.push(binary_numeric(&chunk, &slice(rhs, start, end)?, op)?); + start = end; + } + + ChunkedArray::try_new(new_chunks, array.dtype().clone()) + .map(IntoArrayData::into_array) + .map(Some) } } @@ -254,7 +262,7 @@ mod test { use vortex_error::VortexResult; use crate::array::chunked::ChunkedArray; - use crate::compute::{scalar_at, subtract_scalar}; + use crate::compute::{scalar_at, sub_scalar}; use crate::{assert_arrays_eq, ArrayDType, IntoArrayData, IntoArrayVariant}; fn chunked_array() -> ChunkedArray { @@ -271,9 +279,9 @@ mod test { #[test] fn test_scalar_subtract() { - let chunked = chunked_array(); + let chunked = chunked_array().into_array(); let to_subtract = 1u64; - let array = subtract_scalar(&chunked, &to_subtract.into()).unwrap(); + let array = sub_scalar(&chunked, to_subtract.into()).unwrap(); let chunked = ChunkedArray::try_from(array).unwrap(); let mut chunks_out = chunked.chunks(); diff --git a/vortex-array/src/array/constant/compute/binary_numeric.rs b/vortex-array/src/array/constant/compute/binary_numeric.rs new file mode 100644 index 0000000000..b7994b2167 --- /dev/null +++ b/vortex-array/src/array/constant/compute/binary_numeric.rs @@ -0,0 +1,31 @@ +use vortex_error::{vortex_err, VortexResult}; +use vortex_scalar::BinaryNumericOperator; + +use crate::array::{ConstantArray, ConstantEncoding}; +use crate::compute::BinaryNumericFn; +use crate::{ArrayData, ArrayLen as _, IntoArrayData as _}; + +impl BinaryNumericFn for ConstantEncoding { + fn binary_numeric( + &self, + array: &ConstantArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let Some(rhs) = rhs.as_constant() else { + return Ok(None); + }; + + Ok(Some( + ConstantArray::new( + array + .scalar() + .as_primitive() + .checked_numeric_operator(rhs.as_primitive(), op)? + .ok_or_else(|| vortex_err!("numeric overflow"))?, + array.len(), + ) + .into_array(), + )) + } +} diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index 56cdaf0ce3..dbb7dfbc7b 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -1,3 +1,4 @@ +mod binary_numeric; mod boolean; mod compare; mod invert; @@ -9,8 +10,8 @@ use vortex_scalar::Scalar; use crate::array::constant::ConstantArray; use crate::array::ConstantEncoding; use crate::compute::{ - BinaryBooleanFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn, - SearchSortedFn, SliceFn, TakeFn, + BinaryBooleanFn, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn, + ScalarAtFn, SearchSortedFn, SliceFn, TakeFn, }; use crate::{ArrayData, IntoArrayData}; @@ -19,6 +20,10 @@ impl ComputeVTable for ConstantEncoding { Some(self) } + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + Some(self) + } + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index 40938d8560..dbe3ca459f 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -1,10 +1,10 @@ use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::{BinaryNumericOperator, Scalar}; use crate::array::null::NullArray; use crate::array::NullEncoding; -use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{BinaryNumericFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -13,6 +13,10 @@ impl ComputeVTable for NullEncoding { Some(self) } + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + Some(self) + } + fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } @@ -22,6 +26,18 @@ impl ComputeVTable for NullEncoding { } } +impl BinaryNumericFn for NullEncoding { + fn binary_numeric( + &self, + array: &NullArray, + _rhs: &ArrayData, + _op: BinaryNumericOperator, + ) -> VortexResult> { + // for any arithmetic operation, forall X. NULL op X = NULL + Ok(Some(NullArray::new(array.len()).into_array())) + } +} + impl SliceFn for NullEncoding { fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult { Ok(NullArray::new(stop - start).into_array()) diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index 2385d54112..bac23e2d7f 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -1,7 +1,7 @@ use crate::array::PrimitiveEncoding; use crate::compute::{ CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn, - SearchSortedUsizeFn, SliceFn, SubtractScalarFn, TakeFn, + SearchSortedUsizeFn, SliceFn, TakeFn, }; use crate::ArrayData; @@ -11,7 +11,6 @@ mod filter; mod scalar_at; mod search_sorted; mod slice; -mod subtract_scalar; mod take; impl ComputeVTable for PrimitiveEncoding { @@ -43,10 +42,6 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } - fn subtract_scalar_fn(&self) -> Option<&dyn SubtractScalarFn> { - Some(self) - } - fn take_fn(&self) -> Option<&dyn TakeFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/compute/subtract_scalar.rs b/vortex-array/src/array/primitive/compute/subtract_scalar.rs deleted file mode 100644 index 6b71b59a83..0000000000 --- a/vortex-array/src/array/primitive/compute/subtract_scalar.rs +++ /dev/null @@ -1,156 +0,0 @@ -use itertools::Itertools; -use num_traits::WrappingSub; -use vortex_dtype::{match_each_float_ptype, match_each_integer_ptype, NativePType}; -use vortex_error::{vortex_bail, vortex_err, VortexResult}; -use vortex_scalar::{PrimitiveScalar, Scalar}; - -use crate::array::constant::ConstantArray; -use crate::array::primitive::PrimitiveArray; -use crate::array::PrimitiveEncoding; -use crate::compute::SubtractScalarFn; -use crate::validity::ArrayValidity; -use crate::variants::PrimitiveArrayTrait; -use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData}; - -impl SubtractScalarFn for PrimitiveEncoding { - fn subtract_scalar( - &self, - array: &PrimitiveArray, - to_subtract: &Scalar, - ) -> VortexResult { - if array.dtype() != to_subtract.dtype() { - vortex_bail!(MismatchedTypes: array.dtype(), to_subtract.dtype()) - } - - let validity = array.validity().to_logical(array.len()); - if validity.all_invalid() { - return Ok( - ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(), - ); - } - - let result = if to_subtract.dtype().is_int() { - match_each_integer_ptype!(array.ptype(), |$T| { - let to_subtract: $T = PrimitiveScalar::try_from(to_subtract)? - .typed_value::<$T>() - .ok_or_else(|| vortex_err!("expected primitive"))?; - subtract_scalar_integer::<$T>(array, to_subtract)? - }) - } else { - match_each_float_ptype!(array.ptype(), |$T| { - let to_subtract: $T = PrimitiveScalar::try_from(to_subtract)? - .typed_value::<$T>() - .ok_or_else(|| vortex_err!("expected primitive"))?; - let sub_vec : Vec<$T> = array.maybe_null_slice::<$T>() - .iter() - .map(|&v| v - to_subtract).collect_vec(); - PrimitiveArray::from(sub_vec) - }) - }; - Ok(result.into_array()) - } -} - -fn subtract_scalar_integer( - subtract_from: &PrimitiveArray, - to_subtract: T, -) -> VortexResult { - if to_subtract.is_zero() { - // if to_subtract is zero, skip operation - return Ok(subtract_from.clone()); - } - - let contains_nulls = !subtract_from.logical_validity().all_valid(); - let subtraction_result = if contains_nulls { - let sub_vec = subtract_from - .maybe_null_slice() - .iter() - .map(|&v: &T| v.wrapping_sub(&to_subtract)) - .collect_vec(); - PrimitiveArray::from_vec(sub_vec, subtract_from.validity()) - } else { - PrimitiveArray::from( - subtract_from - .maybe_null_slice::() - .iter() - .map(|&v| v - to_subtract) - .collect_vec(), - ) - }; - Ok(subtraction_result) -} - -#[cfg(test)] -mod test { - use itertools::Itertools; - - use crate::array::primitive::PrimitiveArray; - use crate::compute::subtract_scalar; - use crate::{ArrayLen, IntoArrayData, IntoArrayVariant}; - - #[test] - fn test_scalar_subtract_unsigned() { - let values = vec![1u16, 2, 3].into_array(); - let results = subtract_scalar(&values, &1u16.into()) - .unwrap() - .into_primitive() - .unwrap() - .maybe_null_slice::() - .to_vec(); - assert_eq!(results, &[0u16, 1, 2]); - } - - #[test] - fn test_scalar_subtract_signed() { - let values = vec![1i64, 2, 3].into_array(); - let results = subtract_scalar(&values, &(-1i64).into()) - .unwrap() - .into_primitive() - .unwrap() - .maybe_null_slice::() - .to_vec(); - assert_eq!(results, &[2i64, 3, 4]); - } - - #[test] - fn test_scalar_subtract_nullable() { - let values = PrimitiveArray::from_nullable_vec(vec![Some(1u16), Some(2), None, Some(3)]) - .into_array(); - let flattened = subtract_scalar(&values, &Some(1u16).into()) - .unwrap() - .into_primitive() - .unwrap(); - - let results = flattened.maybe_null_slice::().to_vec(); - assert_eq!(results, &[0u16, 1, 65535, 2]); - let valid_indices = flattened - .validity() - .to_logical(flattened.len()) - .to_null_buffer() - .unwrap() - .unwrap() - .valid_indices() - .collect_vec(); - assert_eq!(valid_indices, &[0, 1, 3]); - } - - #[test] - fn test_scalar_subtract_float() { - let values = vec![1.0f64, 2.0, 3.0].into_array(); - let to_subtract = -1f64; - let results = subtract_scalar(&values, &to_subtract.into()) - .unwrap() - .into_primitive() - .unwrap() - .maybe_null_slice::() - .to_vec(); - assert_eq!(results, &[2.0f64, 3.0, 4.0]); - } - - #[test] - fn test_scalar_subtract_float_underflow_is_ok() { - let values = vec![f32::MIN, 2.0, 3.0].into_array(); - let _results = subtract_scalar(&values, &1.0f32.into()).unwrap(); - let _results = subtract_scalar(&values, &f32::MAX.into()).unwrap(); - } -} diff --git a/vortex-array/src/array/sparse/compute/binary_numeric.rs b/vortex-array/src/array/sparse/compute/binary_numeric.rs new file mode 100644 index 0000000000..50d1e686e8 --- /dev/null +++ b/vortex-array/src/array/sparse/compute/binary_numeric.rs @@ -0,0 +1,36 @@ +use vortex_error::{vortex_err, VortexResult}; +use vortex_scalar::BinaryNumericOperator; + +use crate::array::{SparseArray, SparseEncoding}; +use crate::compute::{binary_numeric, BinaryNumericFn}; +use crate::{ArrayData, ArrayLen as _, IntoArrayData}; + +impl BinaryNumericFn for SparseEncoding { + fn binary_numeric( + &self, + array: &SparseArray, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let Some(rhs_scalar) = rhs.as_constant() else { + return Ok(None); + }; + + let new_patches = array + .patches() + .map_values(|values| binary_numeric(&values, rhs, op))?; + let new_fill_value = array + .fill_scalar() + .as_primitive() + .checked_numeric_operator(rhs_scalar.as_primitive(), op)? + .ok_or_else(|| vortex_err!("numeric overflow"))?; + SparseArray::try_new_from_patches( + new_patches, + array.len(), + array.indices_offset(), + new_fill_value, + ) + .map(IntoArrayData::into_array) + .map(Some) + } +} diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 8809e14958..fa4765632a 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -4,16 +4,21 @@ use vortex_scalar::Scalar; use crate::array::sparse::SparseArray; use crate::array::{ConstantArray, SparseEncoding}; use crate::compute::{ - ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn, SearchResult, SearchSortedFn, - SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn, + BinaryNumericFn, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn, SearchResult, + SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn, }; use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData}; +mod binary_numeric; mod invert; mod slice; mod take; impl ComputeVTable for SparseEncoding { + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index b646c50ad8..8568034743 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -5,7 +5,7 @@ use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult}; use vortex_scalar::{Scalar, ScalarValue}; use crate::array::constant::ConstantArray; -use crate::compute::{scalar_at, subtract_scalar}; +use crate::compute::{scalar_at, sub_scalar}; use crate::encoding::ids; use crate::patches::{Patches, PatchesMetadata}; use crate::stats::{ArrayStatistics, Stat, StatisticsVTable, StatsSet}; @@ -125,7 +125,8 @@ impl SparseArray { #[inline] pub fn resolved_patches(&self) -> VortexResult { let (len, indices, values) = self.patches().into_parts(); - let indices = subtract_scalar(indices, &Scalar::from(self.indices_offset()))?; + let indices_offset = Scalar::from(self.indices_offset()).cast(indices.dtype())?; + let indices = sub_scalar(indices, indices_offset)?; Ok(Patches::new(len, indices, values)) } diff --git a/vortex-array/src/compute/binary_numeric.rs b/vortex-array/src/compute/binary_numeric.rs new file mode 100644 index 0000000000..220608166d --- /dev/null +++ b/vortex-array/src/compute/binary_numeric.rs @@ -0,0 +1,256 @@ +use std::sync::Arc; + +use arrow_array::ArrayRef; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use vortex_scalar::{BinaryNumericOperator, Scalar}; + +use crate::array::ConstantArray; +use crate::arrow::{Datum, FromArrowArray}; +use crate::encoding::{downcast_array_ref, Encoding}; +use crate::{ArrayDType, ArrayData, IntoArrayData as _}; + +pub trait BinaryNumericFn { + fn binary_numeric( + &self, + array: &Array, + other: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult>; +} + +impl BinaryNumericFn for E +where + E: BinaryNumericFn, + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + fn binary_numeric( + &self, + lhs: &ArrayData, + rhs: &ArrayData, + op: BinaryNumericOperator, + ) -> VortexResult> { + let (array_ref, encoding) = downcast_array_ref::(lhs)?; + BinaryNumericFn::binary_numeric(encoding, array_ref, rhs, op) + } +} + +/// Point-wise add two numeric arrays. +pub fn add(lhs: impl AsRef, rhs: impl AsRef) -> VortexResult { + binary_numeric(lhs.as_ref(), rhs.as_ref(), BinaryNumericOperator::Add) +} + +/// Point-wise add a scalar value to this array on the right-hand-side. +pub fn add_scalar(lhs: impl AsRef, rhs: Scalar) -> VortexResult { + let lhs = lhs.as_ref(); + binary_numeric( + lhs, + &ConstantArray::new(rhs, lhs.len()).into_array(), + BinaryNumericOperator::Add, + ) +} + +/// Point-wise subtract two numeric arrays. +pub fn sub(lhs: impl AsRef, rhs: impl AsRef) -> VortexResult { + binary_numeric(lhs.as_ref(), rhs.as_ref(), BinaryNumericOperator::Sub) +} + +/// Point-wise subtract a scalar value from this array on the right-hand-side. +pub fn sub_scalar(lhs: impl AsRef, rhs: Scalar) -> VortexResult { + let lhs = lhs.as_ref(); + binary_numeric( + lhs, + &ConstantArray::new(rhs, lhs.len()).into_array(), + BinaryNumericOperator::Sub, + ) +} + +/// Point-wise multiply two numeric arrays. +pub fn mul(lhs: impl AsRef, rhs: impl AsRef) -> VortexResult { + binary_numeric(lhs.as_ref(), rhs.as_ref(), BinaryNumericOperator::Mul) +} + +/// Point-wise multiply a scalar value into this array on the right-hand-side. +pub fn mul_scalar(lhs: impl AsRef, rhs: Scalar) -> VortexResult { + let lhs = lhs.as_ref(); + binary_numeric( + lhs, + &ConstantArray::new(rhs, lhs.len()).into_array(), + BinaryNumericOperator::Mul, + ) +} + +/// Point-wise divide two numeric arrays. +pub fn div(lhs: impl AsRef, rhs: impl AsRef) -> VortexResult { + binary_numeric(lhs.as_ref(), rhs.as_ref(), BinaryNumericOperator::Div) +} + +/// Point-wise divide a scalar value into this array on the right-hand-side. +pub fn div_scalar(lhs: impl AsRef, rhs: Scalar) -> VortexResult { + let lhs = lhs.as_ref(); + binary_numeric( + lhs, + &ConstantArray::new(rhs, lhs.len()).into_array(), + BinaryNumericOperator::Mul, + ) +} + +pub fn binary_numeric( + lhs: &ArrayData, + rhs: &ArrayData, + op: BinaryNumericOperator, +) -> VortexResult { + if lhs.len() != rhs.len() { + vortex_bail!("Numeric operations aren't supported on arrays of different lengths") + } + if !matches!(lhs.dtype(), DType::Primitive(_, _)) + || !matches!(rhs.dtype(), DType::Primitive(_, _)) + || lhs.dtype() != rhs.dtype() + { + vortex_bail!( + "Numeric operations are only supported on two arrays sharing the same primitive-type: {} {}", + lhs.dtype(), + rhs.dtype() + ) + } + + // Check if LHS supports the operation directly. + if let Some(fun) = lhs.encoding().binary_numeric_fn() { + if let Some(result) = fun.binary_numeric(lhs, rhs, op)? { + return Ok(result); + } + } + + // Check if RHS supports the operation directly. + if let Some(fun) = rhs.encoding().binary_numeric_fn() { + if let Some(result) = fun.binary_numeric(rhs, lhs, op)? { + return Ok(result); + } + } + + log::debug!( + "No numeric implementation found for LHS {}, RHS {}, and operator {:?}", + lhs.encoding().id(), + rhs.encoding().id(), + op, + ); + + // If neither side implements the trait, then we delegate to Arrow compute. + arrow_numeric(lhs.clone(), rhs.clone(), op) +} + +/// Implementation of `BinaryBooleanFn` using the Arrow crate. +/// +/// Note that other encodings should handle a constant RHS value, so we can assume here that +/// the RHS is not constant and expand to a full array. +fn arrow_numeric( + lhs: ArrayData, + rhs: ArrayData, + operator: BinaryNumericOperator, +) -> VortexResult { + let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); + + let lhs = Datum::try_from(lhs)?; + let rhs = Datum::try_from(rhs)?; + + let array = match operator { + BinaryNumericOperator::Add => arrow_arith::numeric::add(&lhs, &rhs)?, + BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&lhs, &rhs)?, + BinaryNumericOperator::Div => arrow_arith::numeric::div(&lhs, &rhs)?, + BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&lhs, &rhs)?, + }; + + Ok(ArrayData::from_arrow(Arc::new(array) as ArrayRef, nullable)) +} + +#[cfg(test)] +mod test { + use vortex_scalar::Scalar; + + use crate::array::PrimitiveArray; + use crate::compute::{scalar_at, sub_scalar}; + use crate::{ArrayLen as _, IntoArrayData, IntoCanonical}; + + #[test] + fn test_scalar_subtract_unsigned() { + let values = vec![1u16, 2, 3].into_array(); + let results = sub_scalar(&values, 1u16.into()) + .unwrap() + .into_canonical() + .unwrap() + .into_primitive() + .unwrap() + .maybe_null_slice::() + .to_vec(); + assert_eq!(results, &[0u16, 1, 2]); + } + + #[test] + fn test_scalar_subtract_signed() { + let values = vec![1i64, 2, 3].into_array(); + let results = sub_scalar(&values, (-1i64).into()) + .unwrap() + .into_canonical() + .unwrap() + .into_primitive() + .unwrap() + .maybe_null_slice::() + .to_vec(); + assert_eq!(results, &[2i64, 3, 4]); + } + + #[test] + fn test_scalar_subtract_nullable() { + let values = PrimitiveArray::from_nullable_vec(vec![Some(1u16), Some(2), None, Some(3)]) + .into_array(); + let result = sub_scalar(&values, Some(1u16).into()) + .unwrap() + .into_canonical() + .unwrap() + .into_primitive() + .unwrap(); + + let actual = (0..result.len()) + .map(|index| scalar_at(&result, index).unwrap()) + .collect::>(); + assert_eq!( + actual, + vec![ + Scalar::from(Some(0u16)), + Scalar::from(Some(1u16)), + Scalar::from(None::), + Scalar::from(Some(2u16)) + ] + ); + } + + #[test] + fn test_scalar_subtract_float() { + let values = vec![1.0f64, 2.0, 3.0].into_array(); + let to_subtract = -1f64; + let results = sub_scalar(&values, to_subtract.into()) + .unwrap() + .into_canonical() + .unwrap() + .into_primitive() + .unwrap() + .maybe_null_slice::() + .to_vec(); + assert_eq!(results, &[2.0f64, 3.0, 4.0]); + } + + #[test] + fn test_scalar_subtract_float_underflow_is_ok() { + let values = vec![f32::MIN, 2.0, 3.0].into_array(); + let _results = sub_scalar(&values, 1.0f32.into()).unwrap(); + let _results = sub_scalar(&values, f32::MAX.into()).unwrap(); + } + + #[test] + fn test_scalar_subtract_type_mismatch_fails() { + let values = vec![1u64, 2, 3].into_array(); + // Subtracting incompatible dtypes should fail + let _results = + sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error"); + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 54e99a489e..92f39021bf 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -7,6 +7,7 @@ //! implementations of these operators, else we will decode, and perform the equivalent operator //! from Arrow. +pub use binary_numeric::*; pub use boolean::{ and, and_kleene, binary_boolean, or, or_kleene, BinaryBooleanFn, BinaryOperator, }; @@ -18,13 +19,13 @@ pub use filter::{filter, FilterFn, FilterIter, FilterMask}; pub use invert::{invert, InvertFn}; pub use like::{like, LikeFn, LikeOptions}; pub use scalar_at::{scalar_at, ScalarAtFn}; -pub use scalar_subtract::{subtract_scalar, SubtractScalarFn}; pub use search_sorted::*; pub use slice::{slice, SliceFn}; pub use take::{take, TakeFn}; use crate::ArrayData; +mod binary_numeric; mod boolean; mod cast; mod compare; @@ -34,7 +35,6 @@ mod filter; mod invert; mod like; mod scalar_at; -mod scalar_subtract; mod search_sorted; mod slice; mod take; @@ -48,6 +48,14 @@ pub trait ComputeVTable { None } + /// Implementation of binary numeric operations. + /// + /// See: [BinaryNumericFn]. + fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn> { + None + } + + /// Implemented for arrays that can be casted to different types. /// Implemented for arrays that can be casted to different types. /// /// See: [CastFn]. @@ -125,13 +133,6 @@ pub trait ComputeVTable { None } - /// Broadcast subtraction of scalar from Vortex array. - /// - /// See: [SubtractScalarFn]. - fn subtract_scalar_fn(&self) -> Option<&dyn SubtractScalarFn> { - None - } - /// Take a set of indices from an array. This often forces allocations and decoding of /// the receiver. /// diff --git a/vortex-array/src/compute/scalar_subtract.rs b/vortex-array/src/compute/scalar_subtract.rs deleted file mode 100644 index ba505107c1..0000000000 --- a/vortex-array/src/compute/scalar_subtract.rs +++ /dev/null @@ -1,48 +0,0 @@ -use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexError, VortexResult}; -use vortex_scalar::Scalar; - -use crate::encoding::Encoding; -use crate::{ArrayDType, ArrayData, IntoArrayVariant}; - -pub trait SubtractScalarFn { - fn subtract_scalar(&self, array: &Array, to_subtract: &Scalar) -> VortexResult; -} - -impl SubtractScalarFn for E -where - E: SubtractScalarFn, - for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, -{ - fn subtract_scalar(&self, array: &ArrayData, to_subtract: &Scalar) -> VortexResult { - let array_ref = <&E::Array>::try_from(array)?; - let encoding = array - .encoding() - .as_any() - .downcast_ref::() - .ok_or_else(|| vortex_err!("Mismatched encoding"))?; - SubtractScalarFn::subtract_scalar(encoding, array_ref, to_subtract) - } -} - -pub fn subtract_scalar( - array: impl AsRef, - to_subtract: &Scalar, -) -> VortexResult { - let array = array.as_ref(); - let to_subtract = to_subtract.cast(array.dtype())?; - - if let Some(f) = array.encoding().subtract_scalar_fn() { - return f.subtract_scalar(array, &to_subtract); - } - - // if subtraction is not implemented for the given array type, but the array has a numeric - // DType, we can flatten the array and apply subtraction to the flattened primitive array - match array.dtype() { - DType::Primitive(..) => subtract_scalar(array.clone().into_primitive()?, &to_subtract), - _ => Err(vortex_err!( - NotImplemented: "scalar_subtract", - array.encoding().id() - )), - } -} diff --git a/vortex-array/src/encoding/mod.rs b/vortex-array/src/encoding/mod.rs index a6c56fd1c1..004bbd520e 100644 --- a/vortex-array/src/encoding/mod.rs +++ b/vortex-array/src/encoding/mod.rs @@ -4,6 +4,8 @@ use std::any::Any; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use vortex_error::{vortex_err, VortexError, VortexResult}; + use crate::compute::ComputeVTable; use crate::stats::StatisticsVTable; use crate::validity::ValidityVTable; @@ -67,6 +69,19 @@ pub trait Encoding: 'static { type Metadata: ArrayMetadata; } +pub fn downcast_array_ref(array: &ArrayData) -> VortexResult<(&E::Array, &E)> +where + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + let array_ref = <&E::Array>::try_from(array)?; + let encoding = array + .encoding() + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("Mismatched encoding"))?; + Ok((array_ref, encoding)) +} + pub type EncodingRef = &'static dyn EncodingVTable; /// Object-safe encoding trait for an array. diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 3b9132bfde..e67e635085 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -10,8 +10,8 @@ use vortex_scalar::Scalar; use crate::aliases::hash_map::HashMap; use crate::array::PrimitiveArray; use crate::compute::{ - scalar_at, search_sorted, search_sorted_usize, search_sorted_usize_many, slice, - subtract_scalar, take, FilterMask, SearchResult, SearchSortedSide, + scalar_at, search_sorted, search_sorted_usize, search_sorted_usize_many, slice, sub_scalar, + take, FilterMask, SearchResult, SearchSortedSide, }; use crate::stats::{ArrayStatistics, Stat}; use crate::validity::Validity; @@ -237,7 +237,7 @@ impl Patches { // Subtract the start value from the indices let indices = slice(self.indices(), patch_start, patch_stop)?; - let indices = subtract_scalar(&indices, &Scalar::from(start).cast(indices.dtype())?)?; + let indices = sub_scalar(&indices, Scalar::from(start).cast(indices.dtype())?)?; Ok(Some(Self::new(stop - start, indices, values))) } diff --git a/vortex-array/src/stream/take_rows.rs b/vortex-array/src/stream/take_rows.rs index 9ddd86e847..c0f6c1f0f9 100644 --- a/vortex-array/src/stream/take_rows.rs +++ b/vortex-array/src/stream/take_rows.rs @@ -7,7 +7,7 @@ use vortex_dtype::match_each_integer_ptype; use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; -use crate::compute::{search_sorted_usize, slice, subtract_scalar, take, SearchSortedSide}; +use crate::compute::{search_sorted_usize, slice, sub_scalar, take, SearchSortedSide}; use crate::stats::{ArrayStatistics, Stat}; use crate::stream::ArrayStream; use crate::variants::PrimitiveArrayTrait; @@ -91,7 +91,7 @@ impl Stream for TakeRows { // onto a worker pool. let indices_for_batch = slice(this.indices, left, right)?.into_primitive()?; let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| { - subtract_scalar(&indices_for_batch.into_array(), &Scalar::from(curr_offset as $T))? + sub_scalar(&indices_for_batch.into_array(), Scalar::from(curr_offset as $T))? }); return Poll::Ready(take(&batch, &shifted_arr).map(Some).transpose()); } diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index a7a6648c89..0b406b6adf 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -147,6 +147,28 @@ macro_rules! match_each_native_ptype { PType::F32 => __with__! { f32 }, PType::F64 => __with__! { f64 }, } + }); + ($self:expr, + integral: | $_:tt $integral_enc:ident | { $($integral_body:tt)* } + floating_point: | $_2:tt $floating_point_enc:ident | { $($floating_point_body:tt)* } + ) => ({ + macro_rules! __with_integer__ {( $_ $integral_enc:ident ) => ( { $($integral_body)* } )} + macro_rules! __with_floating_point__ {( $_ $floating_point_enc:ident ) => ( { $($floating_point_body)* } )} + use $crate::PType; + use $crate::half::f16; + match $self { + PType::I8 => __with_integer__! { i8 }, + PType::I16 => __with_integer__! { i16 }, + PType::I32 => __with_integer__! { i32 }, + PType::I64 => __with_integer__! { i64 }, + PType::U8 => __with_integer__! { u8 }, + PType::U16 => __with_integer__! { u16 }, + PType::U32 => __with_integer__! { u32 }, + PType::U64 => __with_integer__! { u64 }, + PType::F16 => __with_floating_point__! { f16 }, + PType::F32 => __with_floating_point__! { f32 }, + PType::F64 => __with_floating_point__! { f64 }, + } }) } diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 3042a88209..feebc731a7 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -3,7 +3,9 @@ use std::any::type_name; use num_traits::{FromPrimitive, NumCast}; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; -use vortex_error::{vortex_err, vortex_panic, VortexError, VortexResult, VortexUnwrap}; +use vortex_error::{ + vortex_bail, vortex_err, vortex_panic, VortexError, VortexResult, VortexUnwrap, +}; use crate::pvalue::PValue; use crate::value::ScalarValue; @@ -270,3 +272,71 @@ impl From for Scalar { Scalar::primitive(value as u64, Nullability::NonNullable) } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryNumericOperator { + Add, + Sub, + Mul, + Div, + // Missing from arrow-rs: + // Min, + // Max, + // Pow, +} + +impl PrimitiveScalar<'_> { + /// Apply the (checked) operator to self and other using SQL-style null semantics. + /// + /// If the operation overflows, Ok(None) is returned. + /// + /// If the types are incompatible (ignoring nullability), an error is returned. + /// + /// If either value is null, the result is null. + pub fn checked_numeric_operator( + self, + other: PrimitiveScalar<'_>, + op: BinaryNumericOperator, + ) -> VortexResult> { + if !self.dtype().eq_ignore_nullability(other.dtype()) { + vortex_bail!("types must match: {} {}", self.dtype(), other.dtype()); + } + + let nullability = + Nullability::from(self.dtype().is_nullable() || other.dtype().is_nullable()); + + Ok(match_each_native_ptype!( + self.ptype(), + integral: |$P| { + let lhs = self.typed_value::<$P>(); + let rhs = other.typed_value::<$P>(); + match (lhs, rhs) { + (_, None) | (None, _) => Some(Scalar::null(self.dtype().with_nullability(nullability))), + (Some(lhs), Some(rhs)) => match op { + BinaryNumericOperator::Add => + lhs.checked_add(rhs).map(|result| Scalar::primitive(result, nullability)), + BinaryNumericOperator::Sub => + lhs.checked_sub(rhs).map(|result| Scalar::primitive(result, nullability)), + BinaryNumericOperator::Mul => + lhs.checked_mul(rhs).map(|result| Scalar::primitive(result, nullability)), + BinaryNumericOperator::Div => + lhs.checked_div(rhs).map(|result| Scalar::primitive(result, nullability)), + } + } + } + floating_point: |$P| { + let lhs = self.typed_value::<$P>(); + let rhs = other.typed_value::<$P>(); + Some(match (lhs, rhs) { + (_, None) | (None, _) => Scalar::null(self.dtype().with_nullability(nullability)), + (Some(lhs), Some(rhs)) => match op { + BinaryNumericOperator::Add => Scalar::primitive(lhs + rhs, nullability), + BinaryNumericOperator::Sub => Scalar::primitive(lhs - rhs, nullability), + BinaryNumericOperator::Mul => Scalar::primitive(lhs - rhs, nullability), + BinaryNumericOperator::Div => Scalar::primitive(lhs - rhs, nullability), + } + }) + } + )) + } +}