Skip to content

Commit

Permalink
Query aggregates REST support (#4415)
Browse files Browse the repository at this point in the history
Add rest support for query aggregates.

Notes to reviewers:
- At this stage we should pay special attention to the capnp
specification, make sure it's as generic as possible according to the
aggregates design document.
- I tested this manually with a local rest server and serialization
paths seem to work as expected, but overall the feature doesn't
currently work because we need Go support on the REST side which I'll
file as followup (When the REST server deserializes the query, the query
buffers are set to `nullptr` and their sizes to the correct sizes passed
by the user. The Go code is supposed to allocate memory and set the
buffers again before the query is submitted "locally". The tests work
because in the serialization wrappers we emulate this allocation with
some stack allocated buffers, see changes in `helpers.cc`)

---
TYPE: NO_HISTORY | FEATURE
DESC: Query aggregates REST support

---------

Co-authored-by: Luc Rancourt <[email protected]>
Co-authored-by: KiterLuc <[email protected]>
  • Loading branch information
3 people authored Oct 20, 2023
1 parent fa5ef51 commit 831b5b0
Show file tree
Hide file tree
Showing 31 changed files with 1,924 additions and 328 deletions.
28 changes: 24 additions & 4 deletions test/support/src/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1850,7 +1850,9 @@ int submit_query_wrapper(
// 4. Server -> Client : Send query response
std::vector<uint8_t> serialized2;
rc = serialize_query(server_ctx, server_deser_query, &serialized2, 0);
REQUIRE(rc == TILEDB_OK);
if (rc != TILEDB_OK) {
return rc;
}

if (!refactored_query_v2) {
// Close array and clean up
Expand Down Expand Up @@ -1950,13 +1952,31 @@ void allocate_query_buffers_server_side(
tiledb_query_t* query,
ServerQueryBuffers& query_buffers) {
int rc = 0;
const auto buffer_names = query->query_->buffer_names();
auto buffer_names = query->query_->buffer_names();
const auto aggregate_names = query->query_->aggregate_buffer_names();
buffer_names.insert(
buffer_names.end(), aggregate_names.begin(), aggregate_names.end());

for (uint64_t i = 0; i < buffer_names.size(); i++) {
const auto& name = buffer_names[i];
const auto& buff = query->query_->buffer(name);
const auto& schema = query->query_->array_schema();
auto var_size = schema.var_size(name);
auto nullable = schema.is_nullable(name);

// TODO: This is yet another instance where there needs to be a common
// mechanism for reporting the common properties of a field.
// Refactor to use query_field_t.
bool var_size = false;
bool nullable = false;
if (query->query_->is_aggregate(name)) {
var_size =
query->query_->get_aggregate(name).value()->aggregation_var_sized();
nullable =
query->query_->get_aggregate(name).value()->aggregation_nullable();
} else {
var_size = schema.var_size(name);
nullable = schema.is_nullable(name);
}

if (var_size && buff.buffer_var_ == nullptr) {
// Variable-sized buffer
query_buffers.attr_or_dim_data.emplace_back(*buff.buffer_var_size_);
Expand Down
4 changes: 4 additions & 0 deletions tiledb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ set(TILEDB_CORE_SOURCES
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/operation.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/output_buffer_validator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/safe_sum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc
Expand Down Expand Up @@ -286,6 +287,7 @@ set(TILEDB_CORE_SOURCES
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/fragments.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/group.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/query.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/query_aggregates.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/consolidation.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/vacuum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/stats/global_stats.cc
Expand Down Expand Up @@ -338,6 +340,7 @@ if (TILEDB_SERIALIZATION)
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/fragments.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/group.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/query.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/query_aggregates.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/consolidation.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/vacuum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/tiledb-rest.capnp.c++
Expand All @@ -359,6 +362,7 @@ if (TILEDB_SERIALIZATION)
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/fragments.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/group.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/query.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/query_aggregates.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/consolidation.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/vacuum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/serialization/tiledb-rest.capnp.c++
Expand Down
106 changes: 11 additions & 95 deletions tiledb/api/c_api/query_aggregate/query_aggregate_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,21 @@
#include "tiledb/api/c_api/query/query_api_internal.h"
#include "tiledb/api/c_api/query_field/query_field_api_external_experimental.h"
#include "tiledb/api/c_api_support/c_api_support.h"
#include "tiledb/type/apply_with_type.h"

//
const tiledb_channel_operator_handle_t* tiledb_channel_operator_sum =
tiledb_channel_operator_handle_t::make_handle(
TILEDB_QUERY_CHANNEL_OPERATOR_SUM, "SUM");
tiledb::sm::constants::aggregate_sum_str);
const tiledb_channel_operator_handle_t* tiledb_channel_operator_min =
tiledb_channel_operator_handle_t::make_handle(
TILEDB_QUERY_CHANNEL_OPERATOR_MIN, "MIN");
tiledb::sm::constants::aggregate_min_str);
const tiledb_channel_operator_handle_t* tiledb_channel_operator_max =
tiledb_channel_operator_handle_t::make_handle(
TILEDB_QUERY_CHANNEL_OPERATOR_MAX, "MAX");
tiledb::sm::constants::aggregate_max_str);

const tiledb_channel_operation_handle_t* tiledb_aggregate_count =
tiledb_channel_operation_handle_t::make_handle(
std::make_shared<CountOperation>());

shared_ptr<Operation> tiledb_channel_operator_handle_t::make_operation(
const tiledb::sm::FieldInfo& fi) const {
switch (this->value()) {
case TILEDB_QUERY_CHANNEL_OPERATOR_SUM: {
return std::make_shared<SumOperation>(fi, this);
}
case TILEDB_QUERY_CHANNEL_OPERATOR_MIN: {
return std::make_shared<MinOperation>(fi, this);
}
case TILEDB_QUERY_CHANNEL_OPERATOR_MAX: {
return std::make_shared<MaxOperation>(fi, this);
}
default:
throw tiledb::api::CAPIStatusException(
"operator has unsupported value: " +
std::to_string(static_cast<uint8_t>(this->value())));
break;
}
}
std::make_shared<tiledb::sm::CountOperation>());

namespace tiledb::api {

Expand Down Expand Up @@ -260,6 +240,12 @@ capi_return_t tiledb_query_channel_free(

} // namespace tiledb::api

shared_ptr<tiledb::sm::Operation>
tiledb_channel_operator_handle_t::make_operation(
const tiledb::sm::FieldInfo& fi) const {
return tiledb::sm::Operation::make_operation(this->name(), fi);
}

using tiledb::api::api_entry_with_context;

capi_return_t tiledb_channel_operator_sum_get(
Expand Down Expand Up @@ -324,73 +310,3 @@ capi_return_t tiledb_query_channel_free(
return tiledb::api::api_entry_with_context<
tiledb::api::tiledb_query_channel_free>(ctx, channel);
}

MaxOperation::MaxOperation(
const tiledb::sm::FieldInfo& fi,
const tiledb_channel_operator_handle_t* op) {
auto g = [&](auto T) {
if constexpr (tiledb::type::TileDBFundamental<decltype(T)>) {
if constexpr (tiledb::type::TileDBNumeric<decltype(T)>) {
ensure_aggregate_numeric_field(op, fi);
}

// This is a min/max on strings, should be refactored out once we
// change (STRING_ASCII,CHAR) mapping in apply_with_type
if constexpr (std::is_same_v<char, decltype(T)>) {
aggregator_ =
std::make_shared<tiledb::sm::MaxAggregator<std::string>>(fi);
} else {
aggregator_ =
std::make_shared<tiledb::sm::MaxAggregator<decltype(T)>>(fi);
}
} else {
throw std::logic_error(
"MAX aggregates can only be requested on numeric and string "
"types");
}
};
apply_with_type(g, fi.type_);
}

MinOperation::MinOperation(
const tiledb::sm::FieldInfo& fi,
const tiledb_channel_operator_handle_t* op) {
auto g = [&](auto T) {
if constexpr (tiledb::type::TileDBFundamental<decltype(T)>) {
if constexpr (tiledb::type::TileDBNumeric<decltype(T)>) {
ensure_aggregate_numeric_field(op, fi);
}

// This is a min/max on strings, should be refactored out once we
// change (STRING_ASCII,CHAR) mapping in apply_with_type
if constexpr (std::is_same_v<char, decltype(T)>) {
aggregator_ =
std::make_shared<tiledb::sm::MinAggregator<std::string>>(fi);
} else {
aggregator_ =
std::make_shared<tiledb::sm::MinAggregator<decltype(T)>>(fi);
}
} else {
throw std::logic_error(
"MIN aggregates can only be requested on numeric and string "
"types");
}
};
apply_with_type(g, fi.type_);
}

SumOperation::SumOperation(
const tiledb::sm::FieldInfo& fi,
const tiledb_channel_operator_handle_t* op) {
auto g = [&](auto T) {
if constexpr (tiledb::type::TileDBNumeric<decltype(T)>) {
ensure_aggregate_numeric_field(op, fi);
aggregator_ =
std::make_shared<tiledb::sm::SumAggregator<decltype(T)>>(fi);
} else {
throw std::logic_error(
"SUM aggregates can only be requested on numeric types");
}
};
apply_with_type(g, fi.type_);
}
90 changes: 6 additions & 84 deletions tiledb/api/c_api/query_aggregate/query_aggregate_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,64 +38,7 @@
#include "tiledb/api/c_api_support/handle/handle.h"
#include "tiledb/sm/c_api/tiledb_struct_def.h"
#include "tiledb/sm/query/query.h"
#include "tiledb/sm/query/readers/aggregators/count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/sum_aggregator.h"

enum QueryChannelOperator {
TILEDB_QUERY_CHANNEL_OPERATOR_COUNT = 0,
TILEDB_QUERY_CHANNEL_OPERATOR_SUM,
TILEDB_QUERY_CHANNEL_OPERATOR_MIN,
TILEDB_QUERY_CHANNEL_OPERATOR_MAX
};

class Operation {
protected:
shared_ptr<tiledb::sm::IAggregator> aggregator_ = nullptr;

public:
[[nodiscard]] virtual shared_ptr<tiledb::sm::IAggregator> aggregator() const {
return aggregator_;
}

virtual ~Operation(){};
};

class MinOperation : public Operation {
public:
MinOperation() = delete;
explicit MinOperation(
const tiledb::sm::FieldInfo& fi,
const tiledb_channel_operator_handle_t* op);
};

class MaxOperation : public Operation {
public:
MaxOperation() = delete;
explicit MaxOperation(
const tiledb::sm::FieldInfo& fi,
const tiledb_channel_operator_handle_t* op);
};

class SumOperation : public Operation {
public:
SumOperation() = delete;
explicit SumOperation(
const tiledb::sm::FieldInfo& fi,
const tiledb_channel_operator_handle_t* op);
};

class CountOperation : public Operation {
public:
CountOperation() = default;

// For count operations we have a constant handle, create the aggregator when
// requested so that we get a different object for each query.
[[nodiscard]] shared_ptr<tiledb::sm::IAggregator> aggregator()
const override {
return std::make_shared<tiledb::sm::CountAggregator>();
}
};
#include "tiledb/sm/query/readers/aggregators/operation.h"

struct tiledb_channel_operation_handle_t
: public tiledb::api::CAPIHandle<tiledb_channel_operation_handle_t> {
Expand All @@ -106,7 +49,7 @@ struct tiledb_channel_operation_handle_t
"tiledb_channel_operation_t"};

private:
std::shared_ptr<Operation> operation_;
std::shared_ptr<tiledb::sm::Operation> operation_;

public:
/**
Expand All @@ -119,7 +62,7 @@ struct tiledb_channel_operation_handle_t
* @param operation An internal operation object
*/
explicit tiledb_channel_operation_handle_t(
const shared_ptr<Operation>& operation)
const shared_ptr<tiledb::sm::Operation>& operation)
: operation_{operation} {
}

Expand Down Expand Up @@ -176,7 +119,6 @@ struct tiledb_channel_operator_handle_t
"tiledb_channel_operator_handle_t"};

private:
QueryChannelOperator value_;
std::string name_;

public:
Expand All @@ -190,36 +132,16 @@ struct tiledb_channel_operator_handle_t
* @param op An enum specifying the type of operator
* @param name A string representation of the operator
*/
explicit tiledb_channel_operator_handle_t(
QueryChannelOperator op, const std::string& name)
: value_{op}
, name_{name} {
}

[[nodiscard]] inline QueryChannelOperator value() const {
return value_;
explicit tiledb_channel_operator_handle_t(const std::string& name)
: name_{name} {
}

[[nodiscard]] inline std::string name() const {
return name_;
}

[[nodiscard]] std::shared_ptr<Operation> make_operation(
[[nodiscard]] std::shared_ptr<tiledb::sm::Operation> make_operation(
const tiledb::sm::FieldInfo& fi) const;
};

inline void ensure_aggregate_numeric_field(
const tiledb_channel_operator_t* op, const tiledb::sm::FieldInfo& fi) {
if (fi.var_sized_) {
throw tiledb::api::CAPIStatusException(
op->name() + " aggregates are not supported for var sized attributes.");
}
if (fi.cell_val_num_ != 1) {
throw tiledb::api::CAPIStatusException(
op->name() +
" aggregates are not supported for attributes with cell_val_num "
"greater than one.");
}
}

#endif // TILEDB_CAPI_QUERY_AGGREGATE_INTERNAL_H
7 changes: 6 additions & 1 deletion tiledb/api/c_api/query_aggregate/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ commence(unit_test capi_query_aggregate)
this_target_sources(unit_capi_query_aggregate.cc)
this_target_object_libraries(capi_query_aggregate)
if (NOT MSVC)
target_compile_options(unit_capi_query_aggregate PRIVATE -Wno-deprecated-declarations)
target_compile_options(unit_capi_query_aggregate PRIVATE -Wno-deprecated-declarations)
endif()
this_target_link_libraries(tiledb_test_support_lib)

if (TILEDB_SERIALIZATION)
add_definitions(-DTILEDB_SERIALIZATION)
endif()

conclude(unit_test)
Loading

0 comments on commit 831b5b0

Please sign in to comment.