Skip to content

Commit

Permalink
Aggregates CPP api support (#4438)
Browse files Browse the repository at this point in the history
All C APIs have a cpp api counterpart.

---
TYPE: NO_HISTORY | FEATURE | CPP_API
DESC: Aggregates CPP api support

---------

Co-authored-by: Luc Rancourt <[email protected]>
  • Loading branch information
robertbindar and KiterLuc authored Oct 20, 2023
1 parent c57974e commit fa5ef51
Show file tree
Hide file tree
Showing 20 changed files with 691 additions and 133 deletions.
192 changes: 119 additions & 73 deletions test/src/test-cppapi-aggregates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "test/support/src/helpers.h"
#include "tiledb/sm/c_api/tiledb_struct_def.h"
#include "tiledb/sm/cpp_api/tiledb"
#include "tiledb/sm/cpp_api/tiledb_experimental"
#include "tiledb/sm/query/readers/aggregators/count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h"
Expand Down Expand Up @@ -137,7 +138,9 @@ struct CppAggregatesFx {
template <class T>
CppAggregatesFx<T>::CppAggregatesFx()
: vfs_(ctx_) {
ctx_ = Context();
Config cfg;
cfg["sm.allow_aggregates_experimental"] = "true";
ctx_ = Context(cfg);
vfs_ = VFS(ctx_);

remove_array();
Expand Down Expand Up @@ -1172,9 +1175,10 @@ TEST_CASE_METHOD(
layout_ = layout;
Query query(ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
query.ptr()->query_->add_aggregator_to_default_channel(
"Count", std::make_shared<tiledb::sm::CountAggregator>());
// Add a count aggregator to the query.
QueryChannel default_channel =
QueryExperimental::get_default_channel(query);
default_channel.apply_aggregate("Count", CountOperation());

set_ranges_and_condition_if_needed(array, query, false);

Expand Down Expand Up @@ -1264,12 +1268,13 @@ TEMPLATE_LIST_TEST_CASE_METHOD(
CppAggregatesFx<T>::layout_ = layout;
Query query(CppAggregatesFx<T>::ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
query.ptr()->query_->add_aggregator_to_default_channel(
"Sum",
std::make_shared<tiledb::sm::SumAggregator<T>>(
tiledb::sm::FieldInfo(
"a1", false, CppAggregatesFx<T>::nullable_, 1)));
// Add a sum aggregator to the query.
QueryChannel default_channel =
QueryExperimental::get_default_channel(query);
ChannelOperation operation =
QueryExperimental::create_unary_aggregate<SumOperator>(
query, "a1");
default_channel.apply_aggregate("Sum", operation);

CppAggregatesFx<T>::set_ranges_and_condition_if_needed(
array, query, false);
Expand Down Expand Up @@ -1385,7 +1390,7 @@ TEMPLATE_LIST_TEST_CASE_METHOD(
CppAggregatesFx<T>::layout_ = layout;
Query query(CppAggregatesFx<T>::ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
// TODO: Change to real CPPAPI. Add a mean aggregator to the query.
query.ptr()->query_->add_aggregator_to_default_channel(
"Mean",
std::make_shared<tiledb::sm::MeanAggregator<T>>(
Expand Down Expand Up @@ -1476,28 +1481,28 @@ TEMPLATE_LIST_TEST_CASE_METHOD(
}

typedef tuple<
std::pair<uint8_t, tiledb::sm::MinAggregator<uint8_t>>,
std::pair<uint16_t, tiledb::sm::MinAggregator<uint16_t>>,
std::pair<uint32_t, tiledb::sm::MinAggregator<uint32_t>>,
std::pair<uint64_t, tiledb::sm::MinAggregator<uint64_t>>,
std::pair<int8_t, tiledb::sm::MinAggregator<int8_t>>,
std::pair<int16_t, tiledb::sm::MinAggregator<int16_t>>,
std::pair<int32_t, tiledb::sm::MinAggregator<int32_t>>,
std::pair<int64_t, tiledb::sm::MinAggregator<int64_t>>,
std::pair<float, tiledb::sm::MinAggregator<float>>,
std::pair<double, tiledb::sm::MinAggregator<double>>,
std::pair<std::string, tiledb::sm::MinAggregator<std::string>>,
std::pair<uint8_t, tiledb::sm::MaxAggregator<uint8_t>>,
std::pair<uint16_t, tiledb::sm::MaxAggregator<uint16_t>>,
std::pair<uint32_t, tiledb::sm::MaxAggregator<uint32_t>>,
std::pair<uint64_t, tiledb::sm::MaxAggregator<uint64_t>>,
std::pair<int8_t, tiledb::sm::MaxAggregator<int8_t>>,
std::pair<int16_t, tiledb::sm::MaxAggregator<int16_t>>,
std::pair<int32_t, tiledb::sm::MaxAggregator<int32_t>>,
std::pair<int64_t, tiledb::sm::MaxAggregator<int64_t>>,
std::pair<float, tiledb::sm::MaxAggregator<float>>,
std::pair<double, tiledb::sm::MaxAggregator<double>>,
std::pair<std::string, tiledb::sm::MaxAggregator<std::string>>>
std::pair<uint8_t, MinOperator>,
std::pair<uint16_t, MinOperator>,
std::pair<uint32_t, MinOperator>,
std::pair<uint64_t, MinOperator>,
std::pair<int8_t, MinOperator>,
std::pair<int16_t, MinOperator>,
std::pair<int32_t, MinOperator>,
std::pair<int64_t, MinOperator>,
std::pair<float, MinOperator>,
std::pair<double, MinOperator>,
std::pair<std::string, MinOperator>,
std::pair<uint8_t, MaxOperator>,
std::pair<uint16_t, MaxOperator>,
std::pair<uint32_t, MaxOperator>,
std::pair<uint64_t, MaxOperator>,
std::pair<int8_t, MaxOperator>,
std::pair<int16_t, MaxOperator>,
std::pair<int32_t, MaxOperator>,
std::pair<int64_t, MaxOperator>,
std::pair<float, MaxOperator>,
std::pair<double, MaxOperator>,
std::pair<std::string, MaxOperator>>
MinMaxFixedTypesUnderTest;
TEMPLATE_LIST_TEST_CASE(
"C++ API: Aggregates basic min/max",
Expand All @@ -1506,7 +1511,7 @@ TEMPLATE_LIST_TEST_CASE(
typedef decltype(TestType::first) T;
typedef decltype(TestType::second) AGG;
CppAggregatesFx<T> fx;
bool min = std::is_same<AGG, tiledb::sm::MinAggregator<T>>::value;
bool min = std::is_same<AGG, MinOperator>::value;
fx.generate_test_params();
fx.create_array_and_write_fragments();

Expand All @@ -1522,13 +1527,12 @@ TEMPLATE_LIST_TEST_CASE(
fx.layout_ = layout;
Query query(fx.ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
uint64_t cell_val_num =
std::is_same<T, std::string>::value ? fx.STRING_CELL_VAL_NUM : 1;
query.ptr()->query_->add_aggregator_to_default_channel(
"MinMax",
std::make_shared<AGG>(tiledb::sm::FieldInfo(
"a1", false, fx.nullable_, cell_val_num)));
// Add a min/max aggregator to the query.
QueryChannel default_channel =
QueryExperimental::get_default_channel(query);
ChannelOperation operation =
QueryExperimental::create_unary_aggregate<AGG>(query, "a1");
default_channel.apply_aggregate("MinMax", operation);

fx.set_ranges_and_condition_if_needed(array, query, false);

Expand All @@ -1543,7 +1547,17 @@ TEMPLATE_LIST_TEST_CASE(
std::vector<uint8_t> a1(100 * cell_size);
std::vector<uint8_t> a1_validity(100);
query.set_layout(layout);
query.set_data_buffer("MinMax", min_max.data(), min_max.size());
if constexpr (std::is_same<T, std::string>::value) {
query.set_data_buffer(
"MinMax",
static_cast<char*>(static_cast<void*>(min_max.data())),
min_max.size());
} else {
query.set_data_buffer(
"MinMax",
static_cast<T*>(static_cast<void*>(min_max.data())),
min_max.size() / cell_size);
}
if (fx.nullable_) {
query.set_validity_buffer("MinMax", min_max_validity);
}
Expand Down Expand Up @@ -1597,7 +1611,9 @@ TEMPLATE_LIST_TEST_CASE(
}

auto result_el = query.result_buffer_elements_nullable();
CHECK(std::get<1>(result_el["MinMax"]) == min_max.size());
CHECK(
std::get<1>(result_el["MinMax"]) ==
(std::is_same<T, std::string>::value ? 2 : 1));
CHECK(min_max == expected_min_max);

if (request_data) {
Expand All @@ -1612,17 +1628,14 @@ TEMPLATE_LIST_TEST_CASE(
array.close();
}

typedef tuple<
tiledb::sm::MinAggregator<std::string>,
tiledb::sm::MaxAggregator<std::string>>
AggUnderTest;
typedef tuple<MinOperator, MaxOperator> AggUnderTest;
TEMPLATE_LIST_TEST_CASE(
"C++ API: Aggregates basic min/max var",
"[cppapi][aggregates][basic][min-max][var]",
AggUnderTest) {
CppAggregatesFx<std::string> fx;
typedef TestType AGG;
bool min = std::is_same<AGG, tiledb::sm::MinAggregator<std::string>>::value;
bool min = std::is_same<AGG, MinOperator>::value;
fx.generate_test_params();
fx.create_var_array_and_write_fragments();

Expand All @@ -1638,15 +1651,12 @@ TEMPLATE_LIST_TEST_CASE(
fx.layout_ = layout;
Query query(fx.ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
query.ptr()->query_->add_aggregator_to_default_channel(
"MinMax",
std::make_shared<AGG>(tiledb::sm::FieldInfo(
"a1",
true,
fx.nullable_,
TILEDB_VAR_NUM,
tiledb::sm::Datatype::STRING_ASCII)));
// Add a min/max aggregator to the query.
QueryChannel default_channel =
QueryExperimental::get_default_channel(query);
ChannelOperation operation =
QueryExperimental::create_unary_aggregate<AGG>(query, "a1");
default_channel.apply_aggregate("MinMax", operation);

fx.set_ranges_and_condition_if_needed(array, query, true);

Expand Down Expand Up @@ -1776,7 +1786,8 @@ TEMPLATE_LIST_TEST_CASE_METHOD(
CppAggregatesFx<T>::layout_ = layout;
Query query(CppAggregatesFx<T>::ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
// TODO: Change to real CPPAPI. Add a null count aggregator to the
// query.
uint64_t cell_val_num = std::is_same<T, std::string>::value ?
CppAggregatesFx<T>::STRING_CELL_VAL_NUM :
1;
Expand Down Expand Up @@ -1875,7 +1886,8 @@ TEST_CASE_METHOD(
layout_ = layout;
Query query(ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
// TODO: Change to real CPPAPI. Add a null count aggregator to the
// query.
query.ptr()->query_->add_aggregator_to_default_channel(
"NullCount",
std::make_shared<tiledb::sm::NullCountAggregator>(
Expand Down Expand Up @@ -1970,7 +1982,8 @@ TEST_CASE_METHOD(
layout_ = layout;
Query query(ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
// TODO: Change to real CPPAPI. Add a null count aggregator to the
// query.
query.ptr()->query_->add_aggregator_to_default_channel(
"NullCount",
std::make_shared<tiledb::sm::NullCountAggregator>(
Expand Down Expand Up @@ -2107,7 +2120,8 @@ TEST_CASE_METHOD(
layout_ = layout;
Query query(ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
// TODO: Change to real CPPAPI. Add a null count aggregator to the
// query.
query.ptr()->query_->add_aggregator_to_default_channel(
"NullCount",
std::make_shared<tiledb::sm::NullCountAggregator>(
Expand Down Expand Up @@ -2184,18 +2198,14 @@ TEMPLATE_LIST_TEST_CASE_METHOD(
for (tiledb_layout_t layout : CppAggregatesFx<T>::layout_values_) {
CppAggregatesFx<T>::layout_ = layout;
Query query(CppAggregatesFx<T>::ctx_, array, TILEDB_READ);

// TODO: Change to real CPPAPI. Add a count aggregator to the query.
// We add both sum and count as they are processed separately in the
// dense case.
query.ptr()->query_->add_aggregator_to_default_channel(
"Count", std::make_shared<tiledb::sm::CountAggregator>());

query.ptr()->query_->add_aggregator_to_default_channel(
"Sum",
std::make_shared<tiledb::sm::SumAggregator<T>>(
tiledb::sm::FieldInfo(
"a1", false, CppAggregatesFx<T>::nullable_, 1)));
// Add a count aggregator to the query. We add both sum and count as
// they are processed separately in the dense case.
QueryChannel default_channel =
QueryExperimental::get_default_channel(query);
default_channel.apply_aggregate("Count", CountOperation());
ChannelOperation operation2 =
QueryExperimental::create_unary_aggregate<SumOperator>(query, "a1");
default_channel.apply_aggregate("Sum", operation2);

CppAggregatesFx<T>::set_ranges_and_condition_if_needed(
array, query, false);
Expand Down Expand Up @@ -2317,4 +2327,40 @@ TEMPLATE_LIST_TEST_CASE_METHOD(

// Close array.
array.close();
}

TEST_CASE_METHOD(
CppAggregatesFx<int32_t>,
"CPP: Aggregates - Basic",
"[aggregates][cpp_api][args]") {
dense_ = false;
nullable_ = false;
allow_dups_ = false;
create_array_and_write_fragments();

Array array{ctx_, ARRAY_NAME, TILEDB_READ};
Query query(ctx_, array);
query.set_layout(TILEDB_UNORDERED);

// This throws for and attribute that doesn't exist
CHECK_THROWS(QueryExperimental::create_unary_aggregate<SumOperator>(
query, "nonexistent"));

QueryChannel default_channel = QueryExperimental::get_default_channel(query);
ChannelOperation operation =
QueryExperimental::create_unary_aggregate<SumOperator>(query, "a1");
default_channel.apply_aggregate("Sum", operation);

// Duplicated output fields are not allowed
CHECK_THROWS(default_channel.apply_aggregate("Sum", operation));

// Transition the query state
int64_t sum = 0;
query.set_data_buffer("Sum", &sum, 1);
REQUIRE(query.submit() == Query::Status::COMPLETE);

// Check api throws if the query state is already >= initialized
CHECK_THROWS(
QueryExperimental::create_unary_aggregate<SumOperator>(query, "a1"));
CHECK_THROWS(default_channel.apply_aggregate("Something", operation));
}
2 changes: 1 addition & 1 deletion test/support/src/vfs_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,10 @@ struct TemporaryDirectoryFixture {
/** Name of the temporary directory to use for this test */
std::string temp_dir_;

private:
/** Virtual file system */
tiledb_vfs_t* vfs_;

private:
/** Vector of supported filesystems. Used to initialize ``vfs_``. */
const std::vector<std::unique_ptr<SupportedFs>> supported_filesystems_;
};
Expand Down
2 changes: 1 addition & 1 deletion tiledb/api/c_api/data_order/test/unit_capi_data_order.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file tiledeb/api/c_api/data_order/test/unit_capi_data_order.cc
* @file tiledb/api/c_api/data_order/test/unit_capi_data_order.cc
*
* @section LICENSE
*
Expand Down
2 changes: 1 addition & 1 deletion tiledb/api/c_api/datatype/test/unit_capi_datatype.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file tiledeb/api/c_api/datatype/test/unit_capi_datatype.cc
* @file tiledb/api/c_api/datatype/test/unit_capi_datatype.cc
*
* @section LICENSE
*
Expand Down
2 changes: 1 addition & 1 deletion tiledb/api/c_api/filesystem/test/unit_capi_filesystem.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file tiledeb/api/c_api/filesystem/test/unit_capi_filesystem.cc
* @file tiledb/api/c_api/filesystem/test/unit_capi_filesystem.cc
*
* @section LICENSE
*
Expand Down
2 changes: 1 addition & 1 deletion tiledb/api/c_api/object/test/unit_capi_object_type.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file tiledeb/api/c_api/object/test/unit_capi_object_type.cc
* @file tiledb/api/c_api/object/test/unit_capi_object_type.cc
*
* @section LICENSE
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file tiledeb/api/c_api/object/test/unit_capi_object.cc
* @file tiledb/api/c_api/object/test/unit_capi_object.cc
*
* @section LICENSE
*
Expand Down
Loading

0 comments on commit fa5ef51

Please sign in to comment.