Skip to content

Commit

Permalink
Move counts computation to aggregate_with_count. (#4401)
Browse files Browse the repository at this point in the history
* Move counts computation to aggregate_with_count.

This moves the computation for the count aggregators to aggregate_with_count. It also merges the unit test with the existing ones for other aggregators.

TYPE: IMPROVEMENT
DESC: Move counts computation to aggregate_with_count.

* Adding inline.

* Address feedback from @robertbindar.
  • Loading branch information
KiterLuc authored Oct 13, 2023
1 parent 916b5a0 commit eee582d
Show file tree
Hide file tree
Showing 25 changed files with 664 additions and 1,107 deletions.
1 change: 0 additions & 1 deletion test/src/test-cppapi-aggregates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "tiledb/sm/query/readers/aggregators/count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/mean_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/min_max_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/null_count_aggregator.h"
#include "tiledb/sm/query/readers/aggregators/sum_aggregator.h"

#include <test/support/tdb_catch.h>
Expand Down
1 change: 0 additions & 1 deletion tiledb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ set(TILEDB_CORE_SOURCES
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/output_buffer_validator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/safe_sum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc
Expand Down
2 changes: 1 addition & 1 deletion tiledb/sm/query/readers/aggregators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ include(object_library)
# `aggregators` object library
#
commence(object_library aggregators)
this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc)
this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc)
this_target_object_libraries(baseline array_schema)
conclude(object_library)

Expand Down
34 changes: 23 additions & 11 deletions tiledb/sm/query/readers/aggregators/aggregate_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ class AggregateBuffer {
/* API */
/* ********************************* */

/** Returns the validity buffer. */
uint8_t* validity_data() const {
return validity_data_.value();
}

/** Returns if the bitmap is a count bitmap. */
bool is_count_bitmap() const {
return count_bitmap_;
Expand All @@ -96,12 +91,6 @@ class AggregateBuffer {
return bitmap_data_.has_value();
}

/** Returns types bitmap data. */
template <class BitmapType>
BitmapType* bitmap_data_as() const {
return static_cast<BitmapType*>(bitmap_data_.value());
}

/** Returns the min cell position to aggregate. */
uint64_t min_cell() const {
return min_cell_;
Expand Down Expand Up @@ -141,6 +130,29 @@ class AggregateBuffer {
}
}

/**
* Get the validity value at a certain cell index.
*
* @param cell_idx Cell index.
*
* @return Validity value.
*/
inline uint8_t validity_at(const uint64_t cell_idx) const {
return validity_data_.value()[cell_idx];
}

/**
* Get the bitmap value at a certain cell index.
*
* @param cell_idx Cell index.
*
* @return Bitmap value.
*/
template <class BitmapType>
inline BitmapType bitmap_at(const uint64_t cell_idx) const {
return static_cast<BitmapType*>(bitmap_data_.value())[cell_idx];
}

private:
/* ********************************* */
/* PRIVATE ATTRIBUTES */
Expand Down
54 changes: 34 additions & 20 deletions tiledb/sm/query/readers/aggregators/aggregate_with_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h"
#include "tiledb/sm/query/readers/aggregators/field_info.h"
#include "tiledb/sm/query/readers/aggregators/no_op.h"

namespace tiledb::sm {

Expand All @@ -53,7 +54,7 @@ struct type_data<std::string> {
typedef std::string_view value_type;
};

template <typename T>
template <typename T, typename AGG_T, class AggPolicy, class ValidityPolicy>
class AggregateWithCount {
public:
/* ********************************* */
Expand All @@ -71,19 +72,17 @@ class AggregateWithCount {
/**
* Aggregate the input data.
*
* @tparam AGG_T Aggregate value type.
* @tparam BITMAP_T Bitmap type.
* @tparam AggPolicy Aggregation policy.
* @param input_data Input data for the aggregation.
*
* NOTE: Count of cells returned is used to infer the validity from the
* caller.
* @return {Aggregate value, count of cells}.
*/
template <typename AGG_T, typename BITMAP_T, class AggPolicy>
template <typename BITMAP_T>
tuple<AGG_T, uint64_t> aggregate(AggregateBuffer& input_data) {
typedef typename type_data<T>::value_type VALUE_T;
AggPolicy agg_policy;
ValidityPolicy val_policy;
AGG_T res;
if constexpr (std::is_same<AGG_T, std::string_view>::value) {
res = "";
Expand All @@ -96,17 +95,14 @@ class AggregateWithCount {
// nullable. The bitmap tells us which cells was already filtered out by
// ranges or query conditions.
if (input_data.has_bitmap()) {
auto bitmap_values = input_data.bitmap_data_as<BITMAP_T>();

if (field_info_.is_nullable_) {
auto validity_values = input_data.validity_data();

// Process for nullable values with bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
if (validity_values[c] != 0 && bitmap_values[c] != 0) {
AGG_T value = input_data.value_at<VALUE_T>(c);
for (BITMAP_T i = 0; i < bitmap_values[c]; i++) {
auto bitmap_val = input_data.bitmap_at<BITMAP_T>(c);
if (val_policy.op(input_data.validity_at(c)) && bitmap_val != 0) {
auto value = value_at(input_data, c);
for (BITMAP_T i = 0; i < bitmap_val; i++) {
agg_policy.op(value, res, count);
count++;
}
Expand All @@ -116,23 +112,21 @@ class AggregateWithCount {
// Process for non nullable values with bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
AGG_T value = input_data.value_at<VALUE_T>(c);

for (BITMAP_T i = 0; i < bitmap_values[c]; i++) {
auto bitmap_val = input_data.bitmap_at<BITMAP_T>(c);
auto value = value_at(input_data, c);
for (BITMAP_T i = 0; i < bitmap_val; i++) {
agg_policy.op(value, res, count);
count++;
}
}
}
} else {
if (field_info_.is_nullable_) {
auto validity_values = input_data.validity_data();

// Process for nullable values with no bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
if (validity_values[c] != 0) {
AGG_T value = input_data.value_at<VALUE_T>(c);
if (val_policy.op(input_data.validity_at(c))) {
auto value = value_at(input_data, c);
agg_policy.op(value, res, count);
count++;
}
Expand All @@ -141,7 +135,7 @@ class AggregateWithCount {
// Process for non nullable values with no bitmap.
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell();
c++) {
AGG_T value = input_data.value_at<VALUE_T>(c);
auto value = value_at(input_data, c);
agg_policy.op(value, res, count);
count++;
}
Expand All @@ -158,6 +152,26 @@ class AggregateWithCount {

/** Field information. */
const FieldInfo field_info_;

/* ********************************* */
/* PRIVATE METHODS */
/* ********************************* */

/**
* Returns the value at the specified cell if needed.
*
* @param input_data Input data.
* @param c Cell index.
* @return Value.
*/
inline AGG_T value_at(AggregateBuffer& input_data, uint64_t c) {
typedef typename type_data<T>::value_type VALUE_T;
if constexpr (!std::is_same<AggPolicy, NoOp>::value) {
return input_data.value_at<VALUE_T>(c);
}

return AGG_T();
}
};

} // namespace tiledb::sm
Expand Down
56 changes: 31 additions & 25 deletions tiledb/sm/query/readers/aggregators/count_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ class CountAggregatorStatusException : public StatusException {
}
};

CountAggregator::CountAggregator()
: OutputBufferValidator(FieldInfo("", false, false, 1))
template <class ValidityPolicy>
CountAggregatorBase<ValidityPolicy>::CountAggregatorBase(FieldInfo field_info)
: OutputBufferValidator(field_info)
, aggregate_with_count_(field_info)
, count_(0) {
}

void CountAggregator::validate_output_buffer(
template <class ValidityPolicy>
void CountAggregatorBase<ValidityPolicy>::validate_output_buffer(
std::string output_field_name,
std::unordered_map<std::string, QueryBuffer>& buffers) {
if (buffers.count(output_field_name) == 0) {
Expand All @@ -59,33 +62,22 @@ void CountAggregator::validate_output_buffer(
ensure_output_buffer_count(buffers[output_field_name]);
}

template <class BitmapType>
uint64_t count_cells(AggregateBuffer& input_data) {
uint64_t ret = 0;
template <class ValidityPolicy>
void CountAggregatorBase<ValidityPolicy>::aggregate_data(
AggregateBuffer& input_data) {
tuple<uint8_t, uint64_t> res;

auto bitmap_data = input_data.bitmap_data_as<BitmapType>();
for (uint64_t c = input_data.min_cell(); c < input_data.max_cell(); c++) {
ret += bitmap_data[c];
}

return ret;
}

void CountAggregator::aggregate_data(AggregateBuffer& input_data) {
// Run different loops for bitmap versus no bitmap. The bitmap tells us which
// cells was already filtered out by ranges or query conditions.
if (input_data.has_bitmap()) {
if (input_data.is_count_bitmap()) {
count_ += count_cells<uint64_t>(input_data);
} else {
count_ += count_cells<uint8_t>(input_data);
}
if (input_data.is_count_bitmap()) {
res = aggregate_with_count_.template aggregate<uint64_t>(input_data);
} else {
count_ += input_data.max_cell() - input_data.min_cell();
res = aggregate_with_count_.template aggregate<uint8_t>(input_data);
}

count_ += std::get<1>(res);
}

void CountAggregator::copy_to_user_buffer(
template <class ValidityPolicy>
void CountAggregatorBase<ValidityPolicy>::copy_to_user_buffer(
std::string output_field_name,
std::unordered_map<std::string, QueryBuffer>& buffers) {
auto& result_buffer = buffers[output_field_name];
Expand All @@ -96,4 +88,18 @@ void CountAggregator::copy_to_user_buffer(
}
}

NullCountAggregator::NullCountAggregator(FieldInfo field_info)
: CountAggregatorBase(field_info)
, field_info_(field_info) {
if (!field_info_.is_nullable_) {
throw CountAggregatorStatusException(
"NullCount aggregates must only be requested for nullable "
"attributes.");
}
}

// Explicit template instantiations
template CountAggregatorBase<Null>::CountAggregatorBase(FieldInfo);
template CountAggregatorBase<NonNull>::CountAggregatorBase(FieldInfo);

} // namespace tiledb::sm
67 changes: 58 additions & 9 deletions tiledb/sm/query/readers/aggregators/count_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,32 @@
#ifndef TILEDB_COUNT_AGGREGATOR_H
#define TILEDB_COUNT_AGGREGATOR_H

#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h"
#include "tiledb/sm/query/readers/aggregators/iaggregator.h"
#include "tiledb/sm/query/readers/aggregators/no_op.h"
#include "tiledb/sm/query/readers/aggregators/validity_policies.h"

namespace tiledb::sm {

class QueryBuffer;

class CountAggregator : public OutputBufferValidator, public IAggregator {
template <class ValidityPolicy>
class CountAggregatorBase : public OutputBufferValidator, public IAggregator {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

/** Constructor. */
CountAggregator();
CountAggregatorBase(FieldInfo field_info = FieldInfo());

DISABLE_COPY_AND_COPY_ASSIGN(CountAggregator);
DISABLE_MOVE_AND_MOVE_ASSIGN(CountAggregator);
DISABLE_COPY_AND_COPY_ASSIGN(CountAggregatorBase);
DISABLE_MOVE_AND_MOVE_ASSIGN(CountAggregatorBase);

/* ********************************* */
/* API */
/* ********************************* */

/** Returns the field name for the aggregator. */
std::string field_name() override {
return constants::count_of_rows;
}

/** Returns if the aggregation is var sized or not. */
bool var_sized() override {
return false;
Expand Down Expand Up @@ -102,10 +101,60 @@ class CountAggregator : public OutputBufferValidator, public IAggregator {
/* PRIVATE ATTRIBUTES */
/* ********************************* */

/** AggregateWithCount to do summation of AggregateBuffer data. */
AggregateWithCount<uint8_t, uint64_t, NoOp, ValidityPolicy>
aggregate_with_count_;

/** Cell count. */
std::atomic<uint64_t> count_;
};

class CountAggregator : public CountAggregatorBase<NonNull> {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

CountAggregator()
: CountAggregatorBase() {
}

/* ********************************* */
/* API */
/* ********************************* */

/** Returns the field name for the aggregator. */
std::string field_name() override {
return constants::count_of_rows;
}
};

class NullCountAggregator : public CountAggregatorBase<Null> {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

NullCountAggregator(FieldInfo field_info);

/* ********************************* */
/* API */
/* ********************************* */

/** Returns the field name for the aggregator. */
std::string field_name() override {
return field_info_.name_;
}

private:
/* ********************************* */
/* PRIVATE ATTRIBUTES */
/* ********************************* */

/** Field information. */
const FieldInfo field_info_;
};

} // namespace tiledb::sm

#endif // TILEDB_COUNT_AGGREGATOR_H
Loading

0 comments on commit eee582d

Please sign in to comment.