diff --git a/tiledb/CMakeLists.txt b/tiledb/CMakeLists.txt index 10899fa617ba..8c4522bacb4e 100644 --- a/tiledb/CMakeLists.txt +++ b/tiledb/CMakeLists.txt @@ -252,12 +252,12 @@ set(TILEDB_CORE_SOURCES ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_condition.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_remote_buffer_storage.cc - ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/aggregate_sum.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/output_buffer_validator.cc + ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/safe_sum.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/dense_reader.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/ordered_dim_label_reader.cc diff --git a/tiledb/sm/query/readers/aggregators/CMakeLists.txt b/tiledb/sm/query/readers/aggregators/CMakeLists.txt index 72536954053e..3b8b8d729254 100644 --- a/tiledb/sm/query/readers/aggregators/CMakeLists.txt +++ b/tiledb/sm/query/readers/aggregators/CMakeLists.txt @@ -31,7 +31,7 @@ include(object_library) # `aggregators` object library # commence(object_library aggregators) - this_target_sources(aggregate_sum.cc count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc sum_aggregator.cc) + this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc) this_target_object_libraries(baseline array_schema) conclude(object_library) diff --git a/tiledb/sm/query/readers/aggregators/aggregate_sum.h b/tiledb/sm/query/readers/aggregators/aggregate_with_count.h similarity index 74% rename from tiledb/sm/query/readers/aggregators/aggregate_sum.h rename to tiledb/sm/query/readers/aggregators/aggregate_with_count.h index e22be884b940..8a2be8fd9d83 100644 --- a/tiledb/sm/query/readers/aggregators/aggregate_sum.h +++ b/tiledb/sm/query/readers/aggregators/aggregate_with_count.h @@ -35,65 +35,19 @@ #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" +#include "tiledb/sm/query/readers/aggregators/safe_sum.h" namespace tiledb { namespace sm { -#define SUM_TYPE_DATA(T, SUM_T) \ - template <> \ - struct sum_type_data { \ - using type = T; \ - typedef SUM_T sum_type; \ - }; - -/** Convert basic type to a sum type. **/ -template -struct sum_type_data; - -SUM_TYPE_DATA(int8_t, int64_t); -SUM_TYPE_DATA(uint8_t, uint64_t); -SUM_TYPE_DATA(int16_t, int64_t); -SUM_TYPE_DATA(uint16_t, uint64_t); -SUM_TYPE_DATA(int32_t, int64_t); -SUM_TYPE_DATA(uint32_t, uint64_t); -SUM_TYPE_DATA(int64_t, int64_t); -SUM_TYPE_DATA(uint64_t, uint64_t); -SUM_TYPE_DATA(float, double); -SUM_TYPE_DATA(double, double); - -/** - * Sum function that prevent wrap arounds on overflow. - * - * @param value Value to add to the sum. - * @param sum Computed sum. - */ -template -void safe_sum(SUM_T value, SUM_T& sum); - -/** - * Sum function for atomics that prevent wrap arounds on overflow. - * - * @param value Value to add to the sum. - * @param sum Computed sum. - */ -template -void safe_sum(SUM_T value, std::atomic& sum) { - SUM_T cur_sum = sum; - SUM_T new_sum; - do { - new_sum = cur_sum; - safe_sum(value, new_sum); - } while (!std::atomic_compare_exchange_weak(&sum, &cur_sum, new_sum)); -} - template -class AggregateSum { +class AggregateWithCount { public: /* ********************************* */ /* CONSTRUCTORS & DESTRUCTORS */ /* ********************************* */ - AggregateSum(const FieldInfo field_info) + AggregateWithCount(const FieldInfo field_info) : field_info_(field_info) { } @@ -110,8 +64,9 @@ class AggregateSum { * * @return {Sum for the cells, number of cells, optional validity value}. */ - template - tuple> sum(AggregateBuffer& input_data) { + template + tuple> aggregate( + AggregateBuffer& input_data) { SUM_T sum{0}; uint64_t count{0}; optional validity{nullopt}; @@ -136,7 +91,7 @@ class AggregateSum { auto value = static_cast(values[c]); for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { count++; - safe_sum(value, sum); + AggPolicy::op(value, sum); } } } @@ -148,7 +103,7 @@ class AggregateSum { for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { count++; - safe_sum(value, sum); + AggPolicy::op(value, sum); } } } @@ -165,7 +120,7 @@ class AggregateSum { auto value = static_cast(values[c]); count++; - safe_sum(value, sum); + AggPolicy::op(value, sum); } } } else { @@ -174,7 +129,7 @@ class AggregateSum { c++) { auto value = static_cast(values[c]); count++; - safe_sum(value, sum); + AggPolicy::op(value, sum); } } } diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc index 6057aed91c3a..d946ae932d51 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc @@ -94,16 +94,18 @@ void MeanAggregator::aggregate_data(AggregateBuffer& input_data) { 0, 0, nullopt}; if (input_data.is_count_bitmap()) { - res = - summator_.template sum::sum_type, uint64_t>( - input_data); + res = summator_.template aggregate< + typename sum_type_data::sum_type, + uint64_t, + SafeSum>(input_data); } else { - res = - summator_.template sum::sum_type, uint8_t>( - input_data); + res = summator_.template aggregate< + typename sum_type_data::sum_type, + uint8_t, + SafeSum>(input_data); } - safe_sum(std::get<0>(res), sum_); + SafeSum::safe_sum(std::get<0>(res), sum_); count_ += std::get<1>(res); if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) { validity_value_ = 1; diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.h b/tiledb/sm/query/readers/aggregators/mean_aggregator.h index 64ea7142ac81..5b6d459ad470 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.h @@ -34,9 +34,10 @@ #define TILEDB_MEAN_AGGREGATOR_H #include "tiledb/common/common.h" -#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h" +#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" +#include "tiledb/sm/query/readers/aggregators/sum_type.h" namespace tiledb { namespace sm { @@ -116,8 +117,8 @@ class MeanAggregator : public OutputBufferValidator, public IAggregator { /** Field information. */ const FieldInfo field_info_; - /** AggregateSum to do summation of AggregateBuffer data. */ - AggregateSum summator_; + /** AggregateWithCount to do summation of AggregateBuffer data. */ + AggregateWithCount summator_; /** Computed sum. */ std::atomic::sum_type> sum_; diff --git a/tiledb/sm/query/readers/aggregators/aggregate_sum.cc b/tiledb/sm/query/readers/aggregators/safe_sum.cc similarity index 81% rename from tiledb/sm/query/readers/aggregators/aggregate_sum.cc rename to tiledb/sm/query/readers/aggregators/safe_sum.cc index 839c84581633..ca4b5b9a2f0d 100644 --- a/tiledb/sm/query/readers/aggregators/aggregate_sum.cc +++ b/tiledb/sm/query/readers/aggregators/safe_sum.cc @@ -1,5 +1,5 @@ /** - * @file aggregate_sum.cc + * @file safe_sum.cc * * @section LICENSE * @@ -27,17 +27,19 @@ * * @section DESCRIPTION * - * This file implements class AggregateSum. + * This file implements class SafeSum. */ -#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h" +#include "tiledb/sm/query/readers/aggregators/safe_sum.h" + +#include namespace tiledb { namespace sm { -/** Specialization of safe_sum for int64_t sums. */ +/** Specialization of op for int64_t sums. */ template <> -void safe_sum(int64_t value, int64_t& sum) { +void SafeSum::op(int64_t value, int64_t& sum) { if (sum > 0 && value > 0 && (sum > (std::numeric_limits::max() - value))) { throw std::overflow_error("overflow on sum"); @@ -51,9 +53,9 @@ void safe_sum(int64_t value, int64_t& sum) { sum += value; } -/** Specialization of safe_sum for uint64_t sums. */ +/** Specialization of op for uint64_t sums. */ template <> -void safe_sum(uint64_t value, uint64_t& sum) { +void SafeSum::op(uint64_t value, uint64_t& sum) { if (sum > (std::numeric_limits::max() - value)) { throw std::overflow_error("overflow on sum"); } @@ -61,9 +63,9 @@ void safe_sum(uint64_t value, uint64_t& sum) { sum += value; } -/** Specialization of safe_sum for double sums. */ +/** Specialization of op for double sums. */ template <> -void safe_sum(double value, double& sum) { +void SafeSum::op(double value, double& sum) { if ((sum < 0.0) == (value < 0.0) && (std::abs(sum) > (std::numeric_limits::max() - std::abs(value)))) { diff --git a/tiledb/sm/query/readers/aggregators/safe_sum.h b/tiledb/sm/query/readers/aggregators/safe_sum.h new file mode 100644 index 000000000000..ed92662f73cf --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/safe_sum.h @@ -0,0 +1,73 @@ +/** + * @file safe_sum.h + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file defines class SafeSum. + */ + +#ifndef TILEDB_SAFE_SUM_H +#define TILEDB_SAFE_SUM_H + +#include +#include + +namespace tiledb { +namespace sm { + +class SafeSum { + public: + /** + * Sum function that prevent wrap arounds on overflow. + * + * @param value Value to add to the sum. + * @param sum Computed sum. + */ + template + static void op(SUM_T value, SUM_T& sum); + + /** + * Sum function for atomics that prevent wrap arounds on overflow. + * + * @param value Value to add to the sum. + * @param sum Computed sum. + */ + template + static void safe_sum(SUM_T value, std::atomic& sum) { + SUM_T cur_sum = sum; + SUM_T new_sum; + do { + new_sum = cur_sum; + op(value, new_sum); + } while (!std::atomic_compare_exchange_weak(&sum, &cur_sum, new_sum)); + } +}; + +} // namespace sm +} // namespace tiledb + +#endif // TILEDB_SAFE_SUM_H diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc index a9fba1e1b801..baf39806c202 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc @@ -89,16 +89,18 @@ void SumAggregator::aggregate_data(AggregateBuffer& input_data) { 0, 0, nullopt}; if (input_data.is_count_bitmap()) { - res = - summator_.template sum::sum_type, uint64_t>( - input_data); + res = summator_.template aggregate< + typename sum_type_data::sum_type, + uint64_t, + SafeSum>(input_data); } else { - res = - summator_.template sum::sum_type, uint8_t>( - input_data); + res = summator_.template aggregate< + typename sum_type_data::sum_type, + uint8_t, + SafeSum>(input_data); } - safe_sum(std::get<0>(res), sum_); + SafeSum::safe_sum(std::get<0>(res), sum_); if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) { validity_value_ = 1; } diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.h b/tiledb/sm/query/readers/aggregators/sum_aggregator.h index d50c16b9e259..c76c4ed37533 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.h @@ -33,9 +33,10 @@ #ifndef TILEDB_SUM_AGGREGATOR_H #define TILEDB_SUM_AGGREGATOR_H -#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h" +#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" +#include "tiledb/sm/query/readers/aggregators/sum_type.h" namespace tiledb { namespace sm { @@ -115,8 +116,8 @@ class SumAggregator : public OutputBufferValidator, public IAggregator { /** Field information. */ const FieldInfo field_info_; - /** AggregateSum to do summation of AggregateBuffer data. */ - AggregateSum summator_; + /** AggregateWithCount to do summation of AggregateBuffer data. */ + AggregateWithCount summator_; /** Computed sum. */ std::atomic::sum_type> sum_; diff --git a/tiledb/sm/query/readers/aggregators/sum_type.h b/tiledb/sm/query/readers/aggregators/sum_type.h new file mode 100644 index 000000000000..b48c6238f581 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/sum_type.h @@ -0,0 +1,64 @@ +/** + * @file sum_type.h + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file defines sum types in relation to basic types. + */ + +#ifndef TILEDB_SUM_TYPE_H +#define TILEDB_SUM_TYPE_H + +namespace tiledb { +namespace sm { + +#define SUM_TYPE_DATA(T, SUM_T) \ + template <> \ + struct sum_type_data { \ + using type = T; \ + typedef SUM_T sum_type; \ + }; + +/** Convert basic type to a sum type. **/ +template +struct sum_type_data; + +SUM_TYPE_DATA(int8_t, int64_t); +SUM_TYPE_DATA(uint8_t, uint64_t); +SUM_TYPE_DATA(int16_t, int64_t); +SUM_TYPE_DATA(uint16_t, uint64_t); +SUM_TYPE_DATA(int32_t, int64_t); +SUM_TYPE_DATA(uint32_t, uint64_t); +SUM_TYPE_DATA(int64_t, int64_t); +SUM_TYPE_DATA(uint64_t, uint64_t); +SUM_TYPE_DATA(float, double); +SUM_TYPE_DATA(double, double); + +} // namespace sm +} // namespace tiledb + +#endif // TILEDB_SUM_TYPE_H