Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix defects with query channels #4786

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tiledb/api/c_api/query/query_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ inline void ensure_query_is_valid(tiledb_query_t* query) {
*
* @param query A sm::Query pointer
*/
inline void ensure_query_is_not_initialized(const tiledb::sm::Query* query) {
if (query->status() != sm::QueryStatus::UNINITIALIZED) {
inline void ensure_query_is_not_initialized(const tiledb::sm::Query& query) {
if (query.status() != sm::QueryStatus::UNINITIALIZED) {
throw CAPIStatusException(
"argument `query` is at a too late state of its lifetime");
}
Expand All @@ -72,7 +72,8 @@ inline void ensure_query_is_not_initialized(const tiledb::sm::Query* query) {
*/
inline void ensure_query_is_not_initialized(tiledb_query_t* query) {
ensure_query_is_valid(query);
ensure_query_is_not_initialized(query->query_);
// Indirection safe because previous statement will throw otherwise
ensure_query_is_not_initialized(*query->query_);
}

} // namespace tiledb::api
Expand Down
9 changes: 3 additions & 6 deletions tiledb/api/c_api/query_aggregate/query_aggregate_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,8 @@ capi_return_t tiledb_query_get_default_channel(
tiledb_ctx_t*, tiledb_query_t* query, tiledb_query_channel_t** channel) {
ensure_query_is_valid(query);
ensure_output_pointer_is_valid(channel);

// We don't have an internal representation of a channel,
// the default channel is currently just a hashmap, so only pass the query
// to the channel constructor to be carried until next the api call.
*channel = tiledb_query_channel_handle_t::make_handle(query);
*channel = tiledb_query_channel_handle_t::make_handle(
query->query_->actual_default_channel());

return TILEDB_OK;
}
Expand Down Expand Up @@ -209,7 +206,7 @@ capi_return_t tiledb_channel_apply_aggregate(
const char* output_field_name,
const tiledb_channel_operation_t* operation) {
ensure_query_channel_is_valid(channel);
ensure_query_is_not_initialized(channel->query_);
ensure_query_is_not_initialized(channel->query());
ensure_output_field_is_valid(output_field_name);
ensure_operation_is_valid(operation);
channel->add_aggregate(output_field_name, operation);
Expand Down
25 changes: 19 additions & 6 deletions tiledb/api/c_api/query_aggregate/query_aggregate_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "tiledb/sm/c_api/tiledb_struct_def.h"
#include "tiledb/sm/query/query.h"
#include "tiledb/sm/query/readers/aggregators/operation.h"
#include "tiledb/sm/query/readers/aggregators/query_channel.h"

struct tiledb_channel_operation_handle_t
: public tiledb::api::CAPIHandle<tiledb_channel_operation_handle_t> {
Expand Down Expand Up @@ -71,16 +72,22 @@ struct tiledb_channel_operation_handle_t
}
};

/* Forward declaration */
namespace tiledb::sm {
class Query;
}

struct tiledb_query_channel_handle_t
: public tiledb::api::CAPIHandle<tiledb_query_channel_handle_t> {
/**
* Type name
*/
static constexpr std::string_view object_type_name{"tiledb_query_channel_t"};

public:
tiledb::sm::Query* query_;
private:
std::shared_ptr<class tiledb::sm::QueryChannelActual> channel_;

public:
/**
* Default constructor doesn't make sense
*/
Expand All @@ -90,24 +97,30 @@ struct tiledb_query_channel_handle_t
* Ordinary constructor.
* @param query The query object that owns the channel
*/
tiledb_query_channel_handle_t(tiledb_query_t* query)
: query_(query->query_) {
tiledb_query_channel_handle_t(
std::shared_ptr<class tiledb::sm::QueryChannelActual> channel)
: channel_(channel) {
}

inline void add_aggregate(
const char* output_field,
const tiledb_channel_operation_handle_t* operation) {
if (query_->is_aggregate(output_field)) {
auto& query{channel_->query()};
if (query.is_aggregate(output_field)) {
throw tiledb::api::CAPIStatusException(
"An aggregate operation for output field: " +
std::string(output_field) + " already exists.");
}

// Add the aggregator the the default channel as this is the only channel
// type we currently support
query_->add_aggregator_to_default_channel(
query.add_aggregator_to_default_channel(
output_field, operation->aggregator());
}

inline tiledb::sm::Query& query() {
return channel_->query();
}
};

struct tiledb_channel_operator_handle_t
Expand Down
14 changes: 12 additions & 2 deletions tiledb/api/c_api/query_field/query_field_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tiledb_query_field_handle_t::tiledb_query_field_handle_t(
tiledb_query_t* query, const char* field_name)
: query_(query->query_)
, field_name_(field_name) {
bool is_aggregate{false};
if (field_name_ == tiledb::sm::constants::coords) {
field_origin_ = std::make_shared<FieldFromDimension>();
type_ = query_->array_schema().domain().dimension_ptr(0)->type();
Expand All @@ -72,6 +73,7 @@ tiledb_query_field_handle_t::tiledb_query_field_handle_t(
cell_val_num_ =
query_->array_schema().dimension_ptr(field_name_)->cell_val_num();
} else if (query_->is_aggregate(field_name_)) {
is_aggregate = true;
field_origin_ = std::make_shared<FieldFromAggregate>();
auto aggregate = query_->get_aggregate(field_name_).value();
type_ = aggregate->output_datatype();
Expand All @@ -80,8 +82,16 @@ tiledb_query_field_handle_t::tiledb_query_field_handle_t(
} else {
throw tiledb::api::CAPIStatusException("There is no field " + field_name_);
}

channel_ = tiledb_query_channel_handle_t::make_handle(query);
/*
* We have no `class QueryField` that would already know its own aggregate,
* so we mirror the channel selection process that `class Query` has
* responsibility for.
*/
if (is_aggregate) {
channel_ = query_->actual_aggegate_channel();
} else {
channel_ = query_->actual_default_channel();
}
}

namespace tiledb::api {
Expand Down
8 changes: 2 additions & 6 deletions tiledb/api/c_api/query_field/query_field_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct tiledb_query_field_handle_t
std::shared_ptr<FieldOrigin> field_origin_;
tiledb::sm::Datatype type_;
uint32_t cell_val_num_;
tiledb_query_channel_handle_t* channel_;
std::shared_ptr<tiledb::sm::QueryChannelActual> channel_;

public:
/**
Expand All @@ -87,10 +87,6 @@ struct tiledb_query_field_handle_t
*/
tiledb_query_field_handle_t(tiledb_query_t* query, const char* field_name);

~tiledb_query_field_handle_t() {
tiledb_query_channel_handle_t::break_handle(channel_);
}

tiledb_field_origin_t origin() {
return field_origin_->origin();
}
Expand All @@ -101,7 +97,7 @@ struct tiledb_query_field_handle_t
return cell_val_num_;
}
tiledb_query_channel_handle_t* channel() {
return channel_;
return tiledb_query_channel_handle_t::make_handle(channel_);
}
};

Expand Down
Loading
Loading