From 1a53f95221f8ee8ce44ff2a452efe9ae14feea01 Mon Sep 17 00:00:00 2001 From: Josh Casale Date: Mon, 29 Apr 2024 17:59:34 -0400 Subject: [PATCH] recursively apply --- vortex-array/Cargo.toml | 1 + vortex-array/src/array/chunked/compute/mod.rs | 5 + vortex-array/src/array/chunked/mod.rs | 39 ++++- vortex-array/src/array/primitive/mod.rs | 135 ++++++++++++++-- vortex-array/src/compute/scalar_subtract.rs | 145 ++++-------------- vortex-array/src/ptype.rs | 16 ++ vortex-ipc/src/reader.rs | 2 +- 7 files changed, 213 insertions(+), 130 deletions(-) diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 979e671fb3..332ff8aa9a 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -49,6 +49,7 @@ walkdir = { workspace = true } [dev-dependencies] criterion = { workspace = true } + [[bench]] name = "search_sorted" harness = false diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index 13ecc1889e..f5cb19351a 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -3,6 +3,7 @@ use vortex_error::VortexResult; use crate::array::chunked::ChunkedArray; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; +use crate::compute::scalar_subtract::ScalarSubtractFn; use crate::compute::take::TakeFn; use crate::compute::ArrayCompute; use crate::scalar::Scalar; @@ -22,6 +23,10 @@ impl ArrayCompute for ChunkedArray<'_> { fn take(&self) -> Option<&dyn TakeFn> { Some(self) } + + fn scalar_subtract(&self) -> Option<&dyn ScalarSubtractFn> { + Some(self) + } } impl AsContiguousFn for ChunkedArray<'_> { diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index a9c3e9d43a..a9b46e0e3d 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -4,7 +4,9 @@ use vortex_dtype::{IntWidth, Nullability, Signedness}; use vortex_error::{vortex_bail, VortexResult}; use crate::array::primitive::PrimitiveArray; +use crate::compute::as_contiguous::as_contiguous; use crate::compute::scalar_at::scalar_at; +use crate::compute::scalar_subtract::{scalar_subtract, ScalarSubtractFn}; use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::validity::Validity::NonNullable; use crate::validity::{ArrayValidity, LogicalValidity}; @@ -143,13 +145,25 @@ impl ArrayValidity for ChunkedArray<'_> { impl EncodingCompression for ChunkedEncoding {} +impl ScalarSubtractFn for ChunkedArray<'_> { + fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult { + as_contiguous( + &self + .chunks() + .map(|c| scalar_subtract(&c, to_subtract.clone()).unwrap()) + .collect_vec(), + ) + } +} + #[cfg(test)] mod test { use vortex_dtype::{DType, IntWidth, Nullability, Signedness}; use crate::array::chunked::{ChunkedArray, OwnedChunkedArray}; + use crate::compute::scalar_subtract::scalar_subtract; use crate::ptype::NativePType; - use crate::{Array, IntoArray}; + use crate::{Array, IntoArray, ToArray, ToStatic}; #[allow(dead_code)] fn chunked_array() -> OwnedChunkedArray { @@ -179,6 +193,29 @@ mod test { assert_eq!(values, slice); } + #[test] + fn test_scalar_subtract() { + let chunk1 = vec![1.0f64, 2.0, 3.0].into_array(); + let chunk2 = vec![4.0f64, 5.0, 6.0].into_array(); + let to_subtract = -1f64; + + let chunked = ChunkedArray::try_new( + vec![chunk1, chunk2], + DType::Float(64.into(), Nullability::NonNullable), + ) + .unwrap() + .to_array() + .to_static(); + + let array = scalar_subtract(&chunked, to_subtract).unwrap(); + let results = array + .flatten_primitive() + .unwrap() + .typed_data::() + .to_vec(); + assert_eq!(results, &[2.0f64, 3.0, 4.0, 5.0, 6.0, 7.0]); + } + // FIXME(ngates): bring back when slicing is a compute function. // #[test] // pub fn slice_middle() { diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index 79aa04fbef..00409f2f12 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -6,12 +6,12 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::buffer::Buffer; use crate::compute::scalar_subtract::ScalarSubtractFn; +use crate::match_each_integer_ptype; use crate::ptype::{NativePType, PType}; use crate::stats::ArrayStatistics; use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata}; use crate::visitor::{AcceptArrayVisitor, ArrayVisitor}; -use crate::{impl_encoding, ArrayDType, OwnedArray}; -use crate::{match_each_integer_ptype, scalar}; +use crate::{impl_encoding, match_each_float_ptype, ArrayDType, OwnedArray}; use crate::{match_each_native_ptype, ArrayFlatten}; mod accessor; @@ -196,7 +196,7 @@ impl<'a> Array<'a> { impl EncodingCompression for PrimitiveEncoding {} impl ScalarSubtractFn for PrimitiveArray<'_> { - fn scalar_subtract(&self, to_subtract: Scalar) -> VortexResult { + fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult { if self.dtype() != to_subtract.dtype() { vortex_bail!(MismatchedTypes: self.dtype(), to_subtract.dtype()) } @@ -204,17 +204,19 @@ impl ScalarSubtractFn for PrimitiveArray<'_> { let result = match to_subtract.dtype() { DType::Int(..) => { match_each_integer_ptype!(self.ptype(), |$T| { - let to_subtract = >::try_into(to_subtract)?; - let maybe_min = self.statistics().compute_as_cast(Stat::Min);//.unwrap_or($T::MAX); + let to_subtract = $T::try_from(to_subtract)?; + let maybe_min = self.statistics().compute_as_cast(Stat::Min); - if maybe_min.is_some() { - let min: $T = maybe_min.unwrap(); - let max: $T = self.statistics().compute_as_cast(Stat::Max).unwrap(); + if let Some(min) = maybe_min { + let min: $T = min; if let (min, true) = min.overflowing_sub(to_subtract) { vortex_bail!("Integer subtraction over/underflow: {}, {}", min, to_subtract) } - if let (max, true) = max.overflowing_sub(to_subtract) { + if let Some(max) = self.statistics().compute_as_cast(Stat::Max) { + let max: $T = max; + if let (max, true) = max.overflowing_sub(to_subtract) { vortex_bail!("Integer subtraction over/underflow: {}, {}", max, to_subtract) + } } } let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec(); @@ -222,8 +224,8 @@ impl ScalarSubtractFn for PrimitiveArray<'_> { }) } DType::Decimal(..) | DType::Float(..) => { - match_each_native_ptype!(self.ptype(), |$T| { - let to_subtract = >::try_into(to_subtract)?; + match_each_float_ptype!(self.ptype(), |$T| { + let to_subtract = $T::try_from(to_subtract)?; let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec(); PrimitiveArray::from(sub_vec) }) @@ -234,3 +236,114 @@ impl ScalarSubtractFn for PrimitiveArray<'_> { Ok(result.into_array()) } } + +#[cfg(test)] +mod test { + use crate::compute::scalar_subtract::scalar_subtract; + use crate::IntoArray; + + #[test] + fn test_scalar_subtract_unsigned() { + let values = vec![1u16, 2, 3].into_array(); + let results = scalar_subtract(&values, 1u16) + .unwrap() + .flatten_primitive() + .unwrap() + .typed_data::() + .to_vec(); + assert_eq!(results, &[0u16, 1, 2]); + } + + #[test] + fn test_scalar_subtract_signed() { + let values = vec![1i64, 2, 3].into_array(); + let results = scalar_subtract(&values, -1i64) + .unwrap() + .flatten_primitive() + .unwrap() + .typed_data::() + .to_vec(); + assert_eq!(results, &[2i64, 3, 4]); + } + + #[test] + fn test_scalar_subtract_float() { + let values = vec![1.0f64, 2.0, 3.0].into_array(); + let to_subtract = -1f64; + let results = scalar_subtract(&values, to_subtract) + .unwrap() + .flatten_primitive() + .unwrap() + .typed_data::() + .to_vec(); + assert_eq!(results, &[2.0f64, 3.0, 4.0]); + } + + #[test] + fn test_scalar_subtract_int_from_float() { + let values = vec![3.0f64, 4.0, 5.0].into_array(); + // Ints can be cast to floats, so there's no problem here + let results = scalar_subtract(&values, 1u64) + .unwrap() + .flatten_primitive() + .unwrap() + .typed_data::() + .to_vec(); + assert_eq!(results, &[2.0f64, 3.0, 4.0]); + } + + #[test] + fn test_scalar_subtract_unsigned_underflow() { + let values = vec![u8::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1u8).expect_err("should fail with underflow"); + let values = vec![u16::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1u16).expect_err("should fail with underflow"); + let values = vec![u32::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1u32).expect_err("should fail with underflow"); + let values = vec![u64::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1u64).expect_err("should fail with underflow"); + } + + #[test] + fn test_scalar_subtract_signed_overflow() { + let values = vec![i8::MAX, 2, 3].into_array(); + let to_subtract = -1i8; + let _results = + scalar_subtract(&values, to_subtract).expect_err("should fail with overflow"); + let values = vec![i16::MAX, 2, 3].into_array(); + let _results = + scalar_subtract(&values, to_subtract).expect_err("should fail with overflow"); + let values = vec![i32::MAX, 2, 3].into_array(); + let _results = + scalar_subtract(&values, to_subtract).expect_err("should fail with overflow"); + let values = vec![i64::MAX, 2, 3].into_array(); + let _results = + scalar_subtract(&values, to_subtract).expect_err("should fail with overflow"); + } + + #[test] + fn test_scalar_subtract_signed_underflow() { + let values = vec![i8::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1i8).expect_err("should fail with underflow"); + let values = vec![i16::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1i16).expect_err("should fail with underflow"); + let values = vec![i32::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1i32).expect_err("should fail with underflow"); + let values = vec![i64::MIN, 2, 3].into_array(); + let _results = scalar_subtract(&values, 1i64).expect_err("should fail with underflow"); + } + + #[test] + fn test_scalar_subtract_float_underflow_is_ok() { + let values = vec![f32::MIN, 2.0, 3.0].into_array(); + let _results = scalar_subtract(&values, 1.0f32).unwrap(); + let _results = scalar_subtract(&values, f32::MAX).unwrap(); + } + + #[test] + fn test_scalar_subtract_type_mismatch_fails() { + let values = vec![1u64, 2, 3].into_array(); + // Subtracting incompatible dtypes should fail + let _results = scalar_subtract(&values, 1.5f64).expect_err("Expected type mismatch error"); + } +} diff --git a/vortex-array/src/compute/scalar_subtract.rs b/vortex-array/src/compute/scalar_subtract.rs index 5aa1729af0..41b0c3239f 100644 --- a/vortex-array/src/compute/scalar_subtract.rs +++ b/vortex-array/src/compute/scalar_subtract.rs @@ -1,131 +1,42 @@ use vortex_error::{vortex_err, VortexResult}; +use vortex_schema::DType; use crate::scalar::Scalar; -use crate::{Array, OwnedArray}; +use crate::{Array, ArrayDType, Flattened, OwnedArray}; pub trait ScalarSubtractFn { - fn scalar_subtract(&self, to_subtract: Scalar) -> VortexResult; + fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult; } -pub fn scalar_subtract(array: &Array, to_subtract: Scalar) -> VortexResult { - array.with_dyn(|c| { - let option = c - .scalar_subtract() - .map(|t| t.scalar_subtract(to_subtract.clone())); - option.unwrap_or_else(|| { +pub fn scalar_subtract>(array: &Array, to_subtract: T) -> VortexResult { + let to_subtract = to_subtract.into().cast(array.dtype())?; + + if let Some(subtraction_result) = + array.with_dyn(|c| c.scalar_subtract().map(|t| t.scalar_subtract(&to_subtract))) + { + subtraction_result + } else { + // if subtraction is not implemented for the given array type, but the array has a numeric + // DType, we can flatten the array and apply subtraction to the flattened primitive array + let result = match array.dtype() { + DType::Int(..) | DType::Decimal(..) | DType::Float(..) => { + let array = array.clone(); + let result = array.flatten()?; + + if let Flattened::Primitive(flat) = result { + Some(flat.scalar_subtract(&to_subtract)) + } else { + None + } + } + _ => None, + }; + + result.unwrap_or_else(|| { Err(vortex_err!( NotImplemented: "scalar_subtract", array.encoding().id().name() )) }) - }) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::IntoArray; - - #[test] - fn test_scalar_subtract_unsigned() { - let values = vec![1u16, 2, 3].into_array(); - let results = scalar_subtract(&values, 1u16.into()) - .unwrap() - .flatten_primitive() - .unwrap() - .typed_data::() - .to_vec(); - assert_eq!(results, &[0u16, 1, 2]); - } - - #[test] - fn test_scalar_subtract_signed() { - let values = vec![1i64, 2, 3].into_array(); - let to_subtract = -1i64; - let results = scalar_subtract(&values, to_subtract.into()) - .unwrap() - .flatten_primitive() - .unwrap() - .typed_data::() - .to_vec(); - assert_eq!(results, &[2i64, 3, 4]); - } - - #[test] - fn test_scalar_subtract_float() { - let values = vec![1.0f64, 2.0, 3.0].into_array(); - let to_subtract = -1f64; - let results = scalar_subtract(&values, to_subtract.into()) - .unwrap() - .flatten_primitive() - .unwrap() - .typed_data::() - .to_vec(); - assert_eq!(results, &[2.0f64, 3.0, 4.0]); - } - - #[test] - fn test_scalar_subtract_unsigned_underflow() { - let values = vec![u8::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1u8.into()).expect_err("should fail with underflow"); - let values = vec![u16::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1u16.into()).expect_err("should fail with underflow"); - let values = vec![u32::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1u32.into()).expect_err("should fail with underflow"); - let values = vec![u64::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1u64.into()).expect_err("should fail with underflow"); - } - - #[test] - fn test_scalar_subtract_signed_overflow() { - let values = vec![i8::MAX, 2, 3].into_array(); - let to_subtract = -1i8; - let _results = - scalar_subtract(&values, to_subtract.into()).expect_err("should fail with overflow"); - let values = vec![i16::MAX, 2, 3].into_array(); - let _results = - scalar_subtract(&values, to_subtract.into()).expect_err("should fail with overflow"); - let values = vec![i32::MAX, 2, 3].into_array(); - let _results = - scalar_subtract(&values, to_subtract.into()).expect_err("should fail with overflow"); - let values = vec![i64::MAX, 2, 3].into_array(); - let _results = - scalar_subtract(&values, to_subtract.into()).expect_err("should fail with overflow"); - } - - #[test] - fn test_scalar_subtract_signed_underflow() { - let values = vec![i8::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1i8.into()).expect_err("should fail with underflow"); - let values = vec![i16::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1i16.into()).expect_err("should fail with underflow"); - let values = vec![i32::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1i32.into()).expect_err("should fail with underflow"); - let values = vec![i64::MIN, 2, 3].into_array(); - let _results = - scalar_subtract(&values, 1i64.into()).expect_err("should fail with underflow"); - } - - #[test] - fn test_scalar_subtract_float_underflow_is_ok() { - let values = vec![f32::MIN, 2.0, 3.0].into_array(); - let _results = scalar_subtract(&values, 1.0f32.into()).unwrap(); - let _results = scalar_subtract(&values, f32::MAX.into()).unwrap(); - } - - #[test] - fn test_scalar_subtract_type_mismatch_fails() { - let values = vec![1.0f64, 2.0, 3.0].into_array(); - // Subtracting non-equivalent dtypes should fail - let to_subtract = 1u64; - let _results = - scalar_subtract(&values, to_subtract.into()).expect_err("Expected type mismatch error"); } } diff --git a/vortex-array/src/ptype.rs b/vortex-array/src/ptype.rs index 71c02de223..97b1aa07c8 100644 --- a/vortex-array/src/ptype.rs +++ b/vortex-array/src/ptype.rs @@ -139,6 +139,22 @@ macro_rules! match_each_integer_ptype { } pub use match_each_integer_ptype; +#[macro_export] +macro_rules! match_each_float_ptype { + ($self:expr, | $_:tt $enc:ident | $($body:tt)*) => ({ + macro_rules! __with__ {( $_ $enc:ident ) => ( $($body)* )} + use $crate::ptype::PType; + use half::f16; + match $self { + PType::F16 => __with__! { f16 }, + PType::F32 => __with__! { f32 }, + PType::F64 => __with__! { f64 }, + _ => panic!("Unsupported ptype {}", $self), + } + }) +} +pub use match_each_float_ptype; + impl PType { pub const fn is_unsigned_int(self) -> bool { matches!(self, PType::U8 | PType::U16 | PType::U32 | PType::U64) diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index 339da108aa..9b2e5870ff 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -204,7 +204,7 @@ impl<'a, R: Read> StreamArrayReader<'a, R> { let indices_for_batch = slice(indices, left, right)?.flatten_primitive()?; let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| { - indices_for_batch.scalar_subtract(Scalar::from(row_offset as $T))? + indices_for_batch.scalar_subtract(&Scalar::from(row_offset as $T))? }); let from_current_batch = take(&batch, &shifted_arr)?;