diff --git a/test/src/test-cppapi-aggregates.cc b/test/src/test-cppapi-aggregates.cc index 27b0b3a3030a..2ca3d1c00464 100644 --- a/test/src/test-cppapi-aggregates.cc +++ b/test/src/test-cppapi-aggregates.cc @@ -34,6 +34,7 @@ #include "tiledb/sm/c_api/tiledb_struct_def.h" #include "tiledb/sm/cpp_api/tiledb" #include "tiledb/sm/query/readers/aggregators/count_aggregator.h" +#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h" #include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h" #include "tiledb/sm/query/readers/aggregators/sum_aggregator.h" @@ -1334,6 +1335,147 @@ TEMPLATE_LIST_TEST_CASE_METHOD( array.close(); } +typedef tuple< + uint8_t, + uint16_t, + uint32_t, + uint64_t, + int8_t, + int16_t, + int32_t, + int64_t, + float, + double> + MeanFixedTypesUnderTest; +TEMPLATE_LIST_TEST_CASE_METHOD( + CppAggregatesFx, + "C++ API: Aggregates basic mean", + "[cppapi][aggregates][basic][mean]", + MeanFixedTypesUnderTest) { + typedef TestType T; + CppAggregatesFx::generate_test_params(); + CppAggregatesFx::create_array_and_write_fragments(); + + Array array{ + CppAggregatesFx::ctx_, CppAggregatesFx::ARRAY_NAME, TILEDB_READ}; + + for (bool set_ranges : {true, false}) { + CppAggregatesFx::set_ranges_ = set_ranges; + for (bool request_data : {true, false}) { + CppAggregatesFx::request_data_ = request_data; + for (bool set_qc : CppAggregatesFx::set_qc_values_) { + CppAggregatesFx::set_qc_ = set_qc; + for (tiledb_layout_t layout : CppAggregatesFx::layout_values_) { + CppAggregatesFx::layout_ = layout; + Query query(CppAggregatesFx::ctx_, array, TILEDB_READ); + + // TODO: Change to real CPPAPI. Add a count aggregator to the query. + query.ptr()->query_->add_aggregator_to_default_channel( + "Mean", + std::make_shared>( + tiledb::sm::FieldInfo( + "a1", false, CppAggregatesFx::nullable_, 1))); + + CppAggregatesFx::set_ranges_and_condition_if_needed( + array, query, false); + + // Set the data buffer for the aggregator. + uint64_t cell_size = sizeof(T); + std::vector mean(1); + std::vector mean_validity(1); + std::vector dim1(100); + std::vector dim2(100); + std::vector a1(100 * cell_size); + std::vector a1_validity(100); + query.set_layout(layout); + + // TODO: Change to real CPPAPI. Use set_data_buffer from the internal + // query directly because the CPPAPI doesn't know what is an aggregate + // and what the size of an aggregate should be. + uint64_t returned_mean_size = 8; + CHECK(query.ptr() + ->query_ + ->set_data_buffer("Mean", &mean[0], &returned_mean_size) + .ok()); + uint64_t returned_validity_size = 1; + if (CppAggregatesFx::nullable_) { + // TODO: Change to real CPPAPI. Use set_validity_buffer from the + // internal query directly because the CPPAPI doesn't know what is + // an aggregate and what the size of an aggregate should be. + CHECK(query.ptr() + ->query_ + ->set_validity_buffer( + "Mean", mean_validity.data(), &returned_validity_size) + .ok()); + } + + if (request_data) { + query.set_data_buffer("d1", dim1); + query.set_data_buffer("d2", dim2); + query.set_data_buffer( + "a1", static_cast(a1.data()), a1.size() / cell_size); + + if (CppAggregatesFx::nullable_) { + query.set_validity_buffer("a1", a1_validity); + } + } + + // Submit the query. + query.submit(); + + // Check the results. + double expected_mean; + if (CppAggregatesFx::dense_) { + if (CppAggregatesFx::nullable_) { + if (set_ranges) { + expected_mean = set_qc ? (197.0 / 11.0) : (201.0 / 12.0); + } else { + expected_mean = set_qc ? (315.0 / 18.0) : (319.0 / 19.0); + } + } else { + if (set_ranges) { + expected_mean = set_qc ? (398.0 / 23.0) : (402.0 / 24.0); + } else { + expected_mean = set_qc ? (591.0 / 34.0) : (630.0 / 36.0); + } + } + } else { + if (CppAggregatesFx::nullable_) { + if (set_ranges) { + expected_mean = (42.0 / 4.0); + } else { + expected_mean = (56.0 / 8.0); + } + } else { + if (set_ranges) { + expected_mean = CppAggregatesFx::allow_dups_ ? (88.0 / 8.0) : + (81.0 / 7.0); + } else { + expected_mean = CppAggregatesFx::allow_dups_ ? + (120.0 / 16.0) : + (113.0 / 15.0); + } + } + } + + // TODO: use 'std::get<1>(result_el["Mean"]) == 1' once we use + // the set_data_buffer api. + CHECK(returned_mean_size == 8); + CHECK(mean[0] == expected_mean); + + if (request_data) { + CppAggregatesFx::validate_data( + query, dim1, dim2, a1, a1_validity); + } + } + } + } + } + + // Close array. + array.close(); +} + typedef tuple< std::pair>, std::pair>, diff --git a/tiledb/CMakeLists.txt b/tiledb/CMakeLists.txt index daa14b33a1f7..d4de37f944a7 100644 --- a/tiledb/CMakeLists.txt +++ b/tiledb/CMakeLists.txt @@ -253,6 +253,7 @@ set(TILEDB_CORE_SOURCES ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_condition.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_remote_buffer_storage.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc + ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc ${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/dense_reader.cc diff --git a/tiledb/sm/query/readers/CMakeLists.txt b/tiledb/sm/query/readers/CMakeLists.txt index a7462d5b2858..c8c4e1e6ae6e 100644 --- a/tiledb/sm/query/readers/CMakeLists.txt +++ b/tiledb/sm/query/readers/CMakeLists.txt @@ -27,12 +27,4 @@ include(common NO_POLICY_SCOPE) include(object_library) -# -# `result_tile` object library -# -commence(object_library result_tile) - this_target_sources(result_tile.cc) - this_target_object_libraries(array_schema baseline tile) -conclude(object_library) - add_subdirectory(aggregators) diff --git a/tiledb/sm/query/readers/aggregators/CMakeLists.txt b/tiledb/sm/query/readers/aggregators/CMakeLists.txt index 951cf0f3c7e3..8839c4d6e528 100644 --- a/tiledb/sm/query/readers/aggregators/CMakeLists.txt +++ b/tiledb/sm/query/readers/aggregators/CMakeLists.txt @@ -31,8 +31,8 @@ include(object_library) # `aggregators` object library # commence(object_library aggregators) - this_target_sources(count_aggregator.cc min_max_aggregator.cc sum_aggregator.cc) - this_target_object_libraries(baseline array_schema result_tile) + this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc sum_aggregator.cc) + this_target_object_libraries(baseline array_schema) conclude(object_library) add_test_subdirectory() diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc new file mode 100644 index 000000000000..e6ea98ffcc22 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc @@ -0,0 +1,269 @@ +/** + * @file mean_aggregator.cc + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file implements class MeanAggregator. + */ + +#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h" + +#include "tiledb/sm/query/query_buffer.h" +#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" +#include "tiledb/sm/query/readers/aggregators/sum_aggregator.h" + +using namespace tiledb::common; + +namespace tiledb { +namespace sm { + +class MeanAggregatorStatusException : public StatusException { + public: + explicit MeanAggregatorStatusException(const std::string& message) + : StatusException("MeanAggregator", message) { + } +}; + +template +MeanAggregator::MeanAggregator(const FieldInfo field_info) + : field_info_(field_info) + , sum_(0) + , count_(0) + , validity_value_( + field_info_.is_nullable_ ? std::make_optional(0) : nullopt) + , sum_overflowed_(false) { + if (field_info_.var_sized_) { + throw MeanAggregatorStatusException( + "Mean aggregates must not be requested for var sized attributes."); + } + + if (field_info_.cell_val_num_ != 1) { + throw MeanAggregatorStatusException( + "Mean aggregates must not be requested for attributes with more than " + "one value."); + } +} + +template +void MeanAggregator::validate_output_buffer( + std::string output_field_name, + std::unordered_map& buffers) { + if (buffers.count(output_field_name) == 0) { + throw MeanAggregatorStatusException("Result buffer doesn't exist."); + } + + auto& result_buffer = buffers[output_field_name]; + if (result_buffer.buffer_ == nullptr) { + throw MeanAggregatorStatusException( + "Mean aggregates must have a fixed size buffer."); + } + + if (result_buffer.buffer_var_ != nullptr) { + throw MeanAggregatorStatusException( + "Mean aggregates must not have a var buffer."); + } + + if (result_buffer.original_buffer_size_ != 8) { + throw MeanAggregatorStatusException( + "Mean aggregates fixed size buffer should be for one element only."); + } + + bool exists_validity = result_buffer.validity_vector_.buffer(); + if (field_info_.is_nullable_) { + if (!exists_validity) { + throw MeanAggregatorStatusException( + "Mean aggregates for nullable attributes must have a validity " + "buffer."); + } + + if (*result_buffer.validity_vector_.buffer_size() != 1) { + throw MeanAggregatorStatusException( + "Mean aggregates validity vector should be for one element only."); + } + } else { + if (exists_validity) { + throw MeanAggregatorStatusException( + "Mean aggregates for non nullable attributes must not have a " + "validity " + "buffer."); + } + } +} + +template +void MeanAggregator::aggregate_data(AggregateBuffer& input_data) { + tuple> res{0, 0, nullopt}; + + bool overflow = false; + try { + if (input_data.is_count_bitmap()) { + res = mean(input_data); + } else { + res = mean(input_data); + } + } catch (std::overflow_error&) { + overflow = true; + } + + { + // This might be called on multiple threads, the final result stored in sum_ + // should be computed in a thread safe manner. The mutex also protects + // sum_overflowed_ which indicates when the sum has overflowed. + std::unique_lock lock(mean_mtx_); + + // A previous operation already overflowed the sum, return. + if (sum_overflowed_) { + return; + } + + // If we have an overflow, signal it, else it's business as usual. + if (overflow) { + sum_overflowed_ = true; + sum_ = std::get<0>(res); + count_ = std::numeric_limits::max(); + return; + } else { + // This sum might overflow as well. + try { + safe_sum(std::get<0>(res), sum_); + safe_sum(std::get<1>(res), count_); + } catch (std::overflow_error&) { + sum_overflowed_ = true; + } + } + } + + if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) { + validity_value_ = 1; + } +} + +template +void MeanAggregator::copy_to_user_buffer( + std::string output_field_name, + std::unordered_map& buffers) { + auto& result_buffer = buffers[output_field_name]; + *static_cast(result_buffer.buffer_) = sum_ / count_; + + if (result_buffer.buffer_size_) { + *result_buffer.buffer_size_ = sizeof(double); + } + + if (field_info_.is_nullable_) { + *static_cast(result_buffer.validity_vector_.buffer()) = + validity_value_.value(); + + if (result_buffer.validity_vector_.buffer_size()) { + *result_buffer.validity_vector_.buffer_size() = 1; + } + } +} + +template +template +tuple> MeanAggregator::mean( + AggregateBuffer& input_data) { + double sum{0}; + uint64_t count{0}; + optional validity{nullopt}; + auto values = input_data.fixed_data_as(); + + // Run different loops for bitmap versus no bitmap and nullable versus non + // nullable. The bitmap tells us which cells was already filtered out by + // ranges or query conditions. + if (input_data.has_bitmap()) { + auto bitmap_values = input_data.bitmap_data_as(); + + if (field_info_.is_nullable_) { + validity = 0; + auto validity_values = input_data.validity_data(); + + // Process for nullable sums with bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { + if (validity_values[c] != 0 && bitmap_values[c] != 0) { + validity = 1; + + auto value = static_cast(values[c]); + for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { + count++; + safe_sum(value, sum); + } + } + } + } else { + // Process for non nullable sums with bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { + auto value = static_cast(values[c]); + + for (BITMAP_T i = 0; i < bitmap_values[c]; i++) { + count++; + safe_sum(value, sum); + } + } + } + } else { + if (field_info_.is_nullable_) { + validity = 0; + auto validity_values = input_data.validity_data(); + + // Process for nullable sums with no bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { + if (validity_values[c] != 0) { + validity = 1; + + auto value = static_cast(values[c]); + count++; + safe_sum(value, sum); + } + } + } else { + // Process for non nullable sums with no bitmap. + for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) { + auto value = static_cast(values[c]); + count++; + safe_sum(value, sum); + } + } + } + + return {sum, count, validity}; +} + +// Explicit template instantiations +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); +template MeanAggregator::MeanAggregator(const FieldInfo); + +} // namespace sm +} // namespace tiledb diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.h b/tiledb/sm/query/readers/aggregators/mean_aggregator.h new file mode 100644 index 000000000000..7f8e731747a0 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.h @@ -0,0 +1,155 @@ +/** + * @file mean_aggregator.h + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * This file defines class MeanAggregator. + */ + +#ifndef TILEDB_MEAN_AGGREGATOR_H +#define TILEDB_MEAN_AGGREGATOR_H + +#include "tiledb/common/status.h" +#include "tiledb/sm/enums/layout.h" +#include "tiledb/sm/query/readers/aggregators/field_info.h" +#include "tiledb/sm/query/readers/aggregators/iaggregator.h" + +using namespace tiledb::common; + +namespace tiledb { +namespace sm { + +class QueryBuffer; + +template +class MeanAggregator : public IAggregator { + public: + /* ********************************* */ + /* CONSTRUCTORS & DESTRUCTORS */ + /* ********************************* */ + + MeanAggregator() = delete; + + /** + * Constructor. + * + * @param field_info Field info. + */ + MeanAggregator(FieldInfo field_info); + + DISABLE_COPY_AND_COPY_ASSIGN(MeanAggregator); + DISABLE_MOVE_AND_MOVE_ASSIGN(MeanAggregator); + + /* ********************************* */ + /* API */ + /* ********************************* */ + + /** Returns the field name for the aggregator. */ + std::string field_name() override { + return field_info_.name_; + } + + /** Returns if the aggregation is var sized or not. */ + bool var_sized() override { + return false; + }; + + /** Returns if the aggregate needs to be recomputed on overflow. */ + bool need_recompute_on_overflow() override { + return true; + } + + /** + * Validate the result buffer. + * + * @param output_field_name Name for the output buffer. + * @param buffers Query buffers. + */ + void validate_output_buffer( + std::string output_field_name, + std::unordered_map& buffers) override; + + /** + * Aggregate data using the aggregator. + * + * @param input_data Input data for aggregation. + */ + void aggregate_data(AggregateBuffer& input_data) override; + + /** + * Copy final data to the user buffer. + * + * @param output_field_name Name for the output buffer. + * @param buffers Query buffers. + */ + void copy_to_user_buffer( + std::string output_field_name, + std::unordered_map& buffers) override; + + private: + /* ********************************* */ + /* PRIVATE ATTRIBUTES */ + /* ********************************* */ + + /** Field information. */ + const FieldInfo field_info_; + + /** Mutex protecting `sum_`, `sum_overflowed_` and count_. */ + std::mutex mean_mtx_; + + /** Computed sum. */ + double sum_; + + /** Count of values. */ + uint64_t count_; + + /** Computed validity value. */ + optional validity_value_; + + /** Has the sum overflowed. */ + bool sum_overflowed_; + + /* ********************************* */ + /* PRIVATE METHODS */ + /* ********************************* */ + + /** + * Add the sum/count of cells for the input data. + * + * @tparam BITMAP_T Bitmap type. + * @param input_data Input data for the mean. + * + * @return {Computed sum, count of cells, optional validity value}. + */ + template + tuple> mean(AggregateBuffer& input_data); +}; + +} // namespace sm +} // namespace tiledb + +#endif // TILEDB_MEAN_AGGREGATOR_H diff --git a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc index 8e876acf6ed8..02cba5f5cee3 100644 --- a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc @@ -32,7 +32,6 @@ #include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h" -#include "tiledb/sm/array_schema/array_schema.h" #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc index b39b9187f560..aec3411b25e4 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc @@ -32,7 +32,6 @@ #include "tiledb/sm/query/readers/aggregators/sum_aggregator.h" -#include "tiledb/sm/array_schema/array_schema.h" #include "tiledb/sm/query/query_buffer.h" #include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" diff --git a/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt b/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt index 5676902f707d..169fe3bf06be 100644 --- a/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt +++ b/tiledb/sm/query/readers/aggregators/test/CMakeLists.txt @@ -27,6 +27,6 @@ include(unit_test) commence(unit_test aggregators) - this_target_sources(main.cc unit_count.cc unit_min_max.cc unit_sum.cc) + this_target_sources(main.cc unit_count.cc unit_mean.cc unit_min_max.cc unit_sum.cc) this_target_object_libraries(aggregators) conclude(unit_test) diff --git a/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc b/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc index d3d4930f412e..fdea2106afef 100644 --- a/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc +++ b/tiledb/sm/query/readers/aggregators/test/compile_aggregators_main.cc @@ -28,12 +28,34 @@ #include "../count_aggregator.h" #include "../field_info.h" +#include "../mean_aggregator.h" #include "../min_max_aggregator.h" #include "../sum_aggregator.h" int main() { tiledb::sm::CountAggregator(); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MeanAggregator( + tiledb::sm::FieldInfo("Mean", false, false, 1)); + tiledb::sm::MinAggregator( tiledb::sm::FieldInfo("Sum", false, false, 1)); tiledb::sm::MinAggregator( diff --git a/tiledb/sm/query/readers/aggregators/test/unit_mean.cc b/tiledb/sm/query/readers/aggregators/test/unit_mean.cc new file mode 100644 index 000000000000..f04d3001d4d6 --- /dev/null +++ b/tiledb/sm/query/readers/aggregators/test/unit_mean.cc @@ -0,0 +1,410 @@ +/** + * @file unit_mean.cc + * + * @section LICENSE + * + * The MIT License + * + * @copyright Copyright (c) 2023 TileDB, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @section DESCRIPTION + * + * Tests the `MeanAggregator` class. + */ + +#include "tiledb/common/common.h" +#include "tiledb/sm/query/query_buffer.h" +#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h" +#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h" + +#include +#include + +using namespace tiledb::sm; + +TEST_CASE("Mean aggregator: constructor", "[mean-aggregator][constructor]") { + SECTION("Var size") { + CHECK_THROWS_WITH( + MeanAggregator(FieldInfo("a1", true, false, 1)), + "MeanAggregator: Mean aggregates must not be requested for var sized " + "attributes."); + } + + SECTION("Invalid cell val num") { + CHECK_THROWS_WITH( + MeanAggregator(FieldInfo("a1", false, false, 2)), + "MeanAggregator: Mean aggregates must not be requested for attributes " + "with more than one value."); + } +} + +TEST_CASE("Mean aggregator: var sized", "[mean-aggregator][var-sized]") { + MeanAggregator aggregator(FieldInfo("a1", false, false, 1)); + CHECK(aggregator.var_sized() == false); +} + +TEST_CASE( + "Mean aggregator: need recompute", "[mean-aggregator][need-recompute]") { + MeanAggregator aggregator(FieldInfo("a1", false, false, 1)); + CHECK(aggregator.need_recompute_on_overflow() == true); +} + +TEST_CASE("Mean aggregator: field name", "[mean-aggregator][field-name]") { + MeanAggregator aggregator(FieldInfo("a1", false, false, 1)); + CHECK(aggregator.field_name() == "a1"); +} + +TEST_CASE( + "Mean aggregator: Validate buffer", "[mean-aggregator][validate-buffer]") { + MeanAggregator aggregator(FieldInfo("a1", false, false, 1)); + MeanAggregator aggregator_nullable(FieldInfo("a2", false, true, 1)); + + std::unordered_map buffers; + + SECTION("Doesn't exist") { + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Mean", buffers), + "MeanAggregator: Result buffer doesn't exist."); + } + + SECTION("Null data buffer") { + buffers["Mean"].buffer_ = nullptr; + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Mean", buffers), + "MeanAggregator: Mean aggregates must have a fixed size buffer."); + } + + SECTION("Wrong size") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 1; + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Mean", buffers), + "MeanAggregator: Mean aggregates fixed size buffer should be for one " + "element only."); + } + + SECTION("With var buffer") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + buffers["Mean"].buffer_var_ = &mean; + + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Mean", buffers), + "MeanAggregator: Mean aggregates must not have a var buffer."); + } + + SECTION("With validity") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + + uint8_t validity = 0; + uint64_t validity_size = 1; + buffers["Mean"].validity_vector_ = + ValidityVector(&validity, &validity_size); + CHECK_THROWS_WITH( + aggregator.validate_output_buffer("Mean", buffers), + "MeanAggregator: Mean aggregates for non nullable attributes must not " + "have a validity buffer."); + } + + SECTION("With no validity") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + + CHECK_THROWS_WITH( + aggregator_nullable.validate_output_buffer("Mean", buffers), + "MeanAggregator: Mean aggregates for nullable attributes must have a " + "validity buffer."); + } + + SECTION("Wrong validity size") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + + uint8_t validity = 0; + uint64_t validity_size = 2; + buffers["Mean"].validity_vector_ = + ValidityVector(&validity, &validity_size); + CHECK_THROWS_WITH( + aggregator_nullable.validate_output_buffer("Mean", buffers), + "MeanAggregator: Mean aggregates validity vector should " + "be for one element only."); + } + + SECTION("Success") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + aggregator.validate_output_buffer("Mean", buffers); + } + + SECTION("Success nullable") { + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + + uint8_t validity = 0; + uint64_t validity_size = 1; + buffers["Mean"].validity_vector_ = + ValidityVector(&validity, &validity_size); + + aggregator_nullable.validate_output_buffer("Mean", buffers); + } +} + +typedef tuple< + uint8_t, + uint16_t, + uint32_t, + uint64_t, + int8_t, + int16_t, + int32_t, + int64_t, + float, + double> + FixedTypesUnderTest; +TEMPLATE_LIST_TEST_CASE( + "Mean aggregator: Basic aggregation", + "[mean-aggregator][basic-aggregation]", + FixedTypesUnderTest) { + typedef TestType T; + MeanAggregator aggregator(FieldInfo("a1", false, false, 1)); + MeanAggregator aggregator_nullable(FieldInfo("a2", false, true, 1)); + + std::unordered_map buffers; + + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + + double mean2 = 0; + uint8_t validity = 0; + uint64_t validity_size = 1; + buffers["Mean2"].buffer_ = &mean2; + buffers["Mean2"].original_buffer_size_ = 8; + buffers["Mean2"].validity_vector_ = ValidityVector(&validity, &validity_size); + + std::vector fixed_data = {1, 2, 3, 4, 5, 5, 4, 3, 2, 1}; + std::vector validity_data = {0, 0, 1, 0, 1, 0, 1, 0, 1, 0}; + + SECTION("No bitmap") { + // Regular attribute. + AggregateBuffer input_data{ + 2, 10, 10, fixed_data.data(), nullopt, 0, nullopt, false, nullopt}; + aggregator.aggregate_data(input_data); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == (27.0 / 8.0)); + + // Nullable attribute. + AggregateBuffer input_data2{ + 2, + 10, + 10, + fixed_data.data(), + nullopt, + 0, + validity_data.data(), + false, + nullopt}; + aggregator_nullable.aggregate_data(input_data2); + aggregator_nullable.copy_to_user_buffer("Mean2", buffers); + CHECK(mean2 == (14.0 / 4.0)); + CHECK(validity == 1); + } + + SECTION("Regular bitmap") { + // Regular attribute. + std::vector bitmap = {1, 1, 0, 0, 0, 1, 1, 0, 1, 0}; + AggregateBuffer input_data{ + 2, + 10, + 10, + fixed_data.data(), + nullopt, + 0, + nullopt, + false, + bitmap.data()}; + aggregator.aggregate_data(input_data); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == (11.0 / 3.0)); + + AggregateBuffer input_data2{ + 0, 2, 10, fixed_data.data(), nullopt, 0, nullopt, false, bitmap.data()}; + aggregator.aggregate_data(input_data2); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == (14.0 / 5.0)); + + // Nullable attribute. + AggregateBuffer input_data3{ + 0, + 2, + 10, + fixed_data.data(), + nullopt, + 0, + validity_data.data(), + false, + nullopt}; + aggregator_nullable.aggregate_data(input_data3); + aggregator_nullable.copy_to_user_buffer("Mean2", buffers); + CHECK(std::isnan(mean2)); + CHECK(validity == 0); + + AggregateBuffer input_data4{ + 2, + 10, + 10, + fixed_data.data(), + nullopt, + 0, + validity_data.data(), + false, + bitmap.data()}; + aggregator_nullable.aggregate_data(input_data4); + aggregator_nullable.copy_to_user_buffer("Mean2", buffers); + CHECK(mean2 == (6.0 / 2.0)); + CHECK(validity == 1); + } + + SECTION("Count bitmap") { + // Regular attribute. + std::vector bitmap_count = {1, 2, 4, 0, 0, 1, 2, 0, 1, 2}; + AggregateBuffer input_data{ + 2, + 10, + 10, + fixed_data.data(), + nullopt, + 0, + nullopt, + true, + bitmap_count.data()}; + aggregator.aggregate_data(input_data); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == (29.0 / 10.0)); + + AggregateBuffer input_data2{ + 0, + 2, + 10, + fixed_data.data(), + nullopt, + 0, + nullopt, + true, + bitmap_count.data()}; + aggregator.aggregate_data(input_data2); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == (34.0 / 13.0)); + + // Nullable attribute. + AggregateBuffer input_data3{ + 2, + 10, + 10, + fixed_data.data(), + nullopt, + 0, + validity_data.data(), + true, + bitmap_count.data()}; + aggregator_nullable.aggregate_data(input_data3); + aggregator_nullable.copy_to_user_buffer("Mean2", buffers); + CHECK(mean2 == (22.0 / 7.0)); + CHECK(validity == 1); + + AggregateBuffer input_data4{ + 0, + 2, + 10, + fixed_data.data(), + nullopt, + 0, + validity_data.data(), + true, + bitmap_count.data()}; + aggregator_nullable.aggregate_data(input_data4); + aggregator_nullable.copy_to_user_buffer("Mean2", buffers); + CHECK(mean2 == (22.0 / 7.0)); + CHECK(validity == 1); + } +} + +TEST_CASE("Mean aggregator: overflow", "[mean-aggregator][overflow]") { + MeanAggregator aggregator(FieldInfo("a1", false, false, 1)); + + std::unordered_map buffers; + + double mean = 0; + buffers["Mean"].buffer_ = &mean; + buffers["Mean"].original_buffer_size_ = 8; + + std::vector fixed_data = { + std::numeric_limits::max(), + std::numeric_limits::lowest()}; + + AggregateBuffer input_data_max{ + 0, 1, 10, fixed_data.data(), nullopt, 0, nullopt, false, nullopt}; + + AggregateBuffer input_data_lowest{ + 1, 2, 10, fixed_data.data(), nullopt, 0, nullopt, false, nullopt}; + + SECTION("Overflow") { + // First mean doesn't overflow. + aggregator.aggregate_data(input_data_max); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == std::numeric_limits::max()); + + // Now create an overflow + aggregator.aggregate_data(input_data_max); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == std::numeric_limits::max()); + + // Once we overflow, the value doesn't change. + aggregator.aggregate_data(input_data_lowest); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == std::numeric_limits::max()); + } + + SECTION("Underflow") { + // First mean doesn't underflow. + aggregator.aggregate_data(input_data_lowest); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == std::numeric_limits::lowest()); + + // Now cause an underflow. + aggregator.aggregate_data(input_data_lowest); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == std::numeric_limits::lowest()); + + // Once we underflow, the value doesn't change. + aggregator.aggregate_data(input_data_max); + aggregator.copy_to_user_buffer("Mean", buffers); + CHECK(mean == std::numeric_limits::lowest()); + } +}