diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index 57b86c1924f0..d993d36b8d74 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -71,3 +71,7 @@ harness = false [[bench]] name = "fixed_size_list_array" harness = false + +[[bench]] +name = "decimal_overflow" +harness = false diff --git a/arrow-array/benches/decimal_overflow.rs b/arrow-array/benches/decimal_overflow.rs new file mode 100644 index 000000000000..8f22b4b47c31 --- /dev/null +++ b/arrow-array/benches/decimal_overflow.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; +use arrow_buffer::i256; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let len = 8192; + let mut builder_128 = Decimal128Builder::with_capacity(len); + let mut builder_256 = Decimal256Builder::with_capacity(len); + for i in 0..len { + if i % 10 == 0 { + builder_128.append_value(i128::MAX); + builder_256.append_value(i256::from_i128(i128::MAX)); + } else { + builder_128.append_value(i as i128); + builder_256.append_value(i256::from_i128(i as i128)); + } + } + let array_128 = builder_128.finish(); + let array_256 = builder_256.finish(); + + c.bench_function("validate_decimal_precision_128", |b| { + b.iter(|| black_box(array_128.validate_decimal_precision(8))); + }); + c.bench_function("null_if_overflow_precision_128", |b| { + b.iter(|| black_box(array_128.null_if_overflow_precision(8))); + }); + c.bench_function("validate_decimal_precision_256", |b| { + b.iter(|| black_box(array_256.validate_decimal_precision(8))); + }); + c.bench_function("null_if_overflow_precision_256", |b| { + b.iter(|| black_box(array_256.null_if_overflow_precision(8))); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 521ef088e361..567fa00e7385 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -1570,9 +1570,7 @@ impl PrimitiveArray { /// Validates the Decimal Array, if the value of slot is overflow for the specified precision, and /// will be casted to Null pub fn null_if_overflow_precision(&self, precision: u8) -> Self { - self.unary_opt::<_, T>(|v| { - (T::validate_decimal_precision(v, precision).is_ok()).then_some(v) - }) + self.unary_opt::<_, T>(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) } /// Returns [`Self::value`] formatted as a string diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index b39c9c40311b..92262fc04a57 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -24,7 +24,10 @@ use crate::temporal_conversions::as_datetime_with_timezone; use crate::timezone::Tz; use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; use arrow_buffer::{i256, Buffer, OffsetBuffer}; -use arrow_data::decimal::{validate_decimal256_precision, validate_decimal_precision}; +use arrow_data::decimal::{ + is_validate_decimal256_precision, is_validate_decimal_precision, validate_decimal256_precision, + validate_decimal_precision, +}; use arrow_data::{validate_binary_view, validate_string_view}; use arrow_schema::{ ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, @@ -1194,6 +1197,9 @@ pub trait DecimalType: /// Validates that `value` contains no more than `precision` decimal digits fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>; + + /// Determines whether `value` contains no more than `precision` decimal digits + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool; } /// Validate that `precision` and `scale` are valid for `T` @@ -1256,6 +1262,10 @@ impl DecimalType for Decimal128Type { fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> { validate_decimal_precision(num, precision) } + + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { + is_validate_decimal_precision(value, precision) + } } impl ArrowPrimitiveType for Decimal128Type { @@ -1286,6 +1296,10 @@ impl DecimalType for Decimal256Type { fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> { validate_decimal256_precision(num, precision) } + + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { + is_validate_decimal256_precision(value, precision) + } } impl ArrowPrimitiveType for Decimal256Type { diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 600f868a3e01..637cbc417008 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -336,11 +336,7 @@ where if cast_options.safe { let iter = from.iter().map(|v| { v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) - .and_then(|v| { - T::validate_decimal_precision(v, precision) - .is_ok() - .then_some(v) - }) + .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) }); // Benefit: // 20% performance improvement @@ -430,7 +426,7 @@ where (mul * v.as_()) .round() .to_i128() - .filter(|v| Decimal128Type::validate_decimal_precision(*v, precision).is_ok()) + .filter(|v| Decimal128Type::is_valid_decimal_precision(*v, precision)) }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) @@ -473,7 +469,7 @@ where array .unary_opt::<_, Decimal256Type>(|v| { i256::from_f64((v.as_() * mul).round()) - .filter(|v| Decimal256Type::validate_decimal_precision(*v, precision).is_ok()) + .filter(|v| Decimal256Type::is_valid_decimal_precision(*v, precision)) }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index e80d497c8cba..25ef243e18e4 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -327,9 +327,10 @@ where let array = if scale < 0 { match cast_options.safe { true => array.unary_opt::<_, D>(|v| { - v.as_().div_checked(scale_factor).ok().and_then(|v| { - (D::validate_decimal_precision(v, precision).is_ok()).then_some(v) - }) + v.as_() + .div_checked(scale_factor) + .ok() + .and_then(|v| (D::is_valid_decimal_precision(v, precision)).then_some(v)) }), false => array.try_unary::<_, D, _>(|v| { v.as_() @@ -340,9 +341,10 @@ where } else { match cast_options.safe { true => array.unary_opt::<_, D>(|v| { - v.as_().mul_checked(scale_factor).ok().and_then(|v| { - (D::validate_decimal_precision(v, precision).is_ok()).then_some(v) - }) + v.as_() + .mul_checked(scale_factor) + .ok() + .and_then(|v| (D::is_valid_decimal_precision(v, precision)).then_some(v)) }), false => array.try_unary::<_, D, _>(|v| { v.as_() diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs index 74279bfb9af1..d9028591aaaa 100644 --- a/arrow-data/src/decimal.rs +++ b/arrow-data/src/decimal.rs @@ -23,10 +23,13 @@ pub use arrow_schema::{ DECIMAL_DEFAULT_SCALE, }; -// MAX decimal256 value of little-endian format for each precision. -// Each element is the max value of signed 256-bit integer for the specified precision which -// is encoded to the 32-byte width format of little-endian. -pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ +/// MAX decimal256 value of little-endian format for each precision. +/// Each element is the max value of signed 256-bit integer for the specified precision which +/// is encoded to the 32-byte width format of little-endian. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 77] = [ + i256::from_i128(0_i128), // unused first element i256::from_le_bytes([ 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -333,10 +336,13 @@ pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ ]), ]; -// MIN decimal256 value of little-endian format for each precision. -// Each element is the min value of signed 256-bit integer for the specified precision which -// is encoded to the 76-byte width format of little-endian. -pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ +/// MIN decimal256 value of little-endian format for each precision. +/// Each element is the min value of signed 256-bit integer for the specified precision which +/// is encoded to the 76-byte width format of little-endian. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 77] = [ + i256::from_i128(0_i128), // unused first element i256::from_le_bytes([ 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, @@ -643,8 +649,9 @@ pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ ]), ]; -/// `MAX_DECIMAL_FOR_EACH_PRECISION[p]` holds the maximum `i128` value that can +/// `MAX_DECIMAL_FOR_EACH_PRECISION[p-1]` holds the maximum `i128` value that can /// be stored in [arrow_schema::DataType::Decimal128] value of precision `p` +#[allow(dead_code)] // no longer used but is part of our public API pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ 9, 99, @@ -686,8 +693,9 @@ pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ 99999999999999999999999999999999999999, ]; -/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value that can +/// `MIN_DECIMAL_FOR_EACH_PRECISION[p-1]` holds the minimum `i128` value that can /// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p` +#[allow(dead_code)] // no longer used but is part of our public API pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -9, -99, @@ -729,6 +737,98 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -99999999999999999999999999999999999999, ]; +/// `MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[p]` holds the maximum `i128` value that can +/// be stored in [arrow_schema::DataType::Decimal128] value of precision `p`. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED: [i128; 39] = [ + 0, // unused first element + 9, + 99, + 999, + 9999, + 99999, + 999999, + 9999999, + 99999999, + 999999999, + 9999999999, + 99999999999, + 999999999999, + 9999999999999, + 99999999999999, + 999999999999999, + 9999999999999999, + 99999999999999999, + 999999999999999999, + 9999999999999999999, + 99999999999999999999, + 999999999999999999999, + 9999999999999999999999, + 99999999999999999999999, + 999999999999999999999999, + 9999999999999999999999999, + 99999999999999999999999999, + 999999999999999999999999999, + 9999999999999999999999999999, + 99999999999999999999999999999, + 999999999999999999999999999999, + 9999999999999999999999999999999, + 99999999999999999999999999999999, + 999999999999999999999999999999999, + 9999999999999999999999999999999999, + 99999999999999999999999999999999999, + 999999999999999999999999999999999999, + 9999999999999999999999999999999999999, + 99999999999999999999999999999999999999, +]; + +/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value that can +/// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p`. +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +pub(crate) const MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED: [i128; 39] = [ + 0, // unused first element + -9, + -99, + -999, + -9999, + -99999, + -999999, + -9999999, + -99999999, + -999999999, + -9999999999, + -99999999999, + -999999999999, + -9999999999999, + -99999999999999, + -999999999999999, + -9999999999999999, + -99999999999999999, + -999999999999999999, + -9999999999999999999, + -99999999999999999999, + -999999999999999999999, + -9999999999999999999999, + -99999999999999999999999, + -999999999999999999999999, + -9999999999999999999999999, + -99999999999999999999999999, + -999999999999999999999999999, + -9999999999999999999999999999, + -99999999999999999999999999999, + -999999999999999999999999999999, + -9999999999999999999999999999999, + -99999999999999999999999999999999, + -999999999999999999999999999999999, + -9999999999999999999999999999999999, + -99999999999999999999999999999999999, + -999999999999999999999999999999999999, + -9999999999999999999999999999999999999, + -99999999999999999999999999999999999999, +]; + /// Validates that the specified `i128` value can be properly /// interpreted as a Decimal number with precision `precision` #[inline] @@ -738,23 +838,30 @@ pub fn validate_decimal_precision(value: i128, precision: u8) -> Result<(), Arro "Max precision of a Decimal128 is {DECIMAL128_MAX_PRECISION}, but got {precision}", ))); } - - let max = MAX_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1]; - let min = MIN_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1]; - - if value > max { + if value > MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] { Err(ArrowError::InvalidArgumentError(format!( - "{value} is too large to store in a Decimal128 of precision {precision}. Max is {max}" + "{value} is too large to store in a Decimal128 of precision {precision}. Max is {}", + MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] ))) - } else if value < min { + } else if value < MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] { Err(ArrowError::InvalidArgumentError(format!( - "{value} is too small to store in a Decimal128 of precision {precision}. Min is {min}" + "{value} is too small to store in a Decimal128 of precision {precision}. Min is {}", + MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] ))) } else { Ok(()) } } +/// Determines whether the specified `i128` value can be properly +/// interpreted as a Decimal number with precision `precision` +#[inline] +pub fn is_validate_decimal_precision(value: i128, precision: u8) -> bool { + precision <= DECIMAL128_MAX_PRECISION + && value >= MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] + && value <= MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] +} + /// Validates that the specified `i256` of value can be properly /// interpreted as a Decimal256 number with precision `precision` #[inline] @@ -764,18 +871,26 @@ pub fn validate_decimal256_precision(value: i256, precision: u8) -> Result<(), A "Max precision of a Decimal256 is {DECIMAL256_MAX_PRECISION}, but got {precision}", ))); } - let max = MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1]; - let min = MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1]; - - if value > max { + if value > MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] { Err(ArrowError::InvalidArgumentError(format!( - "{value:?} is too large to store in a Decimal256 of precision {precision}. Max is {max:?}" + "{value:?} is too large to store in a Decimal256 of precision {precision}. Max is {:?}", + MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] ))) - } else if value < min { + } else if value < MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] { Err(ArrowError::InvalidArgumentError(format!( - "{value:?} is too small to store in a Decimal256 of precision {precision}. Min is {min:?}" + "{value:?} is too small to store in a Decimal256 of precision {precision}. Min is {:?}", + MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] ))) } else { Ok(()) } } + +/// Determines whether the specified `i256` value can be properly +/// interpreted as a Decimal256 number with precision `precision` +#[inline] +pub fn is_validate_decimal256_precision(value: i256, precision: u8) -> bool { + precision <= DECIMAL256_MAX_PRECISION + && value >= MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] + && value <= MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] +}