Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move counts computation to aggregate_with_count. #4401

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/src/test-cppapi-aggregates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "tiledb/sm/query/readers/aggregators/count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/null_count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/sum_aggregator.h"

#include <test/support/tdb_catch.h>
Expand Down
1 change: 0 additions & 1 deletion tiledb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ set(TILEDB_CORE_SOURCES
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/output_buffer_validator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/safe_sum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc
Expand Down
2 changes: 1 addition & 1 deletion tiledb/sm/query/readers/aggregators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ include(object_library)
# `aggregators` object library
#
commence(object_library aggregators)
this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc)
this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc)
this_target_object_libraries(baseline array_schema)
conclude(object_library)

Expand Down
34 changes: 23 additions & 11 deletions tiledb/sm/query/readers/aggregators/aggregate_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ class AggregateBuffer {
/* API */
/* ********************************* */

/** Returns the validity buffer. */
uint8_t* validity_data() const {
return validity_data_.value();
}

/** Returns if the bitmap is a count bitmap. */
bool is_count_bitmap() const {
return count_bitmap_;
Expand All @@ -96,12 +91,6 @@ class AggregateBuffer {
return bitmap_data_.has_value();
}

/** Returns types bitmap data. */
template <class BitmapType>
BitmapType* bitmap_data_as() const {
return static_cast<BitmapType*>(bitmap_data_.value());
}

/** Returns the min cell position to aggregate. */
uint64_t min_cell() const {
return min_cell_;
Expand Down Expand Up @@ -141,6 +130,29 @@ class AggregateBuffer {
}
}

/**
* Get the validity value at a certain cell index.
*
* @param cell_idx Cell index.
*
* @return Validity value.
*/
inline uint8_t validity_at(const uint64_t cell_idx) const {
return validity_data_.value()[cell_idx];
}

/**
* Get the bitmap value at a certain cell index.
*
* @param cell_idx Cell index.
*
* @return Bitmap value.
*/
template <class BitmapType>
inline BitmapType bitmap_at(const uint64_t cell_idx) const {
return static_cast<BitmapType*>(bitmap_data_.value())[cell_idx];
}

private:
/* ********************************* */
/* PRIVATE ATTRIBUTES */
Expand Down
54 changes: 34 additions & 20 deletions tiledb/sm/query/readers/aggregators/aggregate_with_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h"
#include "tiledb/sm/query/readers/aggregators/field_info.h"
#include "tiledb/sm/query/readers/aggregators/no_op.h"

namespace tiledb::sm {

Expand All @@ -53,7 +54,7 @@ struct type_data<std::string> {
typedef std::string_view value_type;
};

template <typename T>
template <typename T, typename AGG_T, class AggPolicy, class ValidityPolicy>
class AggregateWithCount {
public:
/* ********************************* */
Expand All @@ -71,19 +72,17 @@ class AggregateWithCount {
/**
* Aggregate the input data.
*
* @tparam AGG_T Aggregate value type.
* @tparam BITMAP_T Bitmap type.
* @tparam AggPolicy Aggregation policy.
* @param input_data Input data for the aggregation.
*
* NOTE: Count of cells returned is used to infer the validity from the
* caller.
* @return {Aggregate value, count of cells}.
*/
template <typename AGG_T, typename BITMAP_T, class AggPolicy>
template <typename BITMAP_T>
tuple<AGG_T, uint64_t> aggregate(AggregateBuffer& input_data) {
typedef typename type_data<T>::value_type VALUE_T;
AggPolicy agg_policy;
ValidityPolicy val_policy;
AGG_T res;
if constexpr (std::is_same<AGG_T, std::string_view>::value) {
res = "";
Expand All @@ -96,17 +95,14 @@ class AggregateWithCount {
// nullable. The bitmap tells us which cells was already filtered out by
// ranges or query conditions.
if (input_data.has_bitmap()) {
auto bitmap_values = input_data.bitmap_data_as<BITMAP_T>();

if (field_info_.is_nullable_) {
auto validity_values = input_data.validity_data();

// Process for nullable values with bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
if (validity_values[c] != 0 && bitmap_values[c] != 0) {
AGG_T value = input_data.value_at<VALUE_T>(c);
for (BITMAP_T i = 0; i < bitmap_values[c]; i++) {
auto bitmap_val = input_data.bitmap_at<BITMAP_T>(c);
if (val_policy.op(input_data.validity_at(c)) && bitmap_val != 0) {
auto value = value_at(input_data, c);
for (BITMAP_T i = 0; i < bitmap_val; i++) {
agg_policy.op(value, res, count);
count++;
}
Expand All @@ -116,23 +112,21 @@ class AggregateWithCount {
// Process for non nullable values with bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
AGG_T value = input_data.value_at<VALUE_T>(c);

for (BITMAP_T i = 0; i < bitmap_values[c]; i++) {
auto bitmap_val = input_data.bitmap_at<BITMAP_T>(c);
auto value = value_at(input_data, c);
for (BITMAP_T i = 0; i < bitmap_val; i++) {
agg_policy.op(value, res, count);
count++;
}
}
}
} else {
if (field_info_.is_nullable_) {
auto validity_values = input_data.validity_data();

// Process for nullable values with no bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
if (validity_values[c] != 0) {
AGG_T value = input_data.value_at<VALUE_T>(c);
if (val_policy.op(input_data.validity_at(c))) {
auto value = value_at(input_data, c);
agg_policy.op(value, res, count);
count++;
}
Expand All @@ -141,7 +135,7 @@ class AggregateWithCount {
// Process for non nullable values with no bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
AGG_T value = input_data.value_at<VALUE_T>(c);
auto value = value_at(input_data, c);
agg_policy.op(value, res, count);
count++;
}
Expand All @@ -158,6 +152,26 @@ class AggregateWithCount {

/** Field information. */
const FieldInfo field_info_;

/* ********************************* */
/* PRIVATE METHODS */
/* ********************************* */

/**
* Returns the value at the specified cell if needed.
*
* @param input_data Input data.
* @param c Cell index.
* @return Value.
*/
inline AGG_T value_at(AggregateBuffer& input_data, uint64_t c) {
typedef typename type_data<T>::value_type VALUE_T;
if constexpr (!std::is_same<AggPolicy, NoOp>::value) {
return input_data.value_at<VALUE_T>(c);
}

return AGG_T();
}
};

} // namespace tiledb::sm
Expand Down
56 changes: 31 additions & 25 deletions tiledb/sm/query/readers/aggregators/count_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ class CountAggregatorStatusException : public StatusException {
}
};

CountAggregator::CountAggregator()
: OutputBufferValidator(FieldInfo("", false, false, 1))
template <class ValidityPolicy>
CountAggregatorBase<ValidityPolicy>::CountAggregatorBase(FieldInfo field_info)
: OutputBufferValidator(field_info)
, aggregate_with_count_(field_info)
, count_(0) {
}

void CountAggregator::validate_output_buffer(
template <class ValidityPolicy>
void CountAggregatorBase<ValidityPolicy>::validate_output_buffer(
std::string output_field_name,
std::unordered_map<std::string, QueryBuffer>& buffers) {
if (buffers.count(output_field_name) == 0) {
Expand All @@ -59,33 +62,22 @@ void CountAggregator::validate_output_buffer(
ensure_output_buffer_count(buffers[output_field_name]);
}

template <class BitmapType>
uint64_t count_cells(AggregateBuffer& input_data) {
uint64_t ret = 0;
template <class ValidityPolicy>
void CountAggregatorBase<ValidityPolicy>::aggregate_data(
AggregateBuffer& input_data) {
tuple<uint8_t, uint64_t> res;

auto bitmap_data = input_data.bitmap_data_as<BitmapType>();
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) {
ret += bitmap_data[c];
}

return ret;
}

void CountAggregator::aggregate_data(AggregateBuffer& input_data) {
// Run different loops for bitmap versus no bitmap. The bitmap tells us which
// cells was already filtered out by ranges or query conditions.
if (input_data.has_bitmap()) {
if (input_data.is_count_bitmap()) {
count_ += count_cells<uint64_t>(input_data);
} else {
count_ += count_cells<uint8_t>(input_data);
}
if (input_data.is_count_bitmap()) {
res = aggregate_with_count_.template aggregate<uint64_t>(input_data);
} else {
count_ += input_data.max_cell() - input_data.min_cell();
res = aggregate_with_count_.template aggregate<uint8_t>(input_data);
}

count_ += std::get<1>(res);
}

void CountAggregator::copy_to_user_buffer(
template <class ValidityPolicy>
void CountAggregatorBase<ValidityPolicy>::copy_to_user_buffer(
std::string output_field_name,
std::unordered_map<std::string, QueryBuffer>& buffers) {
auto& result_buffer = buffers[output_field_name];
Expand All @@ -96,4 +88,18 @@ void CountAggregator::copy_to_user_buffer(
}
}

NullCountAggregator::NullCountAggregator(FieldInfo field_info)
: CountAggregatorBase(field_info)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The architecture here is slightly off, NullCountAggregator should be the derived that has a constructor based on FieldInfo (which also performs the checks) and CountAggregatorBase, CountAggregator should have a default constructor. This will avoid the weird FieldInfo("","",false) which tries to replace a default constructed FieldInfo which we explicitly deleted.

, field_info_(field_info) {
if (!field_info_.is_nullable_) {
throw CountAggregatorStatusException(
"NullCount aggregates must only be requested for nullable "
"attributes.");
}
}

// Explicit template instantiations
template CountAggregatorBase<Null>::CountAggregatorBase(FieldInfo);
template CountAggregatorBase<NonNull>::CountAggregatorBase(FieldInfo);

} // namespace tiledb::sm
67 changes: 58 additions & 9 deletions tiledb/sm/query/readers/aggregators/count_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,32 @@
#ifndef TILEDB_COUNT_AGGREGATOR_H
#define TILEDB_COUNT_AGGREGATOR_H

#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h"
#include "tiledb/sm/query/readers/aggregators/iaggregator.h"
#include "tiledb/sm/query/readers/aggregators/no_op.h"
#include "tiledb/sm/query/readers/aggregators/validity_policies.h"

namespace tiledb::sm {

class QueryBuffer;

class CountAggregator : public OutputBufferValidator, public IAggregator {
template <class ValidityPolicy>
class CountAggregatorBase : public OutputBufferValidator, public IAggregator {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

/** Constructor. */
CountAggregator();
CountAggregatorBase(FieldInfo field_info = FieldInfo());

DISABLE_COPY_AND_COPY_ASSIGN(CountAggregator);
DISABLE_MOVE_AND_MOVE_ASSIGN(CountAggregator);
DISABLE_COPY_AND_COPY_ASSIGN(CountAggregatorBase);
DISABLE_MOVE_AND_MOVE_ASSIGN(CountAggregatorBase);

/* ********************************* */
/* API */
/* ********************************* */

/** Returns the field name for the aggregator. */
std::string field_name() override {
return constants::count_of_rows;
}

/** Returns if the aggregation is var sized or not. */
bool var_sized() override {
return false;
Expand Down Expand Up @@ -102,10 +101,60 @@ class CountAggregator : public OutputBufferValidator, public IAggregator {
/* PRIVATE ATTRIBUTES */
/* ********************************* */

/** AggregateWithCount to do summation of AggregateBuffer data. */
AggregateWithCount<uint8_t, uint64_t, NoOp, ValidityPolicy>
aggregate_with_count_;

/** Cell count. */
std::atomic<uint64_t> count_;
};

class CountAggregator : public CountAggregatorBase<NonNull> {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

CountAggregator()
: CountAggregatorBase() {
}

/* ********************************* */
/* API */
/* ********************************* */

/** Returns the field name for the aggregator. */
std::string field_name() override {
return constants::count_of_rows;
}
};

class NullCountAggregator : public CountAggregatorBase<Null> {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

NullCountAggregator(FieldInfo field_info);

/* ********************************* */
/* API */
/* ********************************* */

/** Returns the field name for the aggregator. */
std::string field_name() override {
return field_info_.name_;
}

private:
/* ********************************* */
/* PRIVATE ATTRIBUTES */
/* ********************************* */

/** Field information. */
const FieldInfo field_info_;
};

} // namespace tiledb::sm

#endif // TILEDB_COUNT_AGGREGATOR_H
Loading
Loading