Skip to content

Commit

Permalink
Aggregators: adding mean aggregate.
Browse files Browse the repository at this point in the history
This adds a mean aggregator, which is very simple and can be improved uponst. For now it just sums everything into a double value for all types and divides by the count before returning the value to the user.

---
TYPE: IMPROVEMENT
DESC: Aggregators: adding mean aggregate.
  • Loading branch information
KiterLuc committed Aug 22, 2023
1 parent a15c06b commit fa1a45f
Show file tree
Hide file tree
Showing 11 changed files with 1,002 additions and 13 deletions.
142 changes: 142 additions & 0 deletions test/src/test-cppapi-aggregates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "tiledb/sm/c_api/tiledb_struct_def.h"
#include "tiledb/sm/cpp_api/tiledb"
#include "tiledb/sm/query/readers/aggregators/count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/sum_aggregator.h"

Expand Down Expand Up @@ -1334,6 +1335,147 @@ TEMPLATE_LIST_TEST_CASE_METHOD(
array.close();
}

typedef tuple<
uint8_t,
uint16_t,
uint32_t,
uint64_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double>
MeanFixedTypesUnderTest;
TEMPLATE_LIST_TEST_CASE_METHOD(
CppAggregatesFx,
"C++ API: Aggregates basic mean",
"[cppapi][aggregates][basic][mean]",
MeanFixedTypesUnderTest) {
typedef TestType T;
CppAggregatesFx<T>::generate_test_params();
CppAggregatesFx<T>::create_array_and_write_fragments();

Array array{
CppAggregatesFx<T>::ctx_, CppAggregatesFx<T>::ARRAY_NAME, TILEDB_READ};

for (bool set_ranges : {true, false}) {
CppAggregatesFx<T>::set_ranges_ = set_ranges;
for (bool request_data : {true, false}) {
CppAggregatesFx<T>::request_data_ = request_data;
for (bool set_qc : CppAggregatesFx<T>::set_qc_values_) {
CppAggregatesFx<T>::set_qc_ = set_qc;
for (tiledb_layout_t layout : CppAggregatesFx<T>::layout_values_) {
CppAggregatesFx<T>::layout_ = layout;
Query query(CppAggregatesFx<T>::ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
query.ptr()->query_->add_aggregator_to_default_channel(
"Mean",
std::make_shared<tiledb::sm::MeanAggregator<T>>(
tiledb::sm::FieldInfo(
"a1", false, CppAggregatesFx<T>::nullable_, 1)));

CppAggregatesFx<T>::set_ranges_and_condition_if_needed(
array, query, false);

// Set the data buffer for the aggregator.
uint64_t cell_size = sizeof(T);
std::vector<double> mean(1);
std::vector<uint8_t> mean_validity(1);
std::vector<uint64_t> dim1(100);
std::vector<uint64_t> dim2(100);
std::vector<uint8_t> a1(100 * cell_size);
std::vector<uint8_t> a1_validity(100);
query.set_layout(layout);

// TODO: Change to real CPPAPI. Use set_data_buffer from the internal
// query directly because the CPPAPI doesn't know what is an aggregate
// and what the size of an aggregate should be.
uint64_t returned_mean_size = 8;
CHECK(query.ptr()
->query_
->set_data_buffer("Mean", &mean[0], &returned_mean_size)
.ok());
uint64_t returned_validity_size = 1;
if (CppAggregatesFx<T>::nullable_) {
// TODO: Change to real CPPAPI. Use set_validity_buffer from the
// internal query directly because the CPPAPI doesn't know what is
// an aggregate and what the size of an aggregate should be.
CHECK(query.ptr()
->query_
->set_validity_buffer(
"Mean", mean_validity.data(), &returned_validity_size)
.ok());
}

if (request_data) {
query.set_data_buffer("d1", dim1);
query.set_data_buffer("d2", dim2);
query.set_data_buffer(
"a1", static_cast<void*>(a1.data()), a1.size() / cell_size);

if (CppAggregatesFx<T>::nullable_) {
query.set_validity_buffer("a1", a1_validity);
}
}

// Submit the query.
query.submit();

// Check the results.
double expected_mean;
if (CppAggregatesFx<T>::dense_) {
if (CppAggregatesFx<T>::nullable_) {
if (set_ranges) {
expected_mean = set_qc ? (197.0 / 11.0) : (201.0 / 12.0);
} else {
expected_mean = set_qc ? (315.0 / 18.0) : (319.0 / 19.0);
}
} else {
if (set_ranges) {
expected_mean = set_qc ? (398.0 / 23.0) : (402.0 / 24.0);
} else {
expected_mean = set_qc ? (591.0 / 34.0) : (630.0 / 36.0);
}
}
} else {
if (CppAggregatesFx<T>::nullable_) {
if (set_ranges) {
expected_mean = (42.0 / 4.0);
} else {
expected_mean = (56.0 / 8.0);
}
} else {
if (set_ranges) {
expected_mean = CppAggregatesFx<T>::allow_dups_ ? (88.0 / 8.0) :
(81.0 / 7.0);
} else {
expected_mean = CppAggregatesFx<T>::allow_dups_ ?
(120.0 / 16.0) :
(113.0 / 15.0);
}
}
}

// TODO: use 'std::get<1>(result_el["Mean"]) == 1' once we use
// the set_data_buffer api.
CHECK(returned_mean_size == 8);
CHECK(mean[0] == expected_mean);

if (request_data) {
CppAggregatesFx<T>::validate_data(
query, dim1, dim2, a1, a1_validity);
}
}
}
}
}

// Close array.
array.close();
}

typedef tuple<
std::pair<uint8_t, tiledb::sm::MinAggregator<uint8_t>>,
std::pair<uint16_t, tiledb::sm::MinAggregator<uint16_t>>,
Expand Down
1 change: 1 addition & 0 deletions tiledb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ set(TILEDB_CORE_SOURCES
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_condition.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_remote_buffer_storage.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/dense_reader.cc
Expand Down
8 changes: 0 additions & 8 deletions tiledb/sm/query/readers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,4 @@
include(common NO_POLICY_SCOPE)
include(object_library)

#
# `result_tile` object library
#
commence(object_library result_tile)
this_target_sources(result_tile.cc)
this_target_object_libraries(array_schema baseline tile)
conclude(object_library)

add_subdirectory(aggregators)
4 changes: 2 additions & 2 deletions tiledb/sm/query/readers/aggregators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ include(object_library)
# `aggregators` object library
#
commence(object_library aggregators)
this_target_sources(count_aggregator.cc min_max_aggregator.cc sum_aggregator.cc)
this_target_object_libraries(baseline array_schema result_tile)
this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc sum_aggregator.cc)
this_target_object_libraries(baseline array_schema)
conclude(object_library)

add_test_subdirectory()
Loading

0 comments on commit fa1a45f

Please sign in to comment.