From 447f5c8d8426cb9046fedf9beceb5f188ff62ad9 Mon Sep 17 00:00:00 2001 From: Luc Rancourt Date: Thu, 5 Oct 2023 12:32:32 +0200 Subject: [PATCH] Move counts computation to aggregate_with_count. This moves the computation for the count aggregators to aggregate_with_count. It also merges the unit test with the existing ones for other aggregators. TYPE: IMPROVEMENT DESC: Move counts computation to aggregate_with_count. --- test/src/test-cppapi-aggregates.cc | 1 - tiledb/CMakeLists.txt | 1 - .../query/readers/aggregators/CMakeLists.txt | 2 +- .../readers/aggregators/aggregate_buffer.h | 34 +- .../aggregators/aggregate_with_count.h | 54 +- .../readers/aggregators/count_aggregator.cc | 56 +- .../readers/aggregators/count_aggregator.h | 67 ++- .../readers/aggregators/mean_aggregator.cc | 11 +- .../readers/aggregators/mean_aggregator.h | 5 +- tiledb/sm/query/readers/aggregators/min_max.h | 1 - .../readers/aggregators/min_max_aggregator.cc | 9 +- .../readers/aggregators/min_max_aggregator.h | 4 +- tiledb/sm/query/readers/aggregators/no_op.h | 50 ++ .../aggregators/null_count_aggregator.cc | 116 ---- .../aggregators/null_count_aggregator.h | 136 ----- .../readers/aggregators/sum_aggregator.cc | 11 +- .../readers/aggregators/sum_aggregator.h | 5 +- .../readers/aggregators/test/CMakeLists.txt | 2 +- .../test/compile_aggregators_main.cc | 1 - .../test/unit_aggregate_with_count.cc | 62 +- .../aggregators/test/unit_aggregators.cc | 549 ++++++++++++------ .../readers/aggregators/test/unit_count.cc | 170 ------ .../aggregators/test/unit_null_count.cc | 351 ----------- .../readers/aggregators/validity_policies.h | 64 ++ 24 files changed, 656 insertions(+), 1106 deletions(-) create mode 100644 tiledb/sm/query/readers/aggregators/no_op.h delete mode 100644 tiledb/sm/query/readers/aggregators/null_count_aggregator.cc delete mode 100644 tiledb/sm/query/readers/aggregators/null_count_aggregator.h delete mode 100644 tiledb/sm/query/readers/aggregators/test/unit_count.cc delete mode 100644 tiledb/sm/query/readers/aggregators/test/unit_null_count.cc create mode 100644 tiledb/sm/query/readers/aggregators/validity_policies.h diff --git a/test/src/test-cppapi-aggregates.cc b/test/src/test-cppapi-aggregates.cc index 68969d8d951..91a1aa557b4 100644 --- a/test/src/test-cppapi-aggregates.cc +++ b/test/src/test-cppapi-aggregates.cc @@ -36,7 +36,6 @@ #include "tiledb/sm/query/readers/aggregators/count_aggregator.h" #include "tiledb/sm/query/readers/aggregators/mean_aggregator.h" #include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h" -#include "tiledb/sm/query/readers/aggregators/null_count_aggregator.h" #include "tiledb/sm/query/readers/aggregators/sum_aggregator.h" #include diff --git a/tiledb/CMakeLists.txt b/tiledb/CMakeLists.txt index 17b986ca750..6811175b0a1 100644 --- a/tiledb/CMakeLists.txt +++ b/tiledb/CMakeLists.txt @@ -255,7 +255,6 @@ set(TILEDB_CORE_SOURCES ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc - ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/output_buffer_validator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/safe_sum.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc diff --git a/tiledb/sm/query/readers/aggregators/CMakeLists.txt b/tiledb/sm/query/readers/aggregators/CMakeLists.txt index 3b8b8d72925..0ed74a6a736 100644 --- a/tiledb/sm/query/readers/aggregators/CMakeLists.txt +++ b/tiledb/sm/query/readers/aggregators/CMakeLists.txt @@ -31,7 +31,7 @@ include(object_library) # `aggregators` object library # commence(object_library aggregators) - this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc) + this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc) this_target_object_libraries(baseline array_schema) conclude(object_library) diff --git a/tiledb/sm/query/readers/aggregators/aggregate_buffer.h b/tiledb/sm/query/readers/aggregators/aggregate_buffer.h index 8e8a6cdb861..2ea411f7321 100644 --- a/tiledb/sm/query/readers/aggregators/aggregate_buffer.h +++ b/tiledb/sm/query/readers/aggregators/aggregate_buffer.h @@ -81,11 +81,6 @@ class AggregateBuffer { /* API */ /* ********************************* */ - /** Returns the validity buffer. */ - uint8_t* validity_data() const { - return validity_data_.value(); - } - /** Returns if the bitmap is a count bitmap. */ bool is_count_bitmap() const { return count_bitmap_; @@ -96,12 +91,6 @@ class AggregateBuffer { return bitmap_data_.has_value(); } - /** Returns types bitmap data. */ - template - BitmapType* bitmap_data_as() const { - return static_cast(bitmap_data_.value()); - } - /** Returns the min cell position to aggregate. */ uint64_t min_cell() const { return min_cell_; @@ -141,6 +130,29 @@ class AggregateBuffer { } } + /** + * Get the validity value at a certain cell index. + * + * @param cell_idx Cell index. + * + * @return Validity value. + */ + inline uint8_t validity_at(const uint64_t cell_idx) const { + return validity_data_.value()[cell_idx]; + } + + /** + * Get the bitmap value at a certain cell index. + * + * @param cell_idx Cell index. + * + * @return Bitmap value. + */ + template + BitmapType bitmap_at(const uint64_t cell_idx) const { + return static_cast(bitmap_data_.value())[cell_idx]; + } + private: /* ********************************* */ /* PRIVATE ATTRIBUTES */ diff --git a/tiledb/sm/query/readers/aggregators/aggregate_with_count.h b/tiledb/sm/query/readers/aggregators/aggregate_with_count.h index e18bd81b266..e22186f9cd8 100644 --- a/tiledb/sm/query/readers/aggregators/aggregate_with_count.h +++ b/tiledb/sm/query/readers/aggregators/aggregate_with_count.h @@ -37,6 +37,7 @@ #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" +#include "tiledb/sm/query/readers/aggregators/no_op.h" namespace tiledb::sm { @@ -53,7 +54,7 @@ struct type_data { typedef std::string_view value_type; }; -template +template class AggregateWithCount { public: /* ********************************* */ @@ -71,19 +72,17 @@ class AggregateWithCount { /** * Aggregate the input data. * - * @tparam AGG_T Aggregate value type. * @tparam BITMAP_T Bitmap type. - * @tparam AggPolicy Aggregation policy. * @param input_data Input data for the aggregation. * * NOTE: Count of cells returned is used to infer the validity from the * caller. * @return {Aggregate value, count of cells}. */ - template + template tuple aggregate(AggregateBuffer& input_data) { - typedef typename type_data::value_type VALUE_T; AggPolicy agg_policy; + ValidityPolicy val_policy; AGG_T res; if constexpr (std::is_same::value) { res = ""; @@ -96,17 +95,14 @@ class AggregateWithCount { // nullable. The bitmap tells us which cells was already filtered out by // ranges or query conditions. if (input_data.has_bitmap()) { - auto bitmap_values = input_data.bitmap_data_as(); - if (field_info_.is_nullable_) { - auto validity_values = input_data.validity_data(); - // Process for nullable values with bitmap. for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] != 0 && bitmap_values[c] != 0) { - AGG_T value = input_data.value_at(c); - for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { + auto bitmap_val = input_data.bitmap_at(c); + if (val_policy.op(input_data.validity_at(c)) && bitmap_val != 0) { + auto value = value_at(input_data, c); + for (BITMAP_T i = 0; i < bitmap_val; i++) { agg_policy.op(value, res, count); count++; } @@ -116,9 +112,9 @@ class AggregateWithCount { // Process for non nullable values with bitmap. for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - AGG_T value = input_data.value_at(c); - - for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { + auto bitmap_val = input_data.bitmap_at(c); + auto value = value_at(input_data, c); + for (BITMAP_T i = 0; i < bitmap_val; i++) { agg_policy.op(value, res, count); count++; } @@ -126,13 +122,11 @@ class AggregateWithCount { } } else { if (field_info_.is_nullable_) { - auto validity_values = input_data.validity_data(); - // Process for nullable values with no bitmap. for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] != 0) { - AGG_T value = input_data.value_at(c); + if (val_policy.op(input_data.validity_at(c))) { + auto value = value_at(input_data, c); agg_policy.op(value, res, count); count++; } @@ -141,7 +135,7 @@ class AggregateWithCount { // Process for non nullable values with no bitmap. for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - AGG_T value = input_data.value_at(c); + auto value = value_at(input_data, c); agg_policy.op(value, res, count); count++; } @@ -158,6 +152,26 @@ class AggregateWithCount { /** Field information. */ const FieldInfo field_info_; + + /* ********************************* */ + /* PRIVATE METHODS */ + /* ********************************* */ + + /** + * Returns the value at the specified cell if needed. + * + * @param input_data Input data. + * @param c Cell index. + * @return Value. + */ + inline AGG_T value_at(AggregateBuffer& input_data, uint64_t c) { + typedef typename type_data::value_type VALUE_T; + if constexpr (!std::is_same::value) { + return input_data.value_at(c); + } + + return AGG_T(); + } }; } // namespace tiledb::sm diff --git a/tiledb/sm/query/readers/aggregators/count_aggregator.cc b/tiledb/sm/query/readers/aggregators/count_aggregator.cc index 2b3c2289f28..1dcf824218f 100644 --- a/tiledb/sm/query/readers/aggregators/count_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/count_aggregator.cc @@ -44,12 +44,15 @@ class CountAggregatorStatusException : public StatusException { } }; -CountAggregator::CountAggregator() - : OutputBufferValidator(FieldInfo("", false, false, 1)) +template +CountAggregatorBase::CountAggregatorBase(FieldInfo field_info) + : OutputBufferValidator(field_info) + , aggregate_with_count_(field_info) , count_(0) { } -void CountAggregator::validate_output_buffer( +template +void CountAggregatorBase::validate_output_buffer( std::string output_field_name, std::unordered_map& buffers) { if (buffers.count(output_field_name) == 0) { @@ -59,33 +62,22 @@ void CountAggregator::validate_output_buffer( ensure_output_buffer_count(buffers[output_field_name]); } -template -uint64_t count_cells(AggregateBuffer& input_data) { - uint64_t ret = 0; +template +void CountAggregatorBase::aggregate_data( + AggregateBuffer& input_data) { + tuple res; - auto bitmap_data = input_data.bitmap_data_as(); - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - ret += bitmap_data[c]; - } - - return ret; -} - -void CountAggregator::aggregate_data(AggregateBuffer& input_data) { - // Run different loops for bitmap versus no bitmap. The bitmap tells us which - // cells was already filtered out by ranges or query conditions. - if (input_data.has_bitmap()) { - if (input_data.is_count_bitmap()) { - count_ += count_cells(input_data); - } else { - count_ += count_cells(input_data); - } + if (input_data.is_count_bitmap()) { + res = aggregate_with_count_.template aggregate(input_data); } else { - count_ += input_data.max_cell() - input_data.min_cell(); + res = aggregate_with_count_.template aggregate(input_data); } + + count_ += std::get<1>(res); } -void CountAggregator::copy_to_user_buffer( +template +void CountAggregatorBase::copy_to_user_buffer( std::string output_field_name, std::unordered_map& buffers) { auto& result_buffer = buffers[output_field_name]; @@ -96,4 +88,18 @@ void CountAggregator::copy_to_user_buffer( } } +NullCountAggregator::NullCountAggregator(FieldInfo field_info) + : CountAggregatorBase(field_info) + , field_info_(field_info) { + if (!field_info_.is_nullable_) { + throw CountAggregatorStatusException( + "NullCount aggregates must only be requested for nullable " + "attributes."); + } +} + +// Explicit template instantiations +template CountAggregatorBase::CountAggregatorBase(FieldInfo); +template CountAggregatorBase::CountAggregatorBase(FieldInfo); + } // namespace tiledb::sm diff --git a/tiledb/sm/query/readers/aggregators/count_aggregator.h b/tiledb/sm/query/readers/aggregators/count_aggregator.h index b891b43af58..d012c6d5b86 100644 --- a/tiledb/sm/query/readers/aggregators/count_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/count_aggregator.h @@ -33,33 +33,32 @@ #ifndef TILEDB_COUNT_AGGREGATOR_H #define TILEDB_COUNT_AGGREGATOR_H +#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" +#include "tiledb/sm/query/readers/aggregators/no_op.h" +#include "tiledb/sm/query/readers/aggregators/validity_policies.h" namespace tiledb::sm { class QueryBuffer; -class CountAggregator : public OutputBufferValidator, public IAggregator { +template +class CountAggregatorBase : public OutputBufferValidator, public IAggregator { public: /* ********************************* */ /* CONSTRUCTORS & DESTRUCTORS */ /* ********************************* */ /** Constructor. */ - CountAggregator(); + CountAggregatorBase(FieldInfo field_info); - DISABLE_COPY_AND_COPY_ASSIGN(CountAggregator); - DISABLE_MOVE_AND_MOVE_ASSIGN(CountAggregator); + DISABLE_COPY_AND_COPY_ASSIGN(CountAggregatorBase); + DISABLE_MOVE_AND_MOVE_ASSIGN(CountAggregatorBase); /* ********************************* */ /* API */ /* ********************************* */ - /** Returns the field name for the aggregator. */ - std::string field_name() override { - return constants::count_of_rows; - } - /** Returns if the aggregation is var sized or not. */ bool var_sized() override { return false; @@ -102,10 +101,60 @@ class CountAggregator : public OutputBufferValidator, public IAggregator { /* PRIVATE ATTRIBUTES */ /* ********************************* */ + /** AggregateWithCount to do summation of AggregateBuffer data. */ + AggregateWithCount + aggregate_with_count_; + /** Cell count. */ std::atomic count_; }; +class CountAggregator : public CountAggregatorBase { + public: + /* ********************************* */ + /* CONSTRUCTORS & DESTRUCTORS */ + /* ********************************* */ + + CountAggregator() + : CountAggregatorBase(FieldInfo("", false, false, 1)) { + } + + /* ********************************* */ + /* API */ + /* ********************************* */ + + /** Returns the field name for the aggregator. */ + std::string field_name() override { + return constants::count_of_rows; + } +}; + +class NullCountAggregator : public CountAggregatorBase { + public: + /* ********************************* */ + /* CONSTRUCTORS & DESTRUCTORS */ + /* ********************************* */ + + NullCountAggregator(FieldInfo field_info); + + /* ********************************* */ + /* API */ + /* ********************************* */ + + /** Returns the field name for the aggregator. */ + std::string field_name() override { + return field_info_.name_; + } + + private: + /* ********************************* */ + /* PRIVATE ATTRIBUTES */ + /* ********************************* */ + + /** Field information. */ + const FieldInfo field_info_; +}; + } // namespace tiledb::sm #endif // TILEDB_COUNT_AGGREGATOR_H diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc index ae944140d24..44ae0b329b0 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc @@ -34,7 +34,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -#include "tiledb/sm/query/readers/aggregators/safe_sum.h" namespace tiledb::sm { @@ -99,15 +98,9 @@ void MeanAggregator::aggregate_data(AggregateBuffer& input_data) { // TODO: This is duplicated across aggregates but will go away with // sc-33104. if (input_data.is_count_bitmap()) { - res = aggregate_with_count_.template aggregate< - typename sum_type_data::sum_type, - uint64_t, - SafeSum>(input_data); + res = aggregate_with_count_.template aggregate(input_data); } else { - res = aggregate_with_count_.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data); + res = aggregate_with_count_.template aggregate(input_data); } const auto value = std::get<0>(res); diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.h b/tiledb/sm/query/readers/aggregators/mean_aggregator.h index 036e94a2ce8..b2ef6527b2a 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.h @@ -37,7 +37,9 @@ #include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" +#include "tiledb/sm/query/readers/aggregators/safe_sum.h" #include "tiledb/sm/query/readers/aggregators/sum_type.h" +#include "tiledb/sm/query/readers/aggregators/validity_policies.h" namespace tiledb::sm { @@ -117,7 +119,8 @@ class MeanAggregator : public OutputBufferValidator, public IAggregator { const FieldInfo field_info_; /** AggregateWithCount to do summation of AggregateBuffer data. */ - AggregateWithCount aggregate_with_count_; + AggregateWithCount::sum_type, SafeSum, NonNull> + aggregate_with_count_; /** Computed sum. */ std::atomic::sum_type> sum_; diff --git a/tiledb/sm/query/readers/aggregators/min_max.h b/tiledb/sm/query/readers/aggregators/min_max.h index 08989f81413..8f2afdbb85f 100644 --- a/tiledb/sm/query/readers/aggregators/min_max.h +++ b/tiledb/sm/query/readers/aggregators/min_max.h @@ -44,7 +44,6 @@ struct MinMax { * @param value Value to compare against. * @param sum Computed min/max. * @param count Current count of values. - * @param */ template void op(MIN_MAX_T value, MIN_MAX_T& min_max, uint64_t count) { diff --git a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc index df924af0c9d..972e740e68b 100644 --- a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc @@ -34,7 +34,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -#include "tiledb/sm/query/readers/aggregators/min_max.h" namespace tiledb::sm { @@ -149,13 +148,9 @@ template void ComparatorAggregator::aggregate_data(AggregateBuffer& input_data) { tuple res; if (input_data.is_count_bitmap()) { - res = - aggregate_with_count_.template aggregate>( - input_data); + res = aggregate_with_count_.template aggregate(input_data); } else { - res = - aggregate_with_count_.template aggregate>( - input_data); + res = aggregate_with_count_.template aggregate(input_data); } { diff --git a/tiledb/sm/query/readers/aggregators/min_max_aggregator.h b/tiledb/sm/query/readers/aggregators/min_max_aggregator.h index e2a787e13ae..b1fe1874e11 100644 --- a/tiledb/sm/query/readers/aggregators/min_max_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/min_max_aggregator.h @@ -35,6 +35,8 @@ #include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" +#include "tiledb/sm/query/readers/aggregators/min_max.h" +#include "tiledb/sm/query/readers/aggregators/validity_policies.h" #include @@ -177,7 +179,7 @@ class ComparatorAggregator : public ComparatorAggregatorBase, /* ********************************* */ /** AggregateWithCount to do summation of AggregateBuffer data. */ - AggregateWithCount aggregate_with_count_; + AggregateWithCount, NonNull> aggregate_with_count_; /** Mutex protecting `value_`. */ std::mutex value_mtx_; diff --git a/tiledb/sm/query/readers/aggregators/no_op.h b/tiledb/sm/query/readers/aggregators/no_op.h new file mode 100644 index 00000000000..eb50929c9e1 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/no_op.h @@ -0,0 +1,50 @@ +/** + * @file no_op.h + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file defines classes NoOp. + */ + +#ifndef TILEDB_NO_OP_H +#define TILEDB_NO_OP_H + +namespace tiledb::sm { + +struct NoOp { + public: + /** + * No op function. + */ + template + void op(NO_OP_T, uint64_t, uint64_t) { + } +}; + +} // namespace tiledb::sm + +#endif // TILEDB_NO_OP_H diff --git a/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc b/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc deleted file mode 100644 index 3865713b3a9..00000000000 --- a/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * @file null_count_aggregator.cc - * - * @section LICENSE - * - * The MIT License - * - * @copyright Copyright (c) 2023 TileDB, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - * @section DESCRIPTION - * - * This file implements class NullCountAggregator. - */ - -#include "tiledb/sm/query/readers/aggregators/null_count_aggregator.h" - -#include "tiledb/sm/query/query_buffer.h" -#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" - -namespace tiledb::sm { - -class NullCountAggregatorStatusException : public StatusException { - public: - explicit NullCountAggregatorStatusException(const std::string& message) - : StatusException("NullCountAggregator", message) { - } -}; - -NullCountAggregator::NullCountAggregator(const FieldInfo field_info) - : OutputBufferValidator(field_info) - , field_info_(field_info) - , null_count_(0) { - if (!field_info_.is_nullable_) { - throw NullCountAggregatorStatusException( - "NullCount aggregates must only be requested for nullable attributes."); - } -} - -void NullCountAggregator::validate_output_buffer( - std::string output_field_name, - std::unordered_map& buffers) { - if (buffers.count(output_field_name) == 0) { - throw NullCountAggregatorStatusException("Result buffer doesn't exist."); - } - - ensure_output_buffer_count(buffers[output_field_name]); -} - -void NullCountAggregator::aggregate_data(AggregateBuffer& input_data) { - if (input_data.is_count_bitmap()) { - null_count_ += null_count(input_data); - } else { - null_count_ += null_count(input_data); - } -} - -void NullCountAggregator::copy_to_user_buffer( - std::string output_field_name, - std::unordered_map& buffers) { - auto& result_buffer = buffers[output_field_name]; - *static_cast(result_buffer.buffer_) = null_count_; - - if (result_buffer.buffer_size_) { - *result_buffer.buffer_size_ = sizeof(uint64_t); - } -} - -template -uint64_t NullCountAggregator::null_count(AggregateBuffer& input_data) { - uint64_t null_count{0}; - - // Run different loops for bitmap versus no bitmap. The bitmap tells us which - // cells was already filtered out by ranges or query conditions. - if (input_data.has_bitmap()) { - auto bitmap_values = input_data.bitmap_data_as(); - auto validity_values = input_data.validity_data(); - - // Process with bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] == 0) { - null_count += bitmap_values[c]; - } - } - } else { - auto validity_values = input_data.validity_data(); - - // Process with no bitmap. - for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { - if (validity_values[c] == 0) { - null_count++; - } - } - } - - return null_count; -} - -} // namespace tiledb::sm diff --git a/tiledb/sm/query/readers/aggregators/null_count_aggregator.h b/tiledb/sm/query/readers/aggregators/null_count_aggregator.h deleted file mode 100644 index 0535bd60c3e..00000000000 --- a/tiledb/sm/query/readers/aggregators/null_count_aggregator.h +++ /dev/null @@ -1,136 +0,0 @@ -/** - * @file null_count_aggregator.h - * - * @section LICENSE - * - * The MIT License - * - * @copyright Copyright (c) 2023 TileDB, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - * @section DESCRIPTION - * - * This file defines class NullCountAggregator. - */ - -#ifndef TILEDB_NULL_COUNT_AGGREGATOR_H -#define TILEDB_NULL_COUNT_AGGREGATOR_H - -#include "tiledb/sm/query/readers/aggregators/field_info.h" -#include "tiledb/sm/query/readers/aggregators/iaggregator.h" - -namespace tiledb::sm { - -class QueryBuffer; - -class NullCountAggregator : public OutputBufferValidator, public IAggregator { - public: - /* ********************************* */ - /* CONSTRUCTORS & DESTRUCTORS */ - /* ********************************* */ - - NullCountAggregator() = delete; - - /** - * Constructor. - * - * @param field_info Field info. - */ - NullCountAggregator(FieldInfo field_info); - - DISABLE_COPY_AND_COPY_ASSIGN(NullCountAggregator); - DISABLE_MOVE_AND_MOVE_ASSIGN(NullCountAggregator); - - /* ********************************* */ - /* API */ - /* ********************************* */ - - /** Returns the field name for the aggregator. */ - std::string field_name() override { - return field_info_.name_; - } - - /** Returns if the aggregation is var sized or not. */ - bool var_sized() override { - return false; - }; - - /** Returns if the aggregate needs to be recomputed on overflow. */ - bool need_recompute_on_overflow() override { - return true; - } - - /** - * Validate the result buffer. - * - * @param output_field_name Name for the output buffer. - * @param buffers Query buffers. - */ - void validate_output_buffer( - std::string output_field_name, - std::unordered_map& buffers) override; - - /** - * Aggregate data using the aggregator. - * - * @param input_data Input data for aggregation. - */ - void aggregate_data(AggregateBuffer& input_data) override; - - /** - * Copy final data to the user buffer. - * - * @param output_field_name Name for the output buffer. - * @param buffers Query buffers. - */ - void copy_to_user_buffer( - std::string output_field_name, - std::unordered_map& buffers) override; - - private: - /* ********************************* */ - /* PRIVATE ATTRIBUTES */ - /* ********************************* */ - - /** Field information. */ - const FieldInfo field_info_; - - /** Computed null count. */ - std::atomic null_count_; - - /* ********************************* */ - /* PRIVATE METHODS */ - /* ********************************* */ - - /** - * Add the null count of cells for the input data. - * - * @tparam BITMAP_T Bitmap type. - * @param input_data Input data for the null count. - * - * @return {Computed null count for the cells}. - */ - template - uint64_t null_count(AggregateBuffer& input_data); -}; - -} // namespace tiledb::sm - -#endif // TILEDB_NULL_COUNT_AGGREGATOR_H diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc index 50033cf3cb4..8d349ab5771 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc @@ -34,7 +34,6 @@ #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -#include "tiledb/sm/query/readers/aggregators/safe_sum.h" namespace tiledb::sm { @@ -92,15 +91,9 @@ void SumAggregator::aggregate_data(AggregateBuffer& input_data) { // TODO: This is duplicated across aggregates but will go away with // sc-33104. if (input_data.is_count_bitmap()) { - res = aggregate_with_count_.template aggregate< - typename sum_type_data::sum_type, - uint64_t, - SafeSum>(input_data); + res = aggregate_with_count_.template aggregate(input_data); } else { - res = aggregate_with_count_.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data); + res = aggregate_with_count_.template aggregate(input_data); } const auto value = std::get<0>(res); diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.h b/tiledb/sm/query/readers/aggregators/sum_aggregator.h index 06b783d10cd..24a9643146c 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.h +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.h @@ -36,7 +36,9 @@ #include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/field_info.h" #include "tiledb/sm/query/readers/aggregators/iaggregator.h" +#include "tiledb/sm/query/readers/aggregators/safe_sum.h" #include "tiledb/sm/query/readers/aggregators/sum_type.h" +#include "tiledb/sm/query/readers/aggregators/validity_policies.h" namespace tiledb::sm { @@ -116,7 +118,8 @@ class SumAggregator : public OutputBufferValidator, public IAggregator { const FieldInfo field_info_; /** AggregateWithCount to do summation of AggregateBuffer data. */ - AggregateWithCount aggregate_with_count_; + AggregateWithCount::sum_type, SafeSum, NonNull> + aggregate_with_count_; /** Computed sum. */ std::atomic::sum_type> sum_; diff --git a/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt b/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt index 16713a14d46..f16a39e0c62 100644 --- a/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt +++ b/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt @@ -27,6 +27,6 @@ include(unit_test) commence(unit_test aggregators) - this_target_sources(main.cc unit_aggregate_with_count.cc unit_aggregators.cc unit_count.cc unit_null_count.cc) + this_target_sources(main.cc unit_aggregate_with_count.cc unit_aggregators.cc) this_target_object_libraries(aggregators) conclude(unit_test) diff --git a/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc b/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc index dbf4a6c2d2d..90f7e7b8cee 100644 --- a/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc +++ b/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc @@ -30,7 +30,6 @@ #include "../field_info.h" #include "../mean_aggregator.h" #include "../min_max_aggregator.h" -#include "../null_count_aggregator.h" #include "../sum_aggregator.h" int main() { diff --git a/tiledb/sm/query/readers/aggregators/test/unit_aggregate_with_count.cc b/tiledb/sm/query/readers/aggregators/test/unit_aggregate_with_count.cc index 2f31dd4aba7..c4474928ebe 100644 --- a/tiledb/sm/query/readers/aggregators/test/unit_aggregate_with_count.cc +++ b/tiledb/sm/query/readers/aggregators/test/unit_aggregate_with_count.cc @@ -35,6 +35,7 @@ #include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h" #include "tiledb/sm/query/readers/aggregators/safe_sum.h" #include "tiledb/sm/query/readers/aggregators/sum_type.h" +#include "tiledb/sm/query/readers/aggregators/validity_policies.h" #include @@ -57,8 +58,10 @@ TEMPLATE_LIST_TEST_CASE( "[aggregate-with-count][safe-sum]", FixedTypesUnderTest) { typedef TestType T; - AggregateWithCount aggregator(FieldInfo("a1", false, false, 1)); - AggregateWithCount aggregator_nullable(FieldInfo("a2", false, true, 1)); + AggregateWithCount::sum_type, SafeSum, NonNull> + aggregator(FieldInfo("a1", false, false, 1)); + AggregateWithCount::sum_type, SafeSum, NonNull> + aggregator_nullable(FieldInfo("a2", false, true, 1)); std::vector fixed_data = {1, 2, 3, 4, 5, 5, 4, 3, 2, 1}; std::vector validity_data = {0, 0, 1, 0, 1, 0, 1, 0, 1, 0}; @@ -67,10 +70,7 @@ TEMPLATE_LIST_TEST_CASE( // Regular attribute. AggregateBuffer input_data{ 2, 10, fixed_data.data(), nullopt, nullopt, false, nullopt, 0}; - auto res = aggregator.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data); + auto res = aggregator.template aggregate(input_data); CHECK(std::get<0>(res) == 27); CHECK(std::get<1>(res) == 8); @@ -84,10 +84,8 @@ TEMPLATE_LIST_TEST_CASE( false, nullopt, 0}; - auto res_nullable = aggregator_nullable.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data2); + auto res_nullable = + aggregator_nullable.template aggregate(input_data2); CHECK(std::get<0>(res_nullable) == 14); CHECK(std::get<1>(res_nullable) == 4); } @@ -97,19 +95,13 @@ TEMPLATE_LIST_TEST_CASE( std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; AggregateBuffer input_data{ 2, 10, fixed_data.data(), nullopt, nullopt, false, bitmap.data(), 0}; - auto res = aggregator.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data); + auto res = aggregator.template aggregate(input_data); CHECK(std::get<0>(res) == 11); CHECK(std::get<1>(res) == 3); AggregateBuffer input_data2{ 0, 2, fixed_data.data(), nullopt, nullopt, false, bitmap.data(), 0}; - auto res2 = aggregator.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data2); + auto res2 = aggregator.template aggregate(input_data2); CHECK(std::get<0>(res2) == 3); CHECK(std::get<1>(res2) == 2); @@ -123,10 +115,8 @@ TEMPLATE_LIST_TEST_CASE( false, nullopt, 0}; - auto res_nullable = aggregator_nullable.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data3); + auto res_nullable = + aggregator_nullable.template aggregate(input_data3); CHECK(std::get<0>(res_nullable) == 0); CHECK(std::get<1>(res_nullable) == 0); @@ -139,10 +129,8 @@ TEMPLATE_LIST_TEST_CASE( false, bitmap.data(), 0}; - auto res_nullable2 = aggregator_nullable.template aggregate< - typename sum_type_data::sum_type, - uint8_t, - SafeSum>(input_data4); + auto res_nullable2 = + aggregator_nullable.template aggregate(input_data4); CHECK(std::get<0>(res_nullable2) == 6); CHECK(std::get<1>(res_nullable2) == 2); } @@ -159,10 +147,7 @@ TEMPLATE_LIST_TEST_CASE( true, bitmap_count.data(), 0}; - auto res = aggregator.template aggregate< - typename sum_type_data::sum_type, - uint64_t, - SafeSum>(input_data); + auto res = aggregator.template aggregate(input_data); CHECK(std::get<0>(res) == 29); CHECK(std::get<1>(res) == 10); @@ -175,10 +160,7 @@ TEMPLATE_LIST_TEST_CASE( true, bitmap_count.data(), 0}; - auto res2 = aggregator.template aggregate< - typename sum_type_data::sum_type, - uint64_t, - SafeSum>(input_data2); + auto res2 = aggregator.template aggregate(input_data2); CHECK(std::get<0>(res2) == 5); CHECK(std::get<1>(res2) == 3); @@ -192,10 +174,8 @@ TEMPLATE_LIST_TEST_CASE( true, bitmap_count.data(), 0}; - auto res_nullable = aggregator_nullable.template aggregate< - typename sum_type_data::sum_type, - uint64_t, - SafeSum>(input_data3); + auto res_nullable = + aggregator_nullable.template aggregate(input_data3); CHECK(std::get<0>(res_nullable) == 22); CHECK(std::get<1>(res_nullable) == 7); @@ -208,10 +188,8 @@ TEMPLATE_LIST_TEST_CASE( true, bitmap_count.data(), 0}; - auto res_nullable2 = aggregator_nullable.template aggregate< - typename sum_type_data::sum_type, - uint64_t, - SafeSum>(input_data4); + auto res_nullable2 = + aggregator_nullable.template aggregate(input_data4); CHECK(std::get<0>(res_nullable2) == 0); CHECK(std::get<1>(res_nullable2) == 0); } diff --git a/tiledb/sm/query/readers/aggregators/test/unit_aggregators.cc b/tiledb/sm/query/readers/aggregators/test/unit_aggregators.cc index 473f2c0e9fd..5062388c943 100644 --- a/tiledb/sm/query/readers/aggregators/test/unit_aggregators.cc +++ b/tiledb/sm/query/readers/aggregators/test/unit_aggregators.cc @@ -33,6 +33,7 @@ #include "tiledb/common/common.h" #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" +#include "tiledb/sm/query/readers/aggregators/count_aggregator.h" #include "tiledb/sm/query/readers/aggregators/mean_aggregator.h" #include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h" #include "tiledb/sm/query/readers/aggregators/sum_aggregator.h" @@ -45,9 +46,11 @@ typedef tuple< SumAggregator, MeanAggregator, MinAggregator> - AggUnderTest; + AggUnderTestConstructor; TEMPLATE_LIST_TEST_CASE( - "Aggregator: constructor", "[aggregators][constructor]", AggUnderTest) { + "Aggregator: constructor", + "[aggregators][constructor]", + AggUnderTestConstructor) { SECTION("Var size") { CHECK_THROWS_WITH( TestType(FieldInfo("a1", true, false, 1)), @@ -65,15 +68,42 @@ TEMPLATE_LIST_TEST_CASE( } } +TEST_CASE( + "NullCount aggregator: constructor", + "[null-count-aggregator][constructor]") { + SECTION("Non nullable") { + CHECK_THROWS_WITH( + NullCountAggregator(FieldInfo("a1", false, false, 1)), + "CountAggregator: NullCount aggregates must only be requested for " + "nullable attributes."); + } +} + +template +AGGREGATOR make_aggregator(FieldInfo field_info) { + return AGGREGATOR(field_info); +} + +template <> +CountAggregator make_aggregator(FieldInfo) { + return CountAggregator(); +} + +typedef tuple< + SumAggregator, + MeanAggregator, + MinAggregator, + NullCountAggregator, + CountAggregator> + AggUnderTest; TEMPLATE_LIST_TEST_CASE( "Aggregator: var sized", "[aggregator][var-sized]", AggUnderTest) { - TestType aggregator(FieldInfo("a1", false, false, 1)); + auto aggregator{make_aggregator(FieldInfo("a1", false, true, 1))}; CHECK(aggregator.var_sized() == false); if constexpr (std::is_same>::value) { - MinAggregator aggregator_nullable( - FieldInfo("a1", true, false, 1)); - CHECK(aggregator_nullable.var_sized() == true); + MinAggregator aggregator2(FieldInfo("a1", true, false, 1)); + CHECK(aggregator2.var_sized() == true); } } @@ -81,7 +111,7 @@ TEMPLATE_LIST_TEST_CASE( "Aggregators: need recompute", "[aggregators][need-recompute]", AggUnderTest) { - TestType aggregator(FieldInfo("a1", false, false, 1)); + auto aggregator{make_aggregator(FieldInfo("a1", false, true, 1))}; bool need_recompute = true; if constexpr (std::is_same>::value) { need_recompute = false; @@ -91,14 +121,23 @@ TEMPLATE_LIST_TEST_CASE( TEMPLATE_LIST_TEST_CASE( "Aggregators: field name", "[aggregator][field-name]", AggUnderTest) { - TestType aggregator(FieldInfo("a1", false, false, 1)); - CHECK(aggregator.field_name() == "a1"); + auto aggregator{make_aggregator(FieldInfo("a1", false, true, 1))}; + if (std::is_same::value) { + CHECK(aggregator.field_name() == constants::count_of_rows); + } else { + CHECK(aggregator.field_name() == "a1"); + } } +typedef tuple< + SumAggregator, + MeanAggregator, + MinAggregator> + AggUnderTestValidateBuffer; TEMPLATE_LIST_TEST_CASE( "Aggregators: Validate buffer", "[aggregators][validate-buffer]", - AggUnderTest) { + AggUnderTestValidateBuffer) { TestType aggregator(FieldInfo("a1", false, false, 1)); TestType aggregator_nullable(FieldInfo("a2", false, true, 1)); MinAggregator aggregator_var( @@ -275,6 +314,72 @@ TEMPLATE_LIST_TEST_CASE( } } +typedef tuple + AggUnderTestValidateBufferCount; +TEMPLATE_LIST_TEST_CASE( + "Aggregators: Validate buffer count", + "[aggregators][validate-buffer]", + AggUnderTestValidateBufferCount) { + auto aggregator{make_aggregator(FieldInfo("a1", false, true, 1))}; + std::unordered_map buffers; + + SECTION("Doesn't exist") { + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Count", buffers), + "CountAggregator: Result buffer doesn't exist."); + } + + SECTION("Null data buffer") { + buffers["Count"].buffer_ = nullptr; + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Count", buffers), + "OutputBufferValidator: Aggregate must have a fixed size buffer."); + } + + SECTION("Wrong size") { + uint64_t count = 0; + buffers["Count"].buffer_ = &count; + buffers["Count"].original_buffer_size_ = 1; + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Count", buffers), + "OutputBufferValidator: Aggregate fixed size buffer should be for one " + "element only."); + } + + SECTION("With var buffer") { + uint64_t count = 0; + buffers["Count"].buffer_ = &count; + buffers["Count"].original_buffer_size_ = 8; + buffers["Count"].buffer_var_ = &count; + + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Count", buffers), + "OutputBufferValidator: Aggregate must not have a var buffer."); + } + + SECTION("With validity") { + uint64_t count = 0; + buffers["Count"].buffer_ = &count; + buffers["Count"].original_buffer_size_ = 8; + + uint8_t validity = 0; + uint64_t validity_size = 1; + buffers["Count"].validity_vector_ = + ValidityVector(&validity, &validity_size); + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Count", buffers), + "OutputBufferValidator: Count aggregates must not have a validity " + "buffer."); + } + + SECTION("Success") { + uint64_t count = 0; + buffers["Count"].buffer_ = &count; + buffers["Count"].original_buffer_size_ = 8; + aggregator.validate_output_buffer("Count", buffers); + } +} + template bool is_nan(T) { return false; @@ -317,10 +422,29 @@ struct fixed_data_type { typedef char value_type; }; +template +void check_validity(uint8_t validity, uint8_t expected) { + if constexpr ( + !std::is_same::value && + !std::is_same::value) { + CHECK(validity == expected); + } +} + template void basic_aggregation_test(std::vector expected_results) { - AGGREGATOR aggregator(FieldInfo("a1", false, false, 1)); - AGGREGATOR aggregator_nullable(FieldInfo("a2", false, true, 1)); + // Optionally make the aggregator for non nullable values. + optional aggregator; + if constexpr (!std::is_same::value) { + if constexpr (std::is_same::value) { + aggregator.emplace(); + } else { + aggregator.emplace(FieldInfo("a1", false, false, 1)); + } + } + + auto aggregator_nullable{ + make_aggregator(FieldInfo("a1", false, true, 1))}; std::unordered_map buffers; @@ -340,12 +464,14 @@ void basic_aggregation_test(std::vector expected_results) { std::vector validity_data = {0, 0, 1, 0, 1, 0, 1, 0, 1, 0}; SECTION("No bitmap") { - // Regular attribute. - AggregateBuffer input_data{ - 2, 10, fixed_data.data(), nullopt, nullopt, false, nullopt, 1}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("Agg", buffers); - check_value(RES(), res, expected_results[0]); + if (aggregator.has_value()) { + // Regular attribute. + AggregateBuffer input_data{ + 2, 10, fixed_data.data(), nullopt, nullopt, false, nullopt, 1}; + aggregator->aggregate_data(input_data); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value(RES(), res, expected_results[0]); + } // Nullable attribute. AggregateBuffer input_data2{ @@ -360,23 +486,25 @@ void basic_aggregation_test(std::vector expected_results) { aggregator_nullable.aggregate_data(input_data2); aggregator_nullable.copy_to_user_buffer("Agg2", buffers); check_value(RES(), res2, expected_results[1]); - CHECK(validity == 1); + check_validity(validity, 1); } SECTION("Regular bitmap") { - // Regular attribute. std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; - AggregateBuffer input_data{ - 2, 10, fixed_data.data(), nullopt, nullopt, false, bitmap.data(), 1}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("Agg", buffers); - check_value(RES(), res, expected_results[2]); - - AggregateBuffer input_data2{ - 0, 2, fixed_data.data(), nullopt, nullopt, false, bitmap.data(), 1}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("Agg", buffers); - check_value(RES(), res, expected_results[3]); + if (aggregator.has_value()) { + // Regular attribute. + AggregateBuffer input_data{ + 2, 10, fixed_data.data(), nullopt, nullopt, false, bitmap.data(), 1}; + aggregator->aggregate_data(input_data); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value(RES(), res, expected_results[2]); + + AggregateBuffer input_data2{ + 0, 2, fixed_data.data(), nullopt, nullopt, false, bitmap.data(), 1}; + aggregator->aggregate_data(input_data2); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value(RES(), res, expected_results[3]); + } // Nullable attribute. AggregateBuffer input_data3{ @@ -401,7 +529,7 @@ void basic_aggregation_test(std::vector expected_results) { } else { check_value(RES(), res2, expected_results[4]); } - CHECK(validity == 0); + check_validity(validity, 0); AggregateBuffer input_data4{ 2, @@ -415,37 +543,39 @@ void basic_aggregation_test(std::vector expected_results) { aggregator_nullable.aggregate_data(input_data4); aggregator_nullable.copy_to_user_buffer("Agg2", buffers); check_value(RES(), res2, expected_results[5]); - CHECK(validity == 1); + check_validity(validity, 1); } SECTION("Count bitmap") { - // Regular attribute. std::vector bitmap_count = {1, 2, 4, 0, 0, 1, 2, 0, 1, 2}; - AggregateBuffer input_data{ - 2, - 10, - fixed_data.data(), - nullopt, - nullopt, - true, - bitmap_count.data(), - 1}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("Agg", buffers); - check_value(RES(), res, expected_results[6]); - - AggregateBuffer input_data2{ - 0, - 2, - fixed_data.data(), - nullopt, - nullopt, - true, - bitmap_count.data(), - 1}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("Agg", buffers); - check_value(RES(), res, expected_results[7]); + if (aggregator.has_value()) { + // Regular attribute. + AggregateBuffer input_data{ + 2, + 10, + fixed_data.data(), + nullopt, + nullopt, + true, + bitmap_count.data(), + 1}; + aggregator->aggregate_data(input_data); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value(RES(), res, expected_results[6]); + + AggregateBuffer input_data2{ + 0, + 2, + fixed_data.data(), + nullopt, + nullopt, + true, + bitmap_count.data(), + 1}; + aggregator->aggregate_data(input_data2); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value(RES(), res, expected_results[7]); + } // Nullable attribute. AggregateBuffer input_data3{ @@ -460,7 +590,7 @@ void basic_aggregation_test(std::vector expected_results) { aggregator_nullable.aggregate_data(input_data3); aggregator_nullable.copy_to_user_buffer("Agg2", buffers); check_value(RES(), res2, expected_results[8]); - CHECK(validity == 1); + check_validity(validity, 1); AggregateBuffer input_data4{ 0, @@ -474,7 +604,7 @@ void basic_aggregation_test(std::vector expected_results) { aggregator_nullable.aggregate_data(input_data4); aggregator_nullable.copy_to_user_buffer("Agg2", buffers); check_value(RES(), res2, expected_results[9]); - CHECK(validity == 1); + check_validity(validity, 1); } } @@ -501,18 +631,6 @@ TEMPLATE_LIST_TEST_CASE( SumAggregator>({27, 14, 11, 14, 0, 6, 29, 34, 22, 22}); } -typedef tuple< - uint8_t, - uint16_t, - uint32_t, - uint64_t, - int8_t, - int16_t, - int32_t, - int64_t, - float, - double> - FixedTypesUnderTest; TEMPLATE_LIST_TEST_CASE( "Mean aggregator: Basic aggregation", "[mean-aggregator][basic-aggregation]", @@ -554,11 +672,11 @@ typedef tuple< std::pair>, std::pair>, std::pair>> - FixedTypesUnderTestMinMax; + AggUnderTestMinMax; TEMPLATE_LIST_TEST_CASE( "Min max aggregator: Basic aggregation", "[min-max-aggregator][basic-aggregation]", - FixedTypesUnderTestMinMax) { + AggUnderTestMinMax) { typedef decltype(TestType::first) T; typedef decltype(TestType::second) AGG; std::vector res = {1, 2, 2, 1, 0, 2, 1, 1, 2, 2}; @@ -568,6 +686,22 @@ TEMPLATE_LIST_TEST_CASE( basic_aggregation_test(res); } +TEST_CASE( + "Count aggregator: Basic aggregation", + "[count-aggregator][basic-aggregation]") { + std::vector res = {8, 8, 3, 5, 2, 5, 10, 13, 10, 13}; + basic_aggregation_test(res); +} + +TEMPLATE_LIST_TEST_CASE( + "NullCount aggregator: Basic aggregation", + "[null-count-aggregator][basic-aggregation]", + FixedTypesUnderTest) { + typedef TestType T; + std::vector res = {0, 4, 0, 0, 2, 3, 0, 0, 3, 6}; + basic_aggregation_test(res); +} + TEST_CASE( "Sum aggregator: signed overflow", "[sum-aggregator][signed-overflow]") { SumAggregator aggregator(FieldInfo("a1", false, false, 1)); @@ -745,70 +879,74 @@ TEMPLATE_LIST_TEST_CASE( } } -void check_value_var( - uint64_t offset, - uint64_t min_max_size, - std::vector& min_max, - bool min, - std::string min_val, - std::string max_val) { - if (min) { - CHECK(min_max_size == min_val.length()); - CHECK(0 == memcmp(min_max.data(), min_val.data(), min_val.length())); - } else { - CHECK(min_max_size == max_val.length()); - CHECK(0 == memcmp(min_max.data(), max_val.data(), max_val.length())); - } +template +void check_value_string( + uint64_t fixed_data, uint64_t, std::vector&, RES expected) { + CHECK(fixed_data == expected); +} - CHECK(offset == 0); +template <> +void check_value_string( + uint64_t fixed_data, + uint64_t value_size, + std::vector& value, + std::string expected) { + CHECK(value_size == expected.length()); + CHECK(0 == memcmp(value.data(), expected.data(), expected.length())); + + CHECK(fixed_data == 0); } -typedef tuple, MinAggregator> - MinMaxAggUnderTest; -TEMPLATE_LIST_TEST_CASE( - "Min max aggregator: Basic string aggregation", - "[min-max-aggregator][basic-string-aggregation]", - MinMaxAggUnderTest) { - typedef TestType AGG; - bool min = std::is_same>::value; - AGG aggregator(FieldInfo("a1", true, false, constants::var_num)); - AGG aggregator_nullable(FieldInfo("a2", true, true, constants::var_num)); +template +void basic_string_aggregation_test(std::vector expected_results) { + // Optionally make the aggregator for non nullable values. + optional aggregator; + if constexpr (!std::is_same::value) { + aggregator.emplace(FieldInfo("a1", true, false, constants::var_num)); + } + + AGGREGATOR aggregator_nullable( + FieldInfo("a2", true, true, constants::var_num)); std::unordered_map buffers; - uint64_t offset = 11; - std::vector min_max(10, 0); - uint64_t min_max_size = 10; - buffers["MinMax"].buffer_ = &offset; - buffers["MinMax"].original_buffer_size_ = 8; - buffers["MinMax"].buffer_var_ = min_max.data(); - buffers["MinMax"].original_buffer_var_size_ = 10; - buffers["MinMax"].buffer_var_size_ = &min_max_size; - - uint64_t offset2 = 12; - std::vector min_max2(10, 0); - uint64_t min_max_size2 = 10; + uint64_t fixed_data = 11; + std::vector value(10, 0); + uint64_t value_size = 10; + buffers["Agg"].buffer_ = &fixed_data; + buffers["Agg"].original_buffer_size_ = 8; + buffers["Agg"].buffer_var_ = value.data(); + buffers["Agg"].original_buffer_var_size_ = 10; + buffers["Agg"].buffer_var_size_ = &value_size; + + uint64_t fixed_data2 = 12; + std::vector value2(10, 0); + uint64_t value_size2 = 10; uint8_t validity = 0; uint64_t validity_size = 1; - buffers["MinMax2"].buffer_ = &offset2; - buffers["MinMax2"].original_buffer_size_ = 8; - buffers["MinMax2"].buffer_var_ = min_max2.data(); - buffers["MinMax2"].original_buffer_var_size_ = 10; - buffers["MinMax2"].buffer_var_size_ = &min_max_size2; - buffers["MinMax2"].validity_vector_ = - ValidityVector(&validity, &validity_size); + buffers["Agg2"].buffer_ = &fixed_data2; + buffers["Agg2"].original_buffer_size_ = 8; + + if constexpr (!std::is_same::value) { + buffers["Agg2"].buffer_var_ = value2.data(); + buffers["Agg2"].original_buffer_var_size_ = 10; + buffers["Agg2"].buffer_var_size_ = &value_size2; + } + buffers["Agg2"].validity_vector_ = ValidityVector(&validity, &validity_size); std::vector offsets = {0, 2, 3, 6, 8, 11, 15, 16, 18, 22, 23}; std::string var_data = "11233344555555543322221"; std::vector validity_data = {0, 0, 1, 0, 1, 0, 1, 0, 1, 0}; SECTION("No bitmap") { - // Regular attribute. - AggregateBuffer input_data{ - 2, 10, offsets.data(), var_data.data(), nullopt, false, nullopt, 1}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("MinMax", buffers); - check_value_var(offset, min_max_size, min_max, min, "1", "5555"); + if (aggregator.has_value()) { + // Regular attribute. + AggregateBuffer input_data{ + 2, 10, offsets.data(), var_data.data(), nullopt, false, nullopt, 1}; + aggregator->aggregate_data(input_data); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value_string(fixed_data, value_size, value, expected_results[0]); + } // Nullable attribute. AggregateBuffer input_data2{ @@ -821,39 +959,41 @@ TEMPLATE_LIST_TEST_CASE( nullopt, 1}; aggregator_nullable.aggregate_data(input_data2); - aggregator_nullable.copy_to_user_buffer("MinMax2", buffers); - check_value_var(offset2, min_max_size2, min_max2, min, "2222", "555"); - CHECK(validity == 1); + aggregator_nullable.copy_to_user_buffer("Agg2", buffers); + check_value_string(fixed_data2, value_size2, value2, expected_results[1]); + check_validity(validity, 1); } SECTION("Regular bitmap") { - // Regular attribute. std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; - AggregateBuffer input_data{ - 2, - 10, - offsets.data(), - var_data.data(), - nullopt, - false, - bitmap.data(), - 1}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("MinMax", buffers); - check_value_var(offset, min_max_size, min_max, min, "2222", "5555"); - - AggregateBuffer input_data2{ - 0, - 2, - offsets.data(), - var_data.data(), - nullopt, - false, - bitmap.data(), - 1}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("MinMax", buffers); - check_value_var(offset, min_max_size, min_max, min, "11", "5555"); + if (aggregator.has_value()) { + // Regular attribute. + AggregateBuffer input_data{ + 2, + 10, + offsets.data(), + var_data.data(), + nullopt, + false, + bitmap.data(), + 1}; + aggregator->aggregate_data(input_data); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value_string(fixed_data, value_size, value, expected_results[2]); + + AggregateBuffer input_data2{ + 0, + 2, + offsets.data(), + var_data.data(), + nullopt, + false, + bitmap.data(), + 1}; + aggregator->aggregate_data(input_data2); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value_string(fixed_data, value_size, value, expected_results[3]); + } // Nullable attribute. AggregateBuffer input_data3{ @@ -866,9 +1006,9 @@ TEMPLATE_LIST_TEST_CASE( nullopt, 1}; aggregator_nullable.aggregate_data(input_data3); - aggregator_nullable.copy_to_user_buffer("MinMax2", buffers); - check_value_var(offset2, min_max_size2, min_max2, min, "", ""); - CHECK(validity == 0); + aggregator_nullable.copy_to_user_buffer("Agg2", buffers); + check_value_string(fixed_data2, value_size2, value2, expected_results[4]); + check_validity(validity, 0); AggregateBuffer input_data4{ 2, @@ -880,39 +1020,41 @@ TEMPLATE_LIST_TEST_CASE( bitmap.data(), 1}; aggregator_nullable.aggregate_data(input_data4); - aggregator_nullable.copy_to_user_buffer("MinMax2", buffers); - check_value_var(offset2, min_max_size2, min_max2, min, "2222", "4"); - CHECK(validity == 1); + aggregator_nullable.copy_to_user_buffer("Agg2", buffers); + check_value_string(fixed_data2, value_size2, value2, expected_results[5]); + check_validity(validity, 1); } SECTION("Count bitmap") { - // Regular attribute. std::vector bitmap_count = {1, 2, 4, 0, 0, 1, 2, 0, 1, 2}; - AggregateBuffer input_data{ - 2, - 10, - offsets.data(), - var_data.data(), - nullopt, - true, - bitmap_count.data(), - 1}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("MinMax", buffers); - check_value_var(offset, min_max_size, min_max, min, "1", "5555"); - - AggregateBuffer input_data2{ - 0, - 2, - offsets.data(), - var_data.data(), - nullopt, - true, - bitmap_count.data(), - 1}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("MinMax", buffers); - check_value_var(offset, min_max_size, min_max, min, "1", "5555"); + if (aggregator.has_value()) { + // Regular attribute. + AggregateBuffer input_data{ + 2, + 10, + offsets.data(), + var_data.data(), + nullopt, + true, + bitmap_count.data(), + 1}; + aggregator->aggregate_data(input_data); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value_string(fixed_data, value_size, value, expected_results[6]); + + AggregateBuffer input_data2{ + 0, + 2, + offsets.data(), + var_data.data(), + nullopt, + true, + bitmap_count.data(), + 1}; + aggregator->aggregate_data(input_data2); + aggregator->copy_to_user_buffer("Agg", buffers); + check_value_string(fixed_data, value_size, value, expected_results[7]); + } // Nullable attribute. AggregateBuffer input_data3{ @@ -925,9 +1067,9 @@ TEMPLATE_LIST_TEST_CASE( bitmap_count.data(), 1}; aggregator_nullable.aggregate_data(input_data3); - aggregator_nullable.copy_to_user_buffer("MinMax2", buffers); - check_value_var(offset2, min_max_size2, min_max2, min, "2222", "4"); - CHECK(validity == 1); + aggregator_nullable.copy_to_user_buffer("Agg2", buffers); + check_value_string(fixed_data2, value_size2, value2, expected_results[8]); + check_validity(validity, 1); AggregateBuffer input_data4{ 0, @@ -939,8 +1081,31 @@ TEMPLATE_LIST_TEST_CASE( bitmap_count.data(), 1}; aggregator_nullable.aggregate_data(input_data4); - aggregator_nullable.copy_to_user_buffer("MinMax2", buffers); - check_value_var(offset2, min_max_size2, min_max2, min, "2222", "4"); - CHECK(validity == 1); + aggregator_nullable.copy_to_user_buffer("Agg2", buffers); + check_value_string(fixed_data2, value_size2, value2, expected_results[9]); + check_validity(validity, 1); } } + +typedef tuple, MinAggregator> + MinMaxAggUnderTest; +TEMPLATE_LIST_TEST_CASE( + "Min max aggregator: Basic string aggregation", + "[min-max-aggregator][basic-string-aggregation]", + MinMaxAggUnderTest) { + typedef TestType AGGREGATOR; + std::vector res = { + "1", "2222", "2222", "11", "", "2222", "1", "1", "2222", "2222"}; + if (std::is_same>::value) { + res = {"5555", "555", "5555", "5555", "", "4", "5555", "5555", "4", "4"}; + } + + basic_string_aggregation_test(res); +} + +TEST_CASE( + "NullCount aggregator: Basic string aggregation", + "[null-count-aggregator][basic-string-aggregation]") { + std::vector res = {0, 4, 0, 0, 2, 3, 0, 0, 3, 6}; + basic_string_aggregation_test(res); +} diff --git a/tiledb/sm/query/readers/aggregators/test/unit_count.cc b/tiledb/sm/query/readers/aggregators/test/unit_count.cc deleted file mode 100644 index b3ce47a0757..00000000000 --- a/tiledb/sm/query/readers/aggregators/test/unit_count.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * @file unit_count.cc - * - * @section LICENSE - * - * The MIT License - * - * @copyright Copyright (c) 2023 TileDB, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - * @section DESCRIPTION - * - * Tests the `CountAggregator` class. - */ - -#include "tiledb/common/common.h" -#include "tiledb/sm/query/query_buffer.h" -#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -#include "tiledb/sm/query/readers/aggregators/count_aggregator.h" - -#include - -using namespace tiledb::sm; - -TEST_CASE("Count aggregator: var sized", "[count-aggregator][var-sized]") { - CountAggregator aggregator; - CHECK(aggregator.var_sized() == false); -} - -TEST_CASE( - "Count aggregator: need recompute", "[count-aggregator][need-recompute]") { - CountAggregator aggregator; - CHECK(aggregator.need_recompute_on_overflow() == true); -} - -TEST_CASE("Count aggregator: field name", "[count-aggregator][field-name]") { - CountAggregator aggregator; - CHECK(aggregator.field_name() == constants::count_of_rows); -} - -TEST_CASE( - "Count aggregator: Validate buffer", - "[count-aggregator][validate-buffer]") { - CountAggregator aggregator; - - std::unordered_map buffers; - - SECTION("Doesn't exist") { - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("Count", buffers), - "CountAggregator: Result buffer doesn't exist."); - } - - SECTION("Null data buffer") { - buffers["Count"].buffer_ = nullptr; - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("Count", buffers), - "OutputBufferValidator: Aggregate must have a fixed size buffer."); - } - - SECTION("Wrong size") { - uint64_t count = 0; - buffers["Count"].buffer_ = &count; - buffers["Count"].original_buffer_size_ = 1; - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("Count", buffers), - "OutputBufferValidator: Aggregate fixed size buffer should be for one " - "element only."); - } - - SECTION("With var buffer") { - uint64_t count = 0; - buffers["Count"].buffer_ = &count; - buffers["Count"].original_buffer_size_ = 8; - buffers["Count"].buffer_var_ = &count; - - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("Count", buffers), - "OutputBufferValidator: Aggregate must not have a var buffer."); - } - - SECTION("With validity") { - uint64_t count = 0; - buffers["Count"].buffer_ = &count; - buffers["Count"].original_buffer_size_ = 8; - - uint8_t validity = 0; - uint64_t validity_size = 1; - buffers["Count"].validity_vector_ = - ValidityVector(&validity, &validity_size); - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("Count", buffers), - "OutputBufferValidator: Count aggregates must not have a validity " - "buffer."); - } - - SECTION("Success") { - uint64_t count = 0; - buffers["Count"].buffer_ = &count; - buffers["Count"].original_buffer_size_ = 8; - aggregator.validate_output_buffer("Count", buffers); - } -} - -TEST_CASE( - "Count aggregator: Basic aggregation", - "[count-aggregator][basic-aggregation]") { - CountAggregator aggregator; - - std::unordered_map buffers; - - uint64_t count = 0; - buffers["Count"].buffer_ = &count; - buffers["Count"].original_buffer_size_ = 1; - - SECTION("No bitmap") { - AggregateBuffer input_data{ - 2, 10, nullptr, nullopt, nullopt, false, nullopt, 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("Count", buffers); - CHECK(count == 8); - } - - SECTION("Regular bitmap") { - std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; - AggregateBuffer input_data{ - 2, 10, nullptr, nullopt, nullopt, false, bitmap.data(), 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("Count", buffers); - CHECK(count == 3); - - AggregateBuffer input_data2{ - 0, 2, nullptr, nullopt, nullopt, false, bitmap.data(), 0}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("Count", buffers); - CHECK(count == 5); - } - - SECTION("Count bitmap") { - std::vector bitmap_count = {1, 2, 4, 0, 0, 1, 2, 0, 1, 2}; - AggregateBuffer input_data{ - 2, 10, nullptr, nullopt, nullopt, true, bitmap_count.data(), 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("Count", buffers); - CHECK(count == 10); - - AggregateBuffer input_data2{ - 0, 2, nullptr, nullopt, nullopt, true, bitmap_count.data(), 0}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("Count", buffers); - CHECK(count == 13); - } -} diff --git a/tiledb/sm/query/readers/aggregators/test/unit_null_count.cc b/tiledb/sm/query/readers/aggregators/test/unit_null_count.cc deleted file mode 100644 index 549796d2a91..00000000000 --- a/tiledb/sm/query/readers/aggregators/test/unit_null_count.cc +++ /dev/null @@ -1,351 +0,0 @@ -/** - * @file unit_null_count.cc - * - * @section LICENSE - * - * The MIT License - * - * @copyright Copyright (c) 2023 TileDB, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - * @section DESCRIPTION - * - * Tests the `NullCountAggregator` class. - */ - -#include "tiledb/common/common.h" -#include "tiledb/sm/query/query_buffer.h" -#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" -#include "tiledb/sm/query/readers/aggregators/null_count_aggregator.h" - -#include - -using namespace tiledb::sm; - -TEST_CASE( - "NullCount aggregator: constructor", - "[null-count-aggregator][constructor]") { - SECTION("Non nullable") { - CHECK_THROWS_WITH( - NullCountAggregator(FieldInfo("a1", false, false, 1)), - "NullCountAggregator: NullCount aggregates must only be requested for " - "nullable attributes."); - } -} - -TEST_CASE( - "NullCount aggregator: var sized", "[null-count-aggregator][var-sized]") { - bool var_sized = GENERATE(true, false); - NullCountAggregator aggregator(FieldInfo("a1", var_sized, true, 1)); - CHECK(aggregator.var_sized() == false); -} - -TEST_CASE( - "NullCount aggregator: need recompute", - "[null-count-aggregator][need-recompute]") { - NullCountAggregator aggregator(FieldInfo("a1", false, true, 1)); - CHECK(aggregator.need_recompute_on_overflow() == true); -} - -TEST_CASE( - "NullCount aggregator: field name", "[null-count-aggregator][field-name]") { - NullCountAggregator aggregator(FieldInfo("a1", false, true, 1)); - CHECK(aggregator.field_name() == "a1"); -} - -TEST_CASE( - "NullCount aggregator: Validate buffer", - "[null-count-aggregator][validate-buffer]") { - NullCountAggregator aggregator(FieldInfo("a1", false, true, 1)); - - std::unordered_map buffers; - - SECTION("Doesn't exist") { - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("NullCount", buffers), - "NullCountAggregator: Result buffer doesn't exist."); - } - - SECTION("Null data buffer") { - buffers["NullCount"].buffer_ = nullptr; - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("NullCount", buffers), - "OutputBufferValidator: Aggregate must have a fixed size buffer."); - } - - SECTION("Wrong size") { - uint64_t count = 0; - buffers["NullCount"].buffer_ = &count; - buffers["NullCount"].original_buffer_size_ = 1; - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("NullCount", buffers), - "OutputBufferValidator: Aggregate fixed size buffer should be for one " - "element only."); - } - - SECTION("With var buffer") { - uint64_t null_count = 0; - buffers["NullCount"].buffer_ = &null_count; - buffers["NullCount"].original_buffer_size_ = 8; - buffers["NullCount"].buffer_var_ = &null_count; - - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("NullCount", buffers), - "OutputBufferValidator: Aggregate must not have a var buffer."); - } - - SECTION("With validity") { - uint64_t null_count = 0; - buffers["NullCount"].buffer_ = &null_count; - buffers["NullCount"].original_buffer_size_ = 8; - - uint8_t validity = 0; - uint64_t validity_size = 1; - buffers["NullCount"].validity_vector_ = - ValidityVector(&validity, &validity_size); - CHECK_THROWS_WITH( - aggregator.validate_output_buffer("NullCount", buffers), - "OutputBufferValidator: Count aggregates must not have a validity " - "buffer."); - } - - SECTION("Success") { - uint64_t null_count = 0; - buffers["NullCount"].buffer_ = &null_count; - buffers["NullCount"].original_buffer_size_ = 8; - aggregator.validate_output_buffer("NullCount", buffers); - } -} - -template -std::vector make_fixed_data_nc(T) { - return {1, 2, 3, 4, 5, 5, 4, 3, 2, 1}; -} - -template <> -std::vector make_fixed_data_nc(std::string) { - return {'1', '2', '3', '4', '5', '5', '4', '3', '2', '1'}; -} - -template -struct fixed_data_type { - using type = T; - typedef T value_type; -}; - -template <> -struct fixed_data_type { - using type = std::string; - typedef char value_type; -}; - -typedef tuple< - uint8_t, - uint16_t, - uint32_t, - uint64_t, - int8_t, - int16_t, - int32_t, - int64_t, - float, - double, - std::string> - FixedTypesUnderTest; -TEMPLATE_LIST_TEST_CASE( - "NullCount aggregator: Basic aggregation", - "[null-count-aggregator][basic-aggregation]", - FixedTypesUnderTest) { - typedef TestType T; - NullCountAggregator aggregator(FieldInfo("a1", false, true, 1)); - - std::unordered_map buffers; - - uint64_t null_count = 0; - buffers["NullCount"].buffer_ = &null_count; - buffers["NullCount"].original_buffer_size_ = 8; - - auto fixed_data = - make_fixed_data_nc::value_type>(T()); - std::vector validity_data = {0, 0, 1, 0, 1, 0, 1, 0, 1, 0}; - - SECTION("No bitmap") { - AggregateBuffer input_data{ - 2, - 10, - fixed_data.data(), - nullopt, - validity_data.data(), - false, - nullopt, - 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 4); - } - - SECTION("Regular bitmap") { - std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; - AggregateBuffer input_data{ - 2, - 10, - fixed_data.data(), - nullopt, - validity_data.data(), - false, - bitmap.data(), - 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 1); - - AggregateBuffer input_data2{ - 0, - 2, - fixed_data.data(), - nullopt, - validity_data.data(), - false, - bitmap.data(), - 0}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 3); - } - - SECTION("Count bitmap") { - std::vector bitmap_count = {1, 2, 4, 0, 0, 1, 2, 0, 1, 2}; - AggregateBuffer input_data{ - 2, - 10, - fixed_data.data(), - nullopt, - validity_data.data(), - true, - bitmap_count.data(), - 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 3); - - AggregateBuffer input_data2{ - 0, - 2, - fixed_data.data(), - nullopt, - validity_data.data(), - true, - bitmap_count.data(), - 0}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 6); - } -} - -TEST_CASE( - "NullCount aggregator: Basic string aggregation", - "[null-count-aggregator][basic-string-aggregation]") { - NullCountAggregator aggregator( - FieldInfo("a1", true, true, constants::var_num)); - - std::unordered_map buffers; - - uint64_t null_count = 0; - buffers["NullCount"].buffer_ = &null_count; - buffers["NullCount"].original_buffer_size_ = 8; - - std::vector offsets = {0, 2, 3, 6, 8, 11, 15, 16, 18, 22}; - std::string var_data = "11233344555555543322221"; - std::vector validity_data = {0, 0, 1, 0, 1, 0, 1, 0, 1, 0}; - - SECTION("No bitmap") { - AggregateBuffer input_data{ - 2, - 10, - offsets.data(), - var_data.data(), - validity_data.data(), - false, - nullopt, - 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 4); - } - - SECTION("Regular bitmap") { - std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; - AggregateBuffer input_data{ - 2, - 10, - offsets.data(), - var_data.data(), - validity_data.data(), - false, - bitmap.data(), - 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 1); - - AggregateBuffer input_data2{ - 0, - 2, - offsets.data(), - var_data.data(), - validity_data.data(), - false, - bitmap.data(), - 0}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 3); - } - - SECTION("Count bitmap") { - std::vector bitmap_count = {1, 2, 4, 0, 0, 1, 2, 0, 1, 2}; - AggregateBuffer input_data{ - 2, - 10, - offsets.data(), - var_data.data(), - validity_data.data(), - true, - bitmap_count.data(), - 0}; - aggregator.aggregate_data(input_data); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 3); - - AggregateBuffer input_data2{ - 0, - 2, - offsets.data(), - var_data.data(), - validity_data.data(), - true, - bitmap_count.data(), - 0}; - aggregator.aggregate_data(input_data2); - aggregator.copy_to_user_buffer("NullCount", buffers); - CHECK(null_count == 6); - } -} \ No newline at end of file diff --git a/tiledb/sm/query/readers/aggregators/validity_policies.h b/tiledb/sm/query/readers/aggregators/validity_policies.h new file mode 100644 index 00000000000..2f0a3474ea1 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/validity_policies.h @@ -0,0 +1,64 @@ +/** + * @file validity_policies.h + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file defines classes Null and NonNull. + */ + +#ifndef TILEDB_VALIDITY_POLICIES_H +#define TILEDB_VALIDITY_POLICIES_H + +namespace tiledb::sm { + +struct Null { + public: + /** + * Validity policy for null cells. + * + * @param validity_value Validity value. + */ + inline bool op(uint8_t validity_value) { + return validity_value == 0; + } +}; + +struct NonNull { + public: + /** + * Validity policy for non null cells. + * + * @param validity_value Validity value. + */ + inline bool op(uint8_t validity_value) { + return validity_value != 0; + } +}; + +} // namespace tiledb::sm + +#endif // TILEDB_VALIDITY_POLICIES_H