diff --git a/cpp/src/arrow/acero/tpch_benchmark.cc b/cpp/src/arrow/acero/tpch_benchmark.cc index aa621758b351e..ac3b69c9b706f 100644 --- a/cpp/src/arrow/acero/tpch_benchmark.cc +++ b/cpp/src/arrow/acero/tpch_benchmark.cc @@ -58,7 +58,7 @@ std::shared_ptr Plan_Q1(AsyncGenerator>* sink Expression base_price = field_ref("L_EXTENDEDPRICE"); std::shared_ptr decimal_1 = - std::make_shared(Decimal128{0, 100}, decimal(12, 2)); + std::make_shared(Decimal128{0, 100}, decimal128(12, 2)); Expression discount_multiplier = call("subtract", {literal(decimal_1), field_ref("L_DISCOUNT")}); Expression tax_multiplier = call("add", {literal(decimal_1), field_ref("L_TAX")}); @@ -68,7 +68,7 @@ std::shared_ptr Plan_Q1(AsyncGenerator>* sink call("multiply", {call("cast", {call("multiply", {field_ref("L_EXTENDEDPRICE"), discount_multiplier})}, - compute::CastOptions::Unsafe(decimal(12, 2))), + compute::CastOptions::Unsafe(decimal128(12, 2))), tax_multiplier}); Expression discount = field_ref("L_DISCOUNT"); diff --git a/cpp/src/arrow/acero/tpch_node.cc b/cpp/src/arrow/acero/tpch_node.cc index 137b62ad38a95..abc742f9fa10b 100644 --- a/cpp/src/arrow/acero/tpch_node.cc +++ b/cpp/src/arrow/acero/tpch_node.cc @@ -838,12 +838,12 @@ class PartAndPartSupplierGenerator { const std::vector> kPartTypes = { int32(), utf8(), fixed_size_binary(25), fixed_size_binary(10), - utf8(), int32(), fixed_size_binary(10), decimal(12, 2), + utf8(), int32(), fixed_size_binary(10), decimal128(12, 2), utf8(), }; const std::vector> kPartsuppTypes = { - int32(), int32(), int32(), decimal(12, 2), utf8(), + int32(), int32(), int32(), decimal128(12, 2), utf8(), }; Status AllocatePartBatch(size_t thread_index, int column) { @@ -1527,7 +1527,7 @@ class OrdersAndLineItemGenerator { const std::vector> kOrdersTypes = {int32(), int32(), fixed_size_binary(1), - decimal(12, 2), + decimal128(12, 2), date32(), fixed_size_binary(15), fixed_size_binary(15), @@ -1539,10 +1539,10 @@ class OrdersAndLineItemGenerator { int32(), int32(), int32(), - decimal(12, 2), - decimal(12, 2), - decimal(12, 2), - decimal(12, 2), + decimal128(12, 2), + decimal128(12, 2), + decimal128(12, 2), + decimal128(12, 2), fixed_size_binary(1), fixed_size_binary(1), date32(), @@ -2489,7 +2489,7 @@ class SupplierGenerator : public TpchTableGenerator { std::vector> kTypes = { int32(), fixed_size_binary(25), utf8(), - int32(), fixed_size_binary(15), decimal(12, 2), + int32(), fixed_size_binary(15), decimal128(12, 2), utf8(), }; @@ -2872,7 +2872,7 @@ class CustomerGenerator : public TpchTableGenerator { utf8(), int32(), fixed_size_binary(15), - decimal(12, 2), + decimal128(12, 2), fixed_size_binary(10), utf8(), }; diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index 6927f51283eb7..ce2e66655af3d 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -74,6 +74,10 @@ struct ScalarFromArraySlotImpl { return Finish(a.Value(index_)); } + Status Visit(const Decimal32Array& a) { return Finish(Decimal32(a.GetValue(index_))); } + + Status Visit(const Decimal64Array& a) { return Finish(Decimal64(a.GetValue(index_))); } + Status Visit(const Decimal128Array& a) { return Finish(Decimal128(a.GetValue(index_))); } diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index d65f6ee53564f..a2c9cae3451a1 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -32,6 +32,34 @@ namespace arrow { using internal::checked_cast; +// ---------------------------------------------------------------------- +// Decimal32 + +Decimal32Array::Decimal32Array(const std::shared_ptr& data) + : FixedSizeBinaryArray(data) { + ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL32); +} + +std::string Decimal32Array::FormatValue(int64_t i) const { + const auto& type_ = checked_cast(*type()); + const Decimal32 value(GetValue(i)); + return value.ToString(type_.scale()); +} + +// ---------------------------------------------------------------------- +// Decimal64 + +Decimal64Array::Decimal64Array(const std::shared_ptr& data) + : FixedSizeBinaryArray(data) { + ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL64); +} + +std::string Decimal64Array::FormatValue(int64_t i) const { + const auto& type_ = checked_cast(*type()); + const Decimal64 value(GetValue(i)); + return value.ToString(type_.scale()); +} + // ---------------------------------------------------------------------- // Decimal128 diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index f14812549089a..2f10bb8429996 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -32,6 +32,38 @@ namespace arrow { /// /// @{ +// ---------------------------------------------------------------------- +// Decimal32Array + +/// Concrete Array class for 32-bit decimal data +class ARROW_EXPORT Decimal32Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal32Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal32Array from ArrayData instance + explicit Decimal32Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + +// ---------------------------------------------------------------------- +// Decimal64Array + +/// Concrete Array class for 64-bit decimal data +class ARROW_EXPORT Decimal64Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal64Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal64Array from ArrayData instance + explicit Decimal64Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + // ---------------------------------------------------------------------- // Decimal128Array diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 73e0c692432b6..d69e00460dcfc 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -442,7 +442,7 @@ static std::vector> TestArrayUtilitiesAgainstTheseType large_binary(), binary_view(), fixed_size_binary(3), - decimal(16, 4), + decimal128(16, 4), utf8(), large_utf8(), utf8_view(), @@ -667,8 +667,10 @@ static ScalarVector GetScalars() { std::make_shared(hello), std::make_shared( hello, fixed_size_binary(static_cast(hello->size()))), - std::make_shared(Decimal128(10), decimal(16, 4)), - std::make_shared(Decimal256(10), decimal(76, 38)), + std::make_shared(Decimal32(10), smallest_decimal(7, 4)), + std::make_shared(Decimal64(10), smallest_decimal(12, 4)), + std::make_shared(Decimal128(10), smallest_decimal(20, 4)), + std::make_shared(Decimal256(10), smallest_decimal(76, 38)), std::make_shared(hello), std::make_shared(hello), std::make_shared(hello), @@ -3092,6 +3094,98 @@ class DecimalTest : public ::testing::TestWithParam { } }; +using Decimal32Test = DecimalTest; + +TEST_P(Decimal32Test, NoNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal32(1), Decimal32(-2), Decimal32(2389), + Decimal32(4), Decimal32(-12348)}; + std::vector valid_bytes = {true, true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal32Test, WithNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal32(1), Decimal32(2), Decimal32(-1), Decimal32(4), + Decimal32(-1), Decimal32(1), Decimal32(2)}; + Decimal32 big; + ASSERT_OK_AND_ASSIGN(big, Decimal32::FromString("23034.234")); + draw.push_back(big); + + Decimal32 big_negative; + ASSERT_OK_AND_ASSIGN(big_negative, Decimal32::FromString("-23049.235")); + draw.push_back(big_negative); + + std::vector valid_bytes = {true, true, false, true, false, + true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal32Test, ValidateFull) { + int32_t precision = GetParam(); + std::vector draw; + Decimal32 val = Decimal32::GetMaxValue(precision) + 1; + + draw = {Decimal32(), val}; + auto arr = this->TestCreate(precision, draw, {true, false}, 0); + ASSERT_OK(arr->ValidateFull()); + + draw = {val, Decimal32()}; + arr = this->TestCreate(precision, draw, {true, false}, 0); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("does not fit in precision of"), arr->ValidateFull()); +} + +INSTANTIATE_TEST_SUITE_P(Decimal32Test, Decimal32Test, ::testing::Range(1, 9)); + +using Decimal64Test = DecimalTest; + +TEST_P(Decimal64Test, NoNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal64(1), Decimal64(-2), Decimal64(2389), + Decimal64(4), Decimal64(-12348)}; + std::vector valid_bytes = {true, true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal64Test, WithNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal64(1), Decimal64(2), Decimal64(-1), Decimal64(4), + Decimal64(-1), Decimal64(1), Decimal64(2)}; + Decimal64 big; + ASSERT_OK_AND_ASSIGN(big, Decimal64::FromString("23034.234234")); + draw.push_back(big); + + Decimal64 big_negative; + ASSERT_OK_AND_ASSIGN(big_negative, Decimal64::FromString("-23049.235234")); + draw.push_back(big_negative); + + std::vector valid_bytes = {true, true, false, true, false, + true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal64Test, ValidateFull) { + int32_t precision = GetParam(); + std::vector draw; + Decimal64 val = Decimal64::GetMaxValue(precision) + 1; + + draw = {Decimal64(), val}; + auto arr = this->TestCreate(precision, draw, {true, false}, 0); + ASSERT_OK(arr->ValidateFull()); + + draw = {val, Decimal64()}; + arr = this->TestCreate(precision, draw, {true, false}, 0); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("does not fit in precision of"), arr->ValidateFull()); +} + +INSTANTIATE_TEST_SUITE_P(Decimal64Test, Decimal64Test, ::testing::Range(1, 9)); + using Decimal128Test = DecimalTest; TEST_P(Decimal128Test, NoNulls) { @@ -3315,6 +3409,28 @@ TEST(TestSwapEndianArrayData, PrimitiveType) { expected_data = ArrayData::Make(uint64(), 1, {null_buffer, data_int64_buffer}, 0); AssertArrayDataEqualsWithSwapEndian(data, expected_data); + auto data_4byte_buffer = Buffer::FromString( + "\x01" + "12\x01"); + data = ArrayData::Make(decimal32(9, 8), 1, {null_buffer, data_4byte_buffer}); + auto data_decimal32_buffer = Buffer::FromString( + "\x01" + "21\x01"); + expected_data = + ArrayData::Make(decimal32(9, 8), 1, {null_buffer, data_decimal32_buffer}, 0); + AssertArrayDataEqualsWithSwapEndian(data, expected_data); + + auto data_8byte_buffer = Buffer::FromString( + "\x01" + "123456\x01"); + data = ArrayData::Make(decimal64(18, 8), 1, {null_buffer, data_8byte_buffer}); + auto data_decimal64_buffer = Buffer::FromString( + "\x01" + "654321\x01"); + expected_data = + ArrayData::Make(decimal64(18, 8), 1, {null_buffer, data_decimal64_buffer}, 0); + AssertArrayDataEqualsWithSwapEndian(data, expected_data); + auto data_16byte_buffer = Buffer::FromString( "\x01" "123456789abcde\x01"); @@ -3647,6 +3763,8 @@ DataTypeVector SwappableTypes() { uint16(), uint32(), uint64(), + decimal32(8, 1), + decimal64(16, 2), decimal128(19, 4), decimal256(37, 8), timestamp(TimeUnit::MICRO, ""), diff --git a/cpp/src/arrow/array/array_view_test.cc b/cpp/src/arrow/array/array_view_test.cc index 97110ea97f3fc..a8d6d8ffa3e79 100644 --- a/cpp/src/arrow/array/array_view_test.cc +++ b/cpp/src/arrow/array/array_view_test.cc @@ -385,8 +385,32 @@ TEST(TestArrayView, SparseUnionAsStruct) { CheckView(expected, arr); } -TEST(TestArrayView, DecimalRoundTrip) { - auto ty1 = decimal(10, 4); +TEST(TestArrayView, Decimal32RoundTrip) { + auto ty1 = decimal32(9, 4); + auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + + auto ty2 = fixed_size_binary(4); + ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); + ASSERT_OK(v->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1)); + ASSERT_OK(w->ValidateFull()); + AssertArraysEqual(*arr, *w); +} + +TEST(TestArrayView, Decimal64RoundTrip) { + auto ty1 = decimal64(10, 4); + auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + + auto ty2 = fixed_size_binary(8); + ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); + ASSERT_OK(v->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1)); + ASSERT_OK(w->ValidateFull()); + AssertArraysEqual(*arr, *w); +} + +TEST(TestArrayView, Decimal128RoundTrip) { + auto ty1 = decimal128(20, 4); auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); auto ty2 = fixed_size_binary(16); @@ -397,6 +421,18 @@ TEST(TestArrayView, DecimalRoundTrip) { AssertArraysEqual(*arr, *w); } +TEST(TestArrayView, Decimal256RoundTrip) { + auto ty1 = decimal256(10, 4); + auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + + auto ty2 = fixed_size_binary(32); + ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); + ASSERT_OK(v->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1)); + ASSERT_OK(w->ValidateFull()); + AssertArraysEqual(*arr, *w); +} + TEST(TestArrayView, Dictionaries) { // ARROW-6049 auto ty1 = dictionary(int8(), float32()); diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 40e705aa3e440..2e6e1bfd13032 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -119,6 +119,8 @@ struct AppendScalarImpl { } Status Visit(const FixedSizeBinaryType& t) { return HandleFixedWidth(t); } + Status Visit(const Decimal32Type& t) { return HandleFixedWidth(t); } + Status Visit(const Decimal64Type& t) { return HandleFixedWidth(t); } Status Visit(const Decimal128Type& t) { return HandleFixedWidth(t); } Status Visit(const Decimal256Type& t) { return HandleFixedWidth(t); } diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index 3b1262819df7f..868183768c1d1 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -32,6 +32,76 @@ namespace arrow { class Buffer; class MemoryPool; +// ---------------------------------------------------------------------- +// Decimal32Builder + +Decimal32Builder::Decimal32Builder(const std::shared_ptr& type, + MemoryPool* pool, int64_t alignment) + : FixedSizeBinaryBuilder(type, pool, alignment), + decimal_type_(internal::checked_pointer_cast(type)) {} + +Status Decimal32Builder::Append(Decimal32 value) { + RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); + UnsafeAppend(value); + return Status::OK(); +} + +void Decimal32Builder::UnsafeAppend(Decimal32 value) { + value.ToBytes(GetMutableValue(length())); + byte_builder_.UnsafeAdvance(4); + UnsafeAppendToBitmap(true); +} + +void Decimal32Builder::UnsafeAppend(std::string_view value) { + FixedSizeBinaryBuilder::UnsafeAppend(value); +} + +Status Decimal32Builder::FinishInternal(std::shared_ptr* out) { + std::shared_ptr data; + RETURN_NOT_OK(byte_builder_.Finish(&data)); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + + *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); + capacity_ = length_ = null_count_ = 0; + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Decimal64Builder + +Decimal64Builder::Decimal64Builder(const std::shared_ptr& type, + MemoryPool* pool, int64_t alignment) + : FixedSizeBinaryBuilder(type, pool, alignment), + decimal_type_(internal::checked_pointer_cast(type)) {} + +Status Decimal64Builder::Append(Decimal64 value) { + RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); + UnsafeAppend(value); + return Status::OK(); +} + +void Decimal64Builder::UnsafeAppend(Decimal64 value) { + value.ToBytes(GetMutableValue(length())); + byte_builder_.UnsafeAdvance(8); + UnsafeAppendToBitmap(true); +} + +void Decimal64Builder::UnsafeAppend(std::string_view value) { + FixedSizeBinaryBuilder::UnsafeAppend(value); +} + +Status Decimal64Builder::FinishInternal(std::shared_ptr* out) { + std::shared_ptr data; + RETURN_NOT_OK(byte_builder_.Finish(&data)); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + + *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); + capacity_ = length_ = null_count_ = 0; + return Status::OK(); +} + // ---------------------------------------------------------------------- // Decimal128Builder diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 8094250aef8d4..a0bf0a0422084 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -33,6 +33,68 @@ namespace arrow { /// /// @{ +class ARROW_EXPORT Decimal32Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal32Type; + using ValueType = Decimal32; + + explicit Decimal32Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal32 val); + void UnsafeAppend(Decimal32 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + +class ARROW_EXPORT Decimal64Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal64Type; + using ValueType = Decimal64; + + explicit Decimal64Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal64 val); + void UnsafeAppend(Decimal64 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { public: using TypeClass = Decimal128Type; diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 3f0d711dc5bb5..116c82049eea9 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -298,20 +298,11 @@ class DictionaryBuilderBase : public ArrayBuilder { return Append(std::string_view(value, length)); } - /// \brief Append a decimal (only for Decimal128Type) - template - enable_if_decimal128 Append(const Decimal128& value) { - uint8_t data[16]; - value.ToBytes(data); - return Append(data, 16); - } - - /// \brief Append a decimal (only for Decimal128Type) - template - enable_if_decimal256 Append(const Decimal256& value) { - uint8_t data[32]; - value.ToBytes(data); - return Append(data, 32); + /// \brief Append a decimal (only for Decimal32/64/128/256 Type) + template ::CType> + enable_if_decimal Append(const CType& value) { + auto bytes = value.ToBytes(); + return Append(bytes.data(), static_cast(bytes.size())); } /// \brief Append a scalar null value diff --git a/cpp/src/arrow/array/concatenate.cc b/cpp/src/arrow/array/concatenate.cc index b4638dd6593d8..d8a69868d1543 100644 --- a/cpp/src/arrow/array/concatenate.cc +++ b/cpp/src/arrow/array/concatenate.cc @@ -377,7 +377,7 @@ class ConcatenateImpl { } Status Visit(const FixedWidthType& fixed) { - // Handles numbers, decimal128, decimal256, fixed_size_binary + // Handles numbers, decimal32, decimal64, decimal128, decimal256, fixed_size_binary ARROW_ASSIGN_OR_RAISE(auto buffers, Buffers(1, fixed)); return ConcatenateBuffers(buffers, pool_).Value(&out_->buffers[1]); } diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc index f9714eda34c61..3e36a971578d5 100644 --- a/cpp/src/arrow/array/diff.cc +++ b/cpp/src/arrow/array/diff.cc @@ -707,11 +707,9 @@ class MakeFormatterImpl { template enable_if_decimal Visit(const T&) { impl_ = [](const Array& array, int64_t index, std::ostream* os) { - if constexpr (T::type_id == Type::DECIMAL128) { - *os << checked_cast(array).FormatValue(index); - } else { - *os << checked_cast(array).FormatValue(index); - } + const auto& decimal_array = + checked_cast::ArrayType&>(array); + *os << decimal_array.FormatValue(index); }; return Status::OK(); } diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc index 145978a91ad54..02bcf5bbb4c5b 100644 --- a/cpp/src/arrow/array/diff_test.cc +++ b/cpp/src/arrow/array/diff_test.cc @@ -707,6 +707,8 @@ TEST_F(DiffTest, UnifiedDiffFormatter) { } for (const auto& type : { + decimal32(8, 4), + decimal64(10, 4), decimal128(10, 4), decimal256(10, 4), }) { diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index b56ea25f9e421..51c27b2d9719f 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -152,57 +152,20 @@ class ArrayDataEndianSwapper { return Status::OK(); } - Status Visit(const Decimal128Type& type) { - auto data = reinterpret_cast(data_->buffers[1]->data()); + template + enable_if_decimal Visit(const T& type) { + using value_type = typename TypeTraits::CType; + auto data = data_->buffers[1]->span_as(); ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size(), pool_)); - auto new_data = reinterpret_cast(new_buffer->mutable_data()); - // NOTE: data_->length not trusted (see warning above) - const int64_t length = data_->buffers[1]->size() / Decimal128Type::kByteWidth; - for (int64_t i = 0; i < length; i++) { - uint64_t tmp; - auto idx = i * 2; -#if ARROW_LITTLE_ENDIAN - tmp = bit_util::FromBigEndian(data[idx]); - new_data[idx] = bit_util::FromBigEndian(data[idx + 1]); - new_data[idx + 1] = tmp; -#else - tmp = bit_util::FromLittleEndian(data[idx]); - new_data[idx] = bit_util::FromLittleEndian(data[idx + 1]); - new_data[idx + 1] = tmp; -#endif - } - out_->buffers[1] = std::move(new_buffer); - return Status::OK(); - } + auto new_data = new_buffer->mutable_data_as(); - Status Visit(const Decimal256Type& type) { - auto data = reinterpret_cast(data_->buffers[1]->data()); - ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size())); - auto new_data = reinterpret_cast(new_buffer->mutable_data()); - // NOTE: data_->length not trusted (see warning above) - const int64_t length = data_->buffers[1]->size() / Decimal256Type::kByteWidth; - for (int64_t i = 0; i < length; i++) { - uint64_t tmp0, tmp1, tmp2; - auto idx = i * 4; -#if ARROW_LITTLE_ENDIAN - tmp0 = bit_util::FromBigEndian(data[idx]); - tmp1 = bit_util::FromBigEndian(data[idx + 1]); - tmp2 = bit_util::FromBigEndian(data[idx + 2]); - new_data[idx] = bit_util::FromBigEndian(data[idx + 3]); - new_data[idx + 1] = tmp2; - new_data[idx + 2] = tmp1; - new_data[idx + 3] = tmp0; -#else - tmp0 = bit_util::FromLittleEndian(data[idx]); - tmp1 = bit_util::FromLittleEndian(data[idx + 1]); - tmp2 = bit_util::FromLittleEndian(data[idx + 2]); - new_data[idx] = bit_util::FromLittleEndian(data[idx + 3]); - new_data[idx + 1] = tmp2; - new_data[idx + 2] = tmp1; - new_data[idx + 3] = tmp0; -#endif + for (const value_type& v : data) { + auto bytes = v.ToBytes(); + std::reverse(bytes.begin(), bytes.end()); + memcpy(new_data++, bytes.data(), bytes.size()); } + out_->buffers[1] = std::move(new_buffer); return Status::OK(); } diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 69f1646054f4c..5e466dfa9b2f2 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -144,6 +144,16 @@ struct ValidateArrayImpl { Status Visit(const FixedWidthType&) { return ValidateFixedWidthBuffers(); } + Status Visit(const Decimal32Type& type) { + RETURN_NOT_OK(ValidateFixedWidthBuffers()); + return ValidateDecimals(type); + } + + Status Visit(const Decimal64Type& type) { + RETURN_NOT_OK(ValidateFixedWidthBuffers()); + return ValidateDecimals(type); + } + Status Visit(const Decimal128Type& type) { RETURN_NOT_OK(ValidateFixedWidthBuffers()); return ValidateDecimals(type); diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 7042d9818c691..46969e73e22ae 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -151,6 +151,8 @@ struct DictionaryBuilderCase { Status Visit(const BinaryViewType&) { return CreateFor(); } Status Visit(const StringViewType&) { return CreateFor(); } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } + Status Visit(const Decimal32Type&) { return CreateFor(); } + Status Visit(const Decimal64Type&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } Status Visit(const Decimal256Type&) { return CreateFor(); } diff --git a/cpp/src/arrow/builder_benchmark.cc b/cpp/src/arrow/builder_benchmark.cc index 8ec7373a1de1f..3564f0309b756 100644 --- a/cpp/src/arrow/builder_benchmark.cc +++ b/cpp/src/arrow/builder_benchmark.cc @@ -228,7 +228,7 @@ static void BuildFixedSizeBinaryArray( } static void BuildDecimalArray(benchmark::State& state) { // NOLINT non-const reference - auto type = decimal(10, 5); + auto type = decimal128(10, 5); Decimal128 value; int32_t precision = 0; int32_t scale = 0; diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index eba575f4cf39c..4f9095182f90c 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -1249,13 +1249,20 @@ struct SchemaImporter { if (prec_scale[0] <= 0) { return f_parser_.Invalid(); } - if (prec_scale.size() == 2 || prec_scale[2] == 128) { + if (prec_scale.size() == 2) { + type_ = decimal128(prec_scale[0], prec_scale[1]); + } else if (prec_scale[2] == 32) { + type_ = decimal32(prec_scale[0], prec_scale[1]); + } else if (prec_scale[2] == 64) { + type_ = decimal64(prec_scale[0], prec_scale[1]); + } else if (prec_scale[2] == 128) { type_ = decimal128(prec_scale[0], prec_scale[1]); } else if (prec_scale[2] == 256) { type_ = decimal256(prec_scale[0], prec_scale[1]); } else { return f_parser_.Invalid(); } + return Status::OK(); } diff --git a/cpp/src/arrow/c/bridge_benchmark.cc b/cpp/src/arrow/c/bridge_benchmark.cc index 1ae4657fc9c0c..cc8a3cb1829c6 100644 --- a/cpp/src/arrow/c/bridge_benchmark.cc +++ b/cpp/src/arrow/c/bridge_benchmark.cc @@ -39,7 +39,7 @@ std::shared_ptr ExampleSchema() { auto f5 = field("f5", float32()); auto f6 = field("f6", float32()); auto f7 = field("f7", float32()); - auto f8 = field("f8", decimal(19, 10)); + auto f8 = field("f8", decimal128(19, 10)); return schema({f0, f1, f2, f3, f4, f5, f6, f7, f8}); } diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 01fd56f631d99..fdcb53ddbcfb5 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -363,13 +363,19 @@ TEST_F(TestSchemaExport, Primitive) { TestPrimitive(binary_view(), "vz"); TestPrimitive(utf8_view(), "vu"); - TestPrimitive(decimal(16, 4), "d:16,4"); + TestPrimitive(smallest_decimal(8, 4), "d:8,4,32"); + TestPrimitive(smallest_decimal(16, 4), "d:16,4,64"); + TestPrimitive(decimal128(16, 4), "d:16,4"); TestPrimitive(decimal256(16, 4), "d:16,4,256"); - TestPrimitive(decimal(15, 0), "d:15,0"); + TestPrimitive(smallest_decimal(8, 0), "d:8,0,32"); + TestPrimitive(smallest_decimal(15, 0), "d:15,0,64"); + TestPrimitive(decimal128(15, 0), "d:15,0"); TestPrimitive(decimal256(15, 0), "d:15,0,256"); - TestPrimitive(decimal(15, -4), "d:15,-4"); + TestPrimitive(smallest_decimal(8, -4), "d:8,-4,32"); + TestPrimitive(smallest_decimal(15, -4), "d:15,-4,64"); + TestPrimitive(decimal128(15, -4), "d:15,-4"); TestPrimitive(decimal256(15, -4), "d:15,-4,256"); } @@ -906,7 +912,9 @@ TEST_F(TestArrayExport, Primitive) { TestPrimitive(binary_view(), R"(["foo", "bar", null])"); TestPrimitive(utf8_view(), R"(["foo", "bar", null])"); - TestPrimitive(decimal(16, 4), R"(["1234.5670", null])"); + TestPrimitive(decimal32(9, 4), R"(["1234.5670", null])"); + TestPrimitive(decimal64(16, 4), R"(["1234.5670", null])"); + TestPrimitive(decimal128(16, 4), R"(["1234.5670", null])"); TestPrimitive(decimal256(16, 4), R"(["1234.5670", null])"); TestPrimitive(month_day_nano_interval(), R"([[-1, 5, 20], null])"); @@ -1501,7 +1509,9 @@ TEST_F(TestDeviceArrayExport, Primitive) { TestPrimitive(mm, utf8(), R"(["foo", "bar", null])"); TestPrimitive(mm, large_utf8(), R"(["foo", "bar", null])"); - TestPrimitive(mm, decimal(16, 4), R"(["1234.5670", null])"); + TestPrimitive(mm, decimal32(9, 4), R"(["1234.5670", null])"); + TestPrimitive(mm, decimal64(16, 4), R"(["1234.5670", null])"); + TestPrimitive(mm, decimal128(16, 4), R"(["1234.5670", null])"); TestPrimitive(mm, decimal256(16, 4), R"(["1234.5670", null])"); TestPrimitive(mm, month_day_nano_interval(), R"([[-1, 5, 20], null])"); @@ -1951,6 +1961,10 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", decimal128(16, 4))); FillPrimitive("d:16,4,256"); CheckImport(field("", decimal256(16, 4))); + FillPrimitive("d:4,4,32"); + CheckImport(field("", decimal32(4, 4))); + FillPrimitive("d:16,4,64"); + CheckImport(field("", decimal64(16, 4))); FillPrimitive("d:16,0"); CheckImport(field("", decimal128(16, 0))); @@ -1958,6 +1972,10 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", decimal128(16, 0))); FillPrimitive("d:16,0,256"); CheckImport(field("", decimal256(16, 0))); + FillPrimitive("d:4,0,32"); + CheckImport(field("", decimal32(4, 0))); + FillPrimitive("d:16,0,64"); + CheckImport(field("", decimal64(16, 0))); FillPrimitive("d:16,-4"); CheckImport(field("", decimal128(16, -4))); @@ -1965,6 +1983,10 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", decimal128(16, -4))); FillPrimitive("d:16,-4,256"); CheckImport(field("", decimal256(16, -4))); + FillPrimitive("d:4,-4,32"); + CheckImport(field("", decimal32(4, -4))); + FillPrimitive("d:16,-4,64"); + CheckImport(field("", decimal64(16, -4))); } TEST_F(TestSchemaImport, Temporal) { @@ -2034,7 +2056,7 @@ TEST_F(TestSchemaImport, String) { FillPrimitive("w:3"); CheckImport(fixed_size_binary(3)); FillPrimitive("d:15,4"); - CheckImport(decimal(15, 4)); + CheckImport(decimal128(15, 4)); } TEST_F(TestSchemaImport, List) { @@ -2950,26 +2972,26 @@ TEST_F(TestArrayImport, FixedSizeBinary) { FillPrimitive(2, 0, 0, primitive_buffers_no_nulls2); CheckImport(ArrayFromJSON(fixed_size_binary(3), R"(["abc", "def"])")); FillPrimitive(2, 0, 0, primitive_buffers_no_nulls3); - CheckImport(ArrayFromJSON(decimal(15, 4), R"(["12345.6789", "98765.4321"])")); + CheckImport(ArrayFromJSON(decimal128(15, 4), R"(["12345.6789", "98765.4321"])")); // Empty array with null data pointers FillPrimitive(0, 0, 0, all_buffers_omitted); CheckImport(ArrayFromJSON(fixed_size_binary(3), "[]")); FillPrimitive(0, 0, 0, all_buffers_omitted); - CheckImport(ArrayFromJSON(decimal(15, 4), "[]")); + CheckImport(ArrayFromJSON(decimal128(15, 4), "[]")); } TEST_F(TestArrayImport, FixedSizeBinaryWithOffset) { FillPrimitive(1, 0, 1, primitive_buffers_no_nulls2); CheckImport(ArrayFromJSON(fixed_size_binary(3), R"(["def"])")); FillPrimitive(1, 0, 1, primitive_buffers_no_nulls3); - CheckImport(ArrayFromJSON(decimal(15, 4), R"(["98765.4321"])")); + CheckImport(ArrayFromJSON(decimal128(15, 4), R"(["98765.4321"])")); // Empty array with null data pointers FillPrimitive(0, 0, 1, all_buffers_omitted); CheckImport(ArrayFromJSON(fixed_size_binary(3), "[]")); FillPrimitive(0, 0, 1, all_buffers_omitted); - CheckImport(ArrayFromJSON(decimal(15, 4), "[]")); + CheckImport(ArrayFromJSON(decimal128(15, 4), "[]")); } TEST_F(TestArrayImport, List) { @@ -3624,10 +3646,16 @@ TEST_F(TestSchemaRoundtrip, Primitive) { TestWithTypeFactory(boolean); TestWithTypeFactory(float16); + TestWithTypeFactory([] { return decimal32(8, 4); }); + TestWithTypeFactory([] { return decimal64(16, 4); }); TestWithTypeFactory([] { return decimal128(19, 4); }); TestWithTypeFactory([] { return decimal256(19, 4); }); + TestWithTypeFactory([] { return decimal32(8, 0); }); + TestWithTypeFactory([] { return decimal64(16, 0); }); TestWithTypeFactory([] { return decimal128(19, 0); }); TestWithTypeFactory([] { return decimal256(19, 0); }); + TestWithTypeFactory([] { return decimal32(8, -5); }); + TestWithTypeFactory([] { return decimal64(16, -5); }); TestWithTypeFactory([] { return decimal128(19, -5); }); TestWithTypeFactory([] { return decimal256(19, -5); }); TestWithTypeFactory([] { return fixed_size_binary(3); }); @@ -3661,7 +3689,7 @@ TEST_F(TestSchemaRoundtrip, ListView) { TEST_F(TestSchemaRoundtrip, Struct) { auto f1 = field("f1", utf8(), /*nullable=*/false); - auto f2 = field("f2", list(decimal(19, 4))); + auto f2 = field("f2", list(decimal128(19, 4))); TestWithTypeFactory([&]() { return struct_({f1, f2}); }); f2 = f2->WithMetadata(key_value_metadata(kMetadataKeys2, kMetadataValues2)); @@ -3671,7 +3699,7 @@ TEST_F(TestSchemaRoundtrip, Struct) { TEST_F(TestSchemaRoundtrip, Union) { auto f1 = field("f1", utf8(), /*nullable=*/false); - auto f2 = field("f2", list(decimal(19, 4))); + auto f2 = field("f2", list(decimal128(19, 4))); auto type_codes = std::vector{42, 43}; TestWithTypeFactory( @@ -3901,6 +3929,8 @@ TEST_F(TestArrayRoundtrip, Primitive) { TestWithJSON(int32(), "[]"); TestWithJSON(int32(), "[4, 5, null]"); + TestWithJSON(decimal32(8, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSON(decimal64(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSON(decimal128(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSON(decimal256(16, 4), R"(["0.4759", "1234.5670", null])"); @@ -3908,6 +3938,8 @@ TEST_F(TestArrayRoundtrip, Primitive) { TestWithJSONSliced(int32(), "[4, 5]"); TestWithJSONSliced(int32(), "[4, 5, 6, null]"); + TestWithJSONSliced(decimal32(8, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSONSliced(decimal64(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSONSliced(decimal128(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSONSliced(decimal256(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSONSliced(month_day_nano_interval(), diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index e983b47e39dc4..23a921cc5a0a4 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -750,15 +750,10 @@ class TypeEqualsVisitor { return Status::OK(); } - Status Visit(const Decimal128Type& left) { - const auto& right = checked_cast(right_); - result_ = left.precision() == right.precision() && left.scale() == right.scale(); - return Status::OK(); - } - - Status Visit(const Decimal256Type& left) { - const auto& right = checked_cast(right_); - result_ = left.precision() == right.precision() && left.scale() == right.scale(); + Status Visit(const DecimalType& left) { + const auto& right = checked_cast(right_); + result_ = left.byte_width() == right.byte_width() && + left.precision() == right.precision() && left.scale() == right.scale(); return Status::OK(); } @@ -900,6 +895,18 @@ class ScalarEqualsVisitor { return Status::OK(); } + Status Visit(const Decimal32Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + Status Visit(const Decimal64Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + Status Visit(const Decimal128Scalar& left) { const auto& right = checked_cast(right_); result_ = left.value == right.value; diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index 5daf7d2991d2a..e9664b104d7a6 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -36,7 +36,7 @@ namespace compute { TEST(TypeMatcher, SameTypeId) { std::shared_ptr matcher = match::SameTypeId(Type::DECIMAL); - ASSERT_TRUE(matcher->Matches(*decimal(12, 2))); + ASSERT_TRUE(matcher->Matches(*decimal128(20, 2))); ASSERT_FALSE(matcher->Matches(*int8())); ASSERT_EQ("Type::DECIMAL128", matcher->ToString()); @@ -120,7 +120,7 @@ TEST(InputType, Constructors) { InputType ty2(Type::DECIMAL); ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); ASSERT_EQ("Type::DECIMAL128", ty2.ToString()); - ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); + ASSERT_TRUE(ty2.type_matcher().Matches(*decimal128(12, 2))); ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); // Implicit construction in a vector @@ -204,9 +204,9 @@ TEST(InputType, Matches) { ASSERT_FALSE(input1.Matches(*int16())); InputType input2(Type::DECIMAL); - ASSERT_TRUE(input2.Matches(*decimal(12, 2))); + ASSERT_TRUE(input2.Matches(*decimal128(12, 2))); - auto ty2 = decimal(12, 2); + auto ty2 = decimal128(12, 2); auto ty3 = float64(); ASSERT_OK_AND_ASSIGN(std::shared_ptr arr2, MakeArrayOfNull(ty2, 1)); ASSERT_OK_AND_ASSIGN(std::shared_ptr arr3, MakeArrayOfNull(ty3, 1)); @@ -319,7 +319,7 @@ TEST(KernelSignature, Basics) { ASSERT_EQ(2, sig.in_types().size()); ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8())); ASSERT_TRUE(sig.in_types()[0].Matches(*int8())); - ASSERT_TRUE(sig.in_types()[1].Matches(*decimal(12, 2))); + ASSERT_TRUE(sig.in_types()[1].Matches(*decimal128(12, 2))); } TEST(KernelSignature, Equals) { @@ -381,7 +381,7 @@ TEST(KernelSignature, MatchesInputs) { ASSERT_FALSE(sig2.MatchesInputs({})); ASSERT_FALSE(sig2.MatchesInputs({int8()})); - ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); + ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal128(12, 2)})); // (int8, int32) -> boolean KernelSignature sig3({int8(), int32()}, boolean()); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index b545d8bcc1003..68b1ac7c03ca8 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -336,8 +336,8 @@ struct ProductImpl : public ScalarAggregator { internal::VisitArrayValuesInline( data, [&](typename TypeTraits::CType value) { - this->product = - MultiplyTraits::Multiply(*out_type, this->product, value); + this->product = MultiplyTraits::Multiply( + *out_type, this->product, static_cast(value)); }, [] {}); } else { @@ -347,8 +347,8 @@ struct ProductImpl : public ScalarAggregator { if (data.is_valid) { for (int64_t i = 0; i < batch.length; i++) { auto value = internal::UnboxScalar::Unbox(data); - this->product = - MultiplyTraits::Multiply(*out_type, this->product, value); + this->product = MultiplyTraits::Multiply( + *out_type, this->product, static_cast(value)); } } } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc index f2151e0a9e029..49010d182cd6d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc @@ -77,7 +77,8 @@ struct SumImpl : public ScalarAggregator { this->count += data.is_valid * batch.length; this->nulls_observed = this->nulls_observed || !data.is_valid; if (data.is_valid) { - this->sum += internal::UnboxScalar::Unbox(data) * batch.length; + this->sum += static_cast(internal::UnboxScalar::Unbox(data) * + batch.length); } } return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h index 168f063c770f3..9dab049821d5c 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h @@ -52,6 +52,16 @@ struct FindAccumulatorType> { using Type = DoubleType; }; +template +struct FindAccumulatorType> { + using Type = Decimal32Type; +}; + +template +struct FindAccumulatorType> { + using Type = Decimal64Type; +}; + template struct FindAccumulatorType> { using Type = Decimal128Type; diff --git a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc index 1dab92632ef2d..83d01091b3c8d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc @@ -51,6 +51,8 @@ struct TDigestImpl : public ScalarAggregator { double ToDouble(T value) const { return static_cast(value); } + double ToDouble(const Decimal32& value) const { return value.ToDouble(decimal_scale); } + double ToDouble(const Decimal64& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal128& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal256& value) const { return value.ToDouble(decimal_scale); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index c2fab48dbe208..e4189f9b62b17 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -46,6 +46,8 @@ struct VarStdState { double ToDouble(T value) const { return static_cast(value); } + double ToDouble(const Decimal32& value) const { return value.ToDouble(decimal_scale); } + double ToDouble(const Decimal64& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal128& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal256& value) const { return value.ToDouble(decimal_scale); } @@ -53,8 +55,9 @@ struct VarStdState { // algorithm` // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm template - enable_if_t::value || (sizeof(CType) > 4)> Consume( - const ArraySpan& array) { + enable_if_t::value || (sizeof(CType) > 4) || + (!is_integer_type::value && sizeof(CType) == 4)> + Consume(const ArraySpan& array) { this->all_valid = array.GetNullCount() == 0; int64_t count = array.length - array.GetNullCount(); if (count == 0 || (!this->all_valid && !options.skip_nulls)) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 7f9be92f3a14b..594bd1fce0b84 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -141,6 +141,30 @@ struct GetViewType::value || static T LogicalValue(PhysicalType value) { return value; } }; +template <> +struct GetViewType { + using T = Decimal32; + using PhysicalType = std::string_view; + + static T LogicalValue(PhysicalType value) { + return Decimal32(reinterpret_cast(value.data())); + } + + static T LogicalValue(T value) { return value; } +}; + +template <> +struct GetViewType { + using T = Decimal64; + using PhysicalType = std::string_view; + + static T LogicalValue(PhysicalType value) { + return Decimal64(reinterpret_cast(value.data())); + } + + static T LogicalValue(T value) { return value; } +}; + template <> struct GetViewType { using T = Decimal128; @@ -178,6 +202,16 @@ struct GetOutputType::value>> { using T = std::string; }; +template <> +struct GetOutputType { + using T = Decimal32; +}; + +template <> +struct GetOutputType { + using T = Decimal64; +}; + template <> struct GetOutputType { using T = Decimal128; @@ -225,7 +259,9 @@ using enable_if_not_floating_value = enable_if_t::val template using enable_if_decimal_value = - enable_if_t::value || std::is_same::value, + enable_if_t::value || std::is_same::value || + std::is_same::value || + std::is_same::value, R>; // ---------------------------------------------------------------------- @@ -354,6 +390,22 @@ struct UnboxScalar> { } }; +template <> +struct UnboxScalar { + using T = Decimal32; + static const T& Unbox(const Scalar& val) { + return checked_cast(val).value; + } +}; + +template <> +struct UnboxScalar { + using T = Decimal64; + static const T& Unbox(const Scalar& val) { + return checked_cast(val).value; + } +}; + template <> struct UnboxScalar { using T = Decimal128; @@ -1117,6 +1169,10 @@ ArrayKernelExec GeneratePhysicalNumeric(detail::GetTypeId get_id) { template