From d6e6487fff01bcb4135b45c6156cab18aaf97088 Mon Sep 17 00:00:00 2001 From: Luc Rancourt Date: Tue, 29 Aug 2023 12:15:54 +0200 Subject: [PATCH] Making common sum implementation. --- tiledb/CMakeLists.txt | 1 + .../query/readers/aggregators/CMakeLists.txt | 2 +- .../readers/aggregators/aggregate_buffer.h | 5 +- .../readers/aggregators/aggregate_sum.cc | 77 +++++++ .../query/readers/aggregators/aggregate_sum.h | 197 ++++++++++++++++++ .../readers/aggregators/count_aggregator.cc | 2 - .../readers/aggregators/count_aggregator.h | 5 +- .../sm/query/readers/aggregators/field_info.h | 5 +- .../query/readers/aggregators/iaggregator.h | 6 +- .../readers/aggregators/mean_aggregator.cc | 145 +++---------- .../readers/aggregators/mean_aggregator.h | 31 +-- .../readers/aggregators/min_max_aggregator.cc | 2 - .../readers/aggregators/min_max_aggregator.h | 5 +- .../readers/aggregators/sum_aggregator.cc | 180 +++------------- .../readers/aggregators/sum_aggregator.h | 61 +----- .../readers/aggregators/test/unit_mean.cc | 4 +- .../readers/aggregators/test/unit_sum.cc | 8 +- 17 files changed, 366 insertions(+), 370 deletions(-) create mode 100644 tiledb/sm/query/readers/aggregators/aggregate_sum.cc create mode 100644 tiledb/sm/query/readers/aggregators/aggregate_sum.h diff --git a/tiledb/CMakeLists.txt b/tiledb/CMakeLists.txt index d4de37f944a7..abae3245e191 100644 --- a/tiledb/CMakeLists.txt +++ b/tiledb/CMakeLists.txt @@ -252,6 +252,7 @@ set(TILEDB_CORE_SOURCES ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_condition.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_remote_buffer_storage.cc + ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/aggregate_sum.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc diff --git a/tiledb/sm/query/readers/aggregators/CMakeLists.txt b/tiledb/sm/query/readers/aggregators/CMakeLists.txt index 8839c4d6e528..c7c25dee26e4 100644 --- a/tiledb/sm/query/readers/aggregators/CMakeLists.txt +++ b/tiledb/sm/query/readers/aggregators/CMakeLists.txt @@ -31,7 +31,7 @@ include(object_library) # `aggregators` object library # commence(object_library aggregators) - this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc sum_aggregator.cc) + this_target_sources(aggregate_sum.cc count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc sum_aggregator.cc) this_target_object_libraries(baseline array_schema) conclude(object_library) diff --git a/tiledb/sm/query/readers/aggregators/aggregate_buffer.h b/tiledb/sm/query/readers/aggregators/aggregate_buffer.h index 514e3045b218..fdba7827b2c8 100644 --- a/tiledb/sm/query/readers/aggregators/aggregate_buffer.h +++ b/tiledb/sm/query/readers/aggregators/aggregate_buffer.h @@ -33,10 +33,7 @@ #ifndef TILEDB_AGGREGATE_BUFFER_H #define TILEDB_AGGREGATE_BUFFER_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" - -using namespace tiledb::common; +#include "tiledb/common/common.h" namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/aggregate_sum.cc b/tiledb/sm/query/readers/aggregators/aggregate_sum.cc new file mode 100644 index 000000000000..839c84581633 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/aggregate_sum.cc @@ -0,0 +1,77 @@ +/** + * @file aggregate_sum.cc + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file implements class AggregateSum. + */ + +#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h" + +namespace tiledb { +namespace sm { + +/** Specialization of safe_sum for int64_t sums. */ +template <> +void safe_sum(int64_t value, int64_t& sum) { + if (sum > 0 && value > 0 && + (sum > (std::numeric_limits::max() - value))) { + throw std::overflow_error("overflow on sum"); + } + + if (sum < 0 && value < 0 && + (sum < (std::numeric_limits::min() - value))) { + throw std::overflow_error("overflow on sum"); + } + + sum += value; +} + +/** Specialization of safe_sum for uint64_t sums. */ +template <> +void safe_sum(uint64_t value, uint64_t& sum) { + if (sum > (std::numeric_limits::max() - value)) { + throw std::overflow_error("overflow on sum"); + } + + sum += value; +} + +/** Specialization of safe_sum for double sums. */ +template <> +void safe_sum(double value, double& sum) { + if ((sum < 0.0) == (value < 0.0) && + (std::abs(sum) > + (std::numeric_limits::max() - std::abs(value)))) { + throw std::overflow_error("overflow on sum"); + } + + sum += value; +} + +} // namespace sm +} // namespace tiledb diff --git a/tiledb/sm/query/readers/aggregators/aggregate_sum.h b/tiledb/sm/query/readers/aggregators/aggregate_sum.h new file mode 100644 index 000000000000..e22be884b940 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/aggregate_sum.h @@ -0,0 +1,197 @@ +/** + * @file aggregate_sum.h + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file defines class AggregateSum. + */ + +#ifndef TILEDB_AGGREGATE_SUM_H +#define TILEDB_AGGREGATE_SUM_H + +#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" +#include "tiledb/sm/query/readers/aggregators/field_info.h" + +namespace tiledb { +namespace sm { + +#define SUM_TYPE_DATA(T, SUM_T) \ + template <> \ + struct sum_type_data { \ + using type = T; \ + typedef SUM_T sum_type; \ + }; + +/** Convert basic type to a sum type. **/ +template +struct sum_type_data; + +SUM_TYPE_DATA(int8_t, int64_t); +SUM_TYPE_DATA(uint8_t, uint64_t); +SUM_TYPE_DATA(int16_t, int64_t); +SUM_TYPE_DATA(uint16_t, uint64_t); +SUM_TYPE_DATA(int32_t, int64_t); +SUM_TYPE_DATA(uint32_t, uint64_t); +SUM_TYPE_DATA(int64_t, int64_t); +SUM_TYPE_DATA(uint64_t, uint64_t); +SUM_TYPE_DATA(float, double); +SUM_TYPE_DATA(double, double); + +/** + * Sum function that prevent wrap arounds on overflow. + * + * @param value Value to add to the sum. + * @param sum Computed sum. + */ +template +void safe_sum(SUM_T value, SUM_T& sum); + +/** + * Sum function for atomics that prevent wrap arounds on overflow. + * + * @param value Value to add to the sum. + * @param sum Computed sum. + */ +template +void safe_sum(SUM_T value, std::atomic& sum) { + SUM_T cur_sum = sum; + SUM_T new_sum; + do { + new_sum = cur_sum; + safe_sum(value, new_sum); + } while (!std::atomic_compare_exchange_weak(&sum, &cur_sum, new_sum)); +} + +template +class AggregateSum { + public: + /* ********************************* */ + /* CONSTRUCTORS & DESTRUCTORS */ + /* ********************************* */ + + AggregateSum(const FieldInfo field_info) + : field_info_(field_info) { + } + + /* ********************************* */ + /* API */ + /* ********************************* */ + + /** + * Add the sum of cells for the input data. + * + * @tparam SUM_T Sum type. + * @tparam BITMAP_T Bitmap type. + * @param input_data Input data for the sum. + * + * @return {Sum for the cells, number of cells, optional validity value}. + */ + template + tuple> sum(AggregateBuffer& input_data) { + SUM_T sum{0}; + uint64_t count{0}; + optional validity{nullopt}; + auto values = input_data.fixed_data_as(); + + // Run different loops for bitmap versus no bitmap and nullable versus non + // nullable. The bitmap tells us which cells was already filtered out by + // ranges or query conditions. + if (input_data.has_bitmap()) { + auto bitmap_values = input_data.bitmap_data_as(); + + if (field_info_.is_nullable_) { + validity = 0; + auto validity_values = input_data.validity_data(); + + // Process for nullable sums with bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); + c++) { + if (validity_values[c] != 0 && bitmap_values[c] != 0) { + validity = 1; + + auto value = static_cast(values[c]); + for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { + count++; + safe_sum(value, sum); + } + } + } + } else { + // Process for non nullable sums with bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); + c++) { + auto value = static_cast(values[c]); + + for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { + count++; + safe_sum(value, sum); + } + } + } + } else { + if (field_info_.is_nullable_) { + validity = 0; + auto validity_values = input_data.validity_data(); + + // Process for nullable sums with no bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); + c++) { + if (validity_values[c] != 0) { + validity = 1; + + auto value = static_cast(values[c]); + count++; + safe_sum(value, sum); + } + } + } else { + // Process for non nullable sums with no bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); + c++) { + auto value = static_cast(values[c]); + count++; + safe_sum(value, sum); + } + } + } + + return {sum, count, validity}; + } + + private: + /* ********************************* */ + /* PRIVATE ATTRIBUTES */ + /* ********************************* */ + + /** Field information. */ + const FieldInfo field_info_; +}; + +} // namespace sm +} // namespace tiledb + +#endif // TILEDB_AGGREGATE_SUM_H diff --git a/tiledb/sm/query/readers/aggregators/count_aggregator.cc b/tiledb/sm/query/readers/aggregators/count_aggregator.cc index ebdff4c4df90..8e6181e0ab85 100644 --- a/tiledb/sm/query/readers/aggregators/count_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/count_aggregator.cc @@ -35,8 +35,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -using namespace tiledb::common; - namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/count_aggregator.h b/tiledb/sm/query/readers/aggregators/count_aggregator.h index 9a1cf1386905..8e6595b518ff 100644 --- a/tiledb/sm/query/readers/aggregators/count_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/count_aggregator.h @@ -33,12 +33,9 @@ #ifndef TILEDB_COUNT_AGGREGATOR_H #define TILEDB_COUNT_AGGREGATOR_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" +#include "tiledb/common/common.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" -using namespace tiledb::common; - namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/field_info.h b/tiledb/sm/query/readers/aggregators/field_info.h index 29cd707cd8e0..5dd322a976a8 100644 --- a/tiledb/sm/query/readers/aggregators/field_info.h +++ b/tiledb/sm/query/readers/aggregators/field_info.h @@ -33,10 +33,7 @@ #ifndef TILEDB_FIELD_INFO_H #define TILEDB_FIELD_INFO_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" - -using namespace tiledb::common; +#include "tiledb/common/common.h" namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/iaggregator.h b/tiledb/sm/query/readers/aggregators/iaggregator.h index 821de4c83697..b5d13699c9f3 100644 --- a/tiledb/sm/query/readers/aggregators/iaggregator.h +++ b/tiledb/sm/query/readers/aggregators/iaggregator.h @@ -33,10 +33,8 @@ #ifndef TILEDB_IAGGREGATOR_H #define TILEDB_IAGGREGATOR_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" - -using namespace tiledb::common; +#include "tiledb/common/common.h" +#include "tiledb/sm/misc/constants.h" namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc index e6ea98ffcc22..e95bb4609c2b 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc @@ -34,9 +34,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -#include "tiledb/sm/query/readers/aggregators/sum_aggregator.h" - -using namespace tiledb::common; namespace tiledb { namespace sm { @@ -51,6 +48,7 @@ class MeanAggregatorStatusException : public StatusException { template MeanAggregator::MeanAggregator(const FieldInfo field_info) : field_info_(field_info) + , summator_(field_info) , sum_(0) , count_(0) , validity_value_( @@ -116,49 +114,32 @@ void MeanAggregator::validate_output_buffer( template void MeanAggregator::aggregate_data(AggregateBuffer& input_data) { - tuple> res{0, 0, nullopt}; + // Return if a previous aggregation has overflowed. + if (sum_overflowed_) { + return; + } - bool overflow = false; try { + tuple::sum_type, uint64_t, optional> res{ + 0, 0, nullopt}; + if (input_data.is_count_bitmap()) { - res = mean(input_data); + res = + summator_.template sum::sum_type, uint64_t>( + input_data); } else { - res = mean(input_data); + res = + summator_.template sum::sum_type, uint8_t>( + input_data); } - } catch (std::overflow_error&) { - overflow = true; - } - - { - // This might be called on multiple threads, the final result stored in sum_ - // should be computed in a thread safe manner. The mutex also protects - // sum_overflowed_ which indicates when the sum has overflowed. - std::unique_lock lock(mean_mtx_); - // A previous operation already overflowed the sum, return. - if (sum_overflowed_) { - return; - } - - // If we have an overflow, signal it, else it's business as usual. - if (overflow) { - sum_overflowed_ = true; - sum_ = std::get<0>(res); - count_ = std::numeric_limits::max(); - return; - } else { - // This sum might overflow as well. - try { - safe_sum(std::get<0>(res), sum_); - safe_sum(std::get<1>(res), count_); - } catch (std::overflow_error&) { - sum_overflowed_ = true; - } + safe_sum(std::get<0>(res), sum_); + count_ += std::get<1>(res); + if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) { + validity_value_ = 1; } - } - - if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) { - validity_value_ = 1; + } catch (std::overflow_error& e) { + sum_overflowed_ = true; } } @@ -167,90 +148,30 @@ void MeanAggregator::copy_to_user_buffer( std::string output_field_name, std::unordered_map& buffers) { auto& result_buffer = buffers[output_field_name]; - *static_cast(result_buffer.buffer_) = sum_ / count_; + auto s = static_cast(result_buffer.buffer_); + + if (sum_overflowed_) { + *s = std::numeric_limits::max(); + } else { + *s = static_cast(sum_) / count_; + } if (result_buffer.buffer_size_) { *result_buffer.buffer_size_ = sizeof(double); } if (field_info_.is_nullable_) { - *static_cast(result_buffer.validity_vector_.buffer()) = - validity_value_.value(); - - if (result_buffer.validity_vector_.buffer_size()) { - *result_buffer.validity_vector_.buffer_size() = 1; - } - } -} - -template -template -tuple> MeanAggregator::mean( - AggregateBuffer& input_data) { - double sum{0}; - uint64_t count{0}; - optional validity{nullopt}; - auto values = input_data.fixed_data_as(); - - // Run different loops for bitmap versus no bitmap and nullable versus non - // nullable. The bitmap tells us which cells was already filtered out by - // ranges or query conditions. - if (input_data.has_bitmap()) { - auto bitmap_values = input_data.bitmap_data_as(); - - if (field_info_.is_nullable_) { - validity = 0; - auto validity_values = input_data.validity_data(); - - // Process for nullable sums with bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] != 0 && bitmap_values[c] != 0) { - validity = 1; - - auto value = static_cast(values[c]); - for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { - count++; - safe_sum(value, sum); - } - } - } + auto v = static_cast(result_buffer.validity_vector_.buffer()); + if (sum_overflowed_) { + *v = 0; } else { - // Process for non nullable sums with bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - auto value = static_cast(values[c]); - - for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { - count++; - safe_sum(value, sum); - } - } + *v = validity_value_.value(); } - } else { - if (field_info_.is_nullable_) { - validity = 0; - auto validity_values = input_data.validity_data(); - - // Process for nullable sums with no bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] != 0) { - validity = 1; - auto value = static_cast(values[c]); - count++; - safe_sum(value, sum); - } - } - } else { - // Process for non nullable sums with no bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - auto value = static_cast(values[c]); - count++; - safe_sum(value, sum); - } + if (result_buffer.validity_vector_.buffer_size()) { + *result_buffer.validity_vector_.buffer_size() = 1; } } - - return {sum, count, validity}; } // Explicit template instantiations diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.h b/tiledb/sm/query/readers/aggregators/mean_aggregator.h index 7f8e731747a0..1c8e719461c5 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.h @@ -33,13 +33,11 @@ #ifndef TILEDB_MEAN_AGGREGATOR_H #define TILEDB_MEAN_AGGREGATOR_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" +#include "tiledb/common/common.h" +#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" -using namespace tiledb::common; - namespace tiledb { namespace sm { @@ -118,35 +116,20 @@ class MeanAggregator : public IAggregator { /** Field information. */ const FieldInfo field_info_; - /** Mutex protecting `sum_`, `sum_overflowed_` and count_. */ - std::mutex mean_mtx_; + /** AggregateSum to do summation of AggregateBuffer data. */ + AggregateSum summator_; /** Computed sum. */ - double sum_; + std::atomic::sum_type> sum_; /** Count of values. */ - uint64_t count_; + std::atomic count_; /** Computed validity value. */ optional validity_value_; /** Has the sum overflowed. */ - bool sum_overflowed_; - - /* ********************************* */ - /* PRIVATE METHODS */ - /* ********************************* */ - - /** - * Add the sum/count of cells for the input data. - * - * @tparam BITMAP_T Bitmap type. - * @param input_data Input data for the mean. - * - * @return {Computed sum, count of cells, optional validity value}. - */ - template - tuple> mean(AggregateBuffer& input_data); + std::atomic sum_overflowed_; }; } // namespace sm diff --git a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc index 02cba5f5cee3..5554bf9d42a6 100644 --- a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc @@ -35,8 +35,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -using namespace tiledb::common; - namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/min_max_aggregator.h b/tiledb/sm/query/readers/aggregators/min_max_aggregator.h index 1db5799645ae..9cb4eba1d637 100644 --- a/tiledb/sm/query/readers/aggregators/min_max_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/min_max_aggregator.h @@ -33,15 +33,12 @@ #ifndef TILEDB_MIN_MAX_AGGREGATOR_H #define TILEDB_MIN_MAX_AGGREGATOR_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" +#include "tiledb/common/common.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" #include -using namespace tiledb::common; - namespace tiledb { namespace sm { diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc index aec3411b25e4..b898952f4b4e 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc @@ -35,8 +35,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -using namespace tiledb::common; - namespace tiledb { namespace sm { @@ -47,52 +45,10 @@ class SumAggregatorStatusException : public StatusException { } }; -/** Specialization of safe_sum for int64_t sums. */ -template <> -void safe_sum(int64_t value, int64_t& sum) { - if (sum > 0 && value > 0 && - (sum > (std::numeric_limits::max() - value))) { - sum = std::numeric_limits::max(); - throw std::overflow_error("overflow on sum"); - } - - if (sum < 0 && value < 0 && - (sum < (std::numeric_limits::min() - value))) { - sum = std::numeric_limits::min(); - throw std::overflow_error("overflow on sum"); - } - - sum += value; -} - -/** Specialization of safe_sum for uint64_t sums. */ -template <> -void safe_sum(uint64_t value, uint64_t& sum) { - if (sum > (std::numeric_limits::max() - value)) { - sum = std::numeric_limits::max(); - throw std::overflow_error("overflow on sum"); - } - - sum += value; -} - -/** Specialization of safe_sum for double sums. */ -template <> -void safe_sum(double value, double& sum) { - if ((sum < 0.0) == (value < 0.0) && - (std::abs(sum) > - (std::numeric_limits::max() - std::abs(value)))) { - sum = sum < 0.0 ? std::numeric_limits::lowest() : - std::numeric_limits::max(); - throw std::overflow_error("overflow on sum"); - } - - sum += value; -} - template SumAggregator::SumAggregator(const FieldInfo field_info) : field_info_(field_info) + , summator_(field_info) , sum_(0) , validity_value_( field_info_.is_nullable_ ? std::make_optional(0) : nullopt) @@ -156,47 +112,31 @@ void SumAggregator::validate_output_buffer( template void SumAggregator::aggregate_data(AggregateBuffer& input_data) { - tuple::sum_type, optional> res{0, nullopt}; + // Return if a previous aggregation has overflowed. + if (sum_overflowed_) { + return; + } - bool overflow = false; try { + tuple::sum_type, uint64_t, optional> res{ + 0, 0, nullopt}; + if (input_data.is_count_bitmap()) { - res = sum::sum_type, uint64_t>(input_data); + res = + summator_.template sum::sum_type, uint64_t>( + input_data); } else { - res = sum::sum_type, uint8_t>(input_data); - } - } catch (std::overflow_error&) { - overflow = true; - } - - { - // This might be called on multiple threads, the final result stored in sum_ - // should be computed in a thread safe manner. The mutex also protects - // sum_overflowed_ which indicates when the sum has overflowed. - std::unique_lock lock(sum_mtx_); - - // A previous operation already overflowed the sum, return. - if (sum_overflowed_) { - return; + res = + summator_.template sum::sum_type, uint8_t>( + input_data); } - // If we have an overflow, signal it, else it's business as usual. - if (overflow) { - sum_overflowed_ = true; - sum_ = std::get<0>(res); - return; - } else { - // This sum might overflow as well. - try { - safe_sum(std::get<0>(res), sum_); - } catch (std::overflow_error&) { - sum_overflowed_ = true; - } + safe_sum(std::get<0>(res), sum_); + if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) { + validity_value_ = 1; } - } - - if (field_info_.is_nullable_ && std::get<1>(res).value() == 1) { - validity_value_ = 1; + } catch (std::overflow_error& e) { + sum_overflowed_ = true; } } @@ -205,86 +145,30 @@ void SumAggregator::copy_to_user_buffer( std::string output_field_name, std::unordered_map& buffers) { auto& result_buffer = buffers[output_field_name]; - *static_cast::sum_type*>(result_buffer.buffer_) = - sum_; + auto s = + static_cast::sum_type*>(result_buffer.buffer_); + if (sum_overflowed_) { + *s = std::numeric_limits::sum_type>::max(); + } else { + *s = sum_; + } if (result_buffer.buffer_size_) { *result_buffer.buffer_size_ = sizeof(typename sum_type_data::sum_type); } if (field_info_.is_nullable_) { - *static_cast(result_buffer.validity_vector_.buffer()) = - validity_value_.value(); - - if (result_buffer.validity_vector_.buffer_size()) { - *result_buffer.validity_vector_.buffer_size() = 1; - } - } -} - -template -template -tuple> SumAggregator::sum( - AggregateBuffer& input_data) { - SUM_T sum{0}; - optional validity{nullopt}; - auto values = input_data.fixed_data_as(); - - // Run different loops for bitmap versus no bitmap and nullable versus non - // nullable. The bitmap tells us which cells was already filtered out by - // ranges or query conditions. - if (input_data.has_bitmap()) { - auto bitmap_values = input_data.bitmap_data_as(); - - if (field_info_.is_nullable_) { - validity = 0; - auto validity_values = input_data.validity_data(); - - // Process for nullable sums with bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] != 0 && bitmap_values[c] != 0) { - validity = 1; - - auto value = static_cast(values[c]); - for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { - safe_sum(value, sum); - } - } - } + auto v = static_cast(result_buffer.validity_vector_.buffer()); + if (sum_overflowed_) { + *v = 0; } else { - // Process for non nullable sums with bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - auto value = static_cast(values[c]); - - for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { - safe_sum(value, sum); - } - } + *v = validity_value_.value(); } - } else { - if (field_info_.is_nullable_) { - validity = 0; - auto validity_values = input_data.validity_data(); - // Process for nullable sums with no bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] != 0) { - validity = 1; - - auto value = static_cast(values[c]); - safe_sum(value, sum); - } - } - } else { - // Process for non nullable sums with no bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - auto value = static_cast(values[c]); - safe_sum(value, sum); - } + if (result_buffer.validity_vector_.buffer_size()) { + *result_buffer.validity_vector_.buffer_size() = 1; } } - - return {sum, validity}; } // Explicit template instantiations diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.h b/tiledb/sm/query/readers/aggregators/sum_aggregator.h index b083a5bb8fd2..f3064b06049b 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.h @@ -33,49 +33,16 @@ #ifndef TILEDB_SUM_AGGREGATOR_H #define TILEDB_SUM_AGGREGATOR_H -#include "tiledb/common/status.h" -#include "tiledb/sm/enums/layout.h" +#include "tiledb/common/common.h" +#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" -using namespace tiledb::common; - namespace tiledb { namespace sm { -#define SUM_TYPE_DATA(T, SUM_T) \ - template <> \ - struct sum_type_data { \ - using type = T; \ - typedef SUM_T sum_type; \ - }; - -/** Convert basic type to a sum type. **/ -template -struct sum_type_data; - -SUM_TYPE_DATA(int8_t, int64_t); -SUM_TYPE_DATA(uint8_t, uint64_t); -SUM_TYPE_DATA(int16_t, int64_t); -SUM_TYPE_DATA(uint16_t, uint64_t); -SUM_TYPE_DATA(int32_t, int64_t); -SUM_TYPE_DATA(uint32_t, uint64_t); -SUM_TYPE_DATA(int64_t, int64_t); -SUM_TYPE_DATA(uint64_t, uint64_t); -SUM_TYPE_DATA(float, double); -SUM_TYPE_DATA(double, double); - class QueryBuffer; -/** - * Sum function that prevent wrap arounds on overflow. - * - * @param value Value to add to the sum. - * @param sum Computed sum. - */ -template -void safe_sum(SUM_T value, SUM_T& sum); - template class SumAggregator : public IAggregator { public: @@ -149,33 +116,17 @@ class SumAggregator : public IAggregator { /** Field information. */ const FieldInfo field_info_; - /** Mutex protecting `sum_` and `sum_overflowed_`. */ - std::mutex sum_mtx_; + /** AggregateSum to do summation of AggregateBuffer data. */ + AggregateSum summator_; /** Computed sum. */ - typename sum_type_data::sum_type sum_; + std::atomic::sum_type> sum_; /** Computed validity value. */ optional validity_value_; /** Has the sum overflowed. */ - bool sum_overflowed_; - - /* ********************************* */ - /* PRIVATE METHODS */ - /* ********************************* */ - - /** - * Add the sum of cells for the input data. - * - * @tparam SUM_T Sum type. - * @tparam BITMAP_T Bitmap type. - * @param input_data Input data for the sum. - * - * @return {Computed sum for the cells, optional validity value}. - */ - template - tuple> sum(AggregateBuffer& input_data); + std::atomic sum_overflowed_; }; } // namespace sm diff --git a/tiledb/sm/query/readers/aggregators/test/unit_mean.cc b/tiledb/sm/query/readers/aggregators/test/unit_mean.cc index f04d3001d4d6..e74257d7f0b6 100644 --- a/tiledb/sm/query/readers/aggregators/test/unit_mean.cc +++ b/tiledb/sm/query/readers/aggregators/test/unit_mean.cc @@ -400,11 +400,11 @@ TEST_CASE("Mean aggregator: overflow", "[mean-aggregator][overflow]") { // Now cause an underflow. aggregator.aggregate_data(input_data_lowest); aggregator.copy_to_user_buffer("Mean", buffers); - CHECK(mean == std::numeric_limits::lowest()); + CHECK(mean == std::numeric_limits::max()); // Once we underflow, the value doesn't change. aggregator.aggregate_data(input_data_max); aggregator.copy_to_user_buffer("Mean", buffers); - CHECK(mean == std::numeric_limits::lowest()); + CHECK(mean == std::numeric_limits::max()); } } diff --git a/tiledb/sm/query/readers/aggregators/test/unit_sum.cc b/tiledb/sm/query/readers/aggregators/test/unit_sum.cc index f0b1fec115d1..9749fff1a358 100644 --- a/tiledb/sm/query/readers/aggregators/test/unit_sum.cc +++ b/tiledb/sm/query/readers/aggregators/test/unit_sum.cc @@ -427,12 +427,12 @@ TEST_CASE( aggregator.aggregate_data(input_data_minus_one); aggregator.aggregate_data(input_data_minus_one); aggregator.copy_to_user_buffer("Sum", buffers); - CHECK(sum == std::numeric_limits::min()); + CHECK(sum == std::numeric_limits::max()); // Once we underflow, the value doesn't change. aggregator.aggregate_data(input_data_plus_one); aggregator.copy_to_user_buffer("Sum", buffers); - CHECK(sum == std::numeric_limits::min()); + CHECK(sum == std::numeric_limits::max()); } } @@ -518,11 +518,11 @@ TEST_CASE( // Now cause an underflow. aggregator.aggregate_data(input_data_lowest); aggregator.copy_to_user_buffer("Sum", buffers); - CHECK(sum == std::numeric_limits::lowest()); + CHECK(sum == std::numeric_limits::max()); // Once we underflow, the value doesn't change. aggregator.aggregate_data(input_data_max); aggregator.copy_to_user_buffer("Sum", buffers); - CHECK(sum == std::numeric_limits::lowest()); + CHECK(sum == std::numeric_limits::max()); } }