Skip to content

Commit

Permalink
Refactor aggregate with count.
Browse files Browse the repository at this point in the history
  • Loading branch information
KiterLuc committed Aug 30, 2023
1 parent 447a97f commit daae965
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 86 deletions.
2 changes: 1 addition & 1 deletion tiledb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ set(TILEDB_CORE_SOURCES
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_condition.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/query_remote_buffer_storage.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/aggregate_sum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/mean_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/null_count_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/output_buffer_validator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/safe_sum.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/aggregators/sum_aggregator.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/dense_reader.cc
${TILEDB_CORE_INCLUDE_DIR}/tiledb/sm/query/readers/ordered_dim_label_reader.cc
Expand Down
2 changes: 1 addition & 1 deletion tiledb/sm/query/readers/aggregators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ include(object_library)
# `aggregators` object library
#
commence(object_library aggregators)
this_target_sources(aggregate_sum.cc count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc sum_aggregator.cc)
this_target_sources(count_aggregator.cc mean_aggregator.cc min_max_aggregator.cc null_count_aggregator.cc output_buffer_validator.cc safe_sum.cc sum_aggregator.cc)
this_target_object_libraries(baseline array_schema)
conclude(object_library)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,65 +35,19 @@

#include "tiledb/sm/query/readers/aggregators/aggregate_buffer.h"
#include "tiledb/sm/query/readers/aggregators/field_info.h"
#include "tiledb/sm/query/readers/aggregators/safe_sum.h"

namespace tiledb {
namespace sm {

#define SUM_TYPE_DATA(T, SUM_T) \
template <> \
struct sum_type_data<T> { \
using type = T; \
typedef SUM_T sum_type; \
};

/** Convert basic type to a sum type. **/
template <typename T>
struct sum_type_data;

SUM_TYPE_DATA(int8_t, int64_t);
SUM_TYPE_DATA(uint8_t, uint64_t);
SUM_TYPE_DATA(int16_t, int64_t);
SUM_TYPE_DATA(uint16_t, uint64_t);
SUM_TYPE_DATA(int32_t, int64_t);
SUM_TYPE_DATA(uint32_t, uint64_t);
SUM_TYPE_DATA(int64_t, int64_t);
SUM_TYPE_DATA(uint64_t, uint64_t);
SUM_TYPE_DATA(float, double);
SUM_TYPE_DATA(double, double);

/**
* Sum function that prevent wrap arounds on overflow.
*
* @param value Value to add to the sum.
* @param sum Computed sum.
*/
template <typename SUM_T>
void safe_sum(SUM_T value, SUM_T& sum);

/**
* Sum function for atomics that prevent wrap arounds on overflow.
*
* @param value Value to add to the sum.
* @param sum Computed sum.
*/
template <typename SUM_T>
void safe_sum(SUM_T value, std::atomic<SUM_T>& sum) {
SUM_T cur_sum = sum;
SUM_T new_sum;
do {
new_sum = cur_sum;
safe_sum(value, new_sum);
} while (!std::atomic_compare_exchange_weak(&sum, &cur_sum, new_sum));
}

template <typename T>
class AggregateSum {
class AggregateWithCount {
public:
/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
/* ********************************* */

AggregateSum(const FieldInfo field_info)
AggregateWithCount(const FieldInfo field_info)
: field_info_(field_info) {
}

Expand All @@ -110,8 +64,9 @@ class AggregateSum {
*
* @return {Sum for the cells, number of cells, optional validity value}.
*/
template <typename SUM_T, typename BITMAP_T>
tuple<SUM_T, uint64_t, optional<uint8_t>> sum(AggregateBuffer& input_data) {
template <typename SUM_T, typename BITMAP_T, class AggPolicy>
tuple<SUM_T, uint64_t, optional<uint8_t>> aggregate(
AggregateBuffer& input_data) {
SUM_T sum{0};
uint64_t count{0};
optional<uint8_t> validity{nullopt};
Expand All @@ -136,7 +91,7 @@ class AggregateSum {
auto value = static_cast<SUM_T>(values[c]);
for (BITMAP_T i = 0; i < bitmap_values[c]; i++) {
count++;
safe_sum(value, sum);
AggPolicy::op(value, sum);
}
}
}
Expand All @@ -148,7 +103,7 @@ class AggregateSum {

for (BITMAP_T i = 0; i < bitmap_values[c]; i++) {
count++;
safe_sum(value, sum);
AggPolicy::op(value, sum);
}
}
}
Expand All @@ -165,7 +120,7 @@ class AggregateSum {

auto value = static_cast<SUM_T>(values[c]);
count++;
safe_sum(value, sum);
AggPolicy::op(value, sum);
}
}
} else {
Expand All @@ -174,7 +129,7 @@ class AggregateSum {
c++) {
auto value = static_cast<SUM_T>(values[c]);
count++;
safe_sum(value, sum);
AggPolicy::op(value, sum);
}
}
}
Expand Down
16 changes: 9 additions & 7 deletions tiledb/sm/query/readers/aggregators/mean_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,18 @@ void MeanAggregator<T>::aggregate_data(AggregateBuffer& input_data) {
0, 0, nullopt};

if (input_data.is_count_bitmap()) {
res =
summator_.template sum<typename sum_type_data<T>::sum_type, uint64_t>(
input_data);
res = summator_.template aggregate<
typename sum_type_data<T>::sum_type,
uint64_t,
SafeSum>(input_data);
} else {
res =
summator_.template sum<typename sum_type_data<T>::sum_type, uint8_t>(
input_data);
res = summator_.template aggregate<
typename sum_type_data<T>::sum_type,
uint8_t,
SafeSum>(input_data);
}

safe_sum(std::get<0>(res), sum_);
SafeSum::safe_sum(std::get<0>(res), sum_);
count_ += std::get<1>(res);
if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) {
validity_value_ = 1;
Expand Down
7 changes: 4 additions & 3 deletions tiledb/sm/query/readers/aggregators/mean_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
#define TILEDB_MEAN_AGGREGATOR_H

#include "tiledb/common/common.h"
#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h"
#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h"
#include "tiledb/sm/query/readers/aggregators/field_info.h"
#include "tiledb/sm/query/readers/aggregators/iaggregator.h"
#include "tiledb/sm/query/readers/aggregators/sum_type.h"

namespace tiledb {
namespace sm {
Expand Down Expand Up @@ -116,8 +117,8 @@ class MeanAggregator : public OutputBufferValidator, public IAggregator {
/** Field information. */
const FieldInfo field_info_;

/** AggregateSum to do summation of AggregateBuffer data. */
AggregateSum<T> summator_;
/** AggregateWithCount to do summation of AggregateBuffer data. */
AggregateWithCount<T> summator_;

/** Computed sum. */
std::atomic<typename sum_type_data<T>::sum_type> sum_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file aggregate_sum.cc
* @file safe_sum.cc
*
* @section LICENSE
*
Expand Down Expand Up @@ -27,17 +27,19 @@
*
* @section DESCRIPTION
*
* This file implements class AggregateSum.
* This file implements class SafeSum.
*/

#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h"
#include "tiledb/sm/query/readers/aggregators/safe_sum.h"

#include <limits>

namespace tiledb {
namespace sm {

/** Specialization of safe_sum for int64_t sums. */
/** Specialization of op for int64_t sums. */
template <>
void safe_sum<int64_t>(int64_t value, int64_t& sum) {
void SafeSum::op<int64_t>(int64_t value, int64_t& sum) {
if (sum > 0 && value > 0 &&
(sum > (std::numeric_limits<int64_t>::max() - value))) {
throw std::overflow_error("overflow on sum");
Expand All @@ -51,19 +53,19 @@ void safe_sum<int64_t>(int64_t value, int64_t& sum) {
sum += value;
}

/** Specialization of safe_sum for uint64_t sums. */
/** Specialization of op for uint64_t sums. */
template <>
void safe_sum<uint64_t>(uint64_t value, uint64_t& sum) {
void SafeSum::op<uint64_t>(uint64_t value, uint64_t& sum) {
if (sum > (std::numeric_limits<uint64_t>::max() - value)) {
throw std::overflow_error("overflow on sum");
}

sum += value;
}

/** Specialization of safe_sum for double sums. */
/** Specialization of op for double sums. */
template <>
void safe_sum<double>(double value, double& sum) {
void SafeSum::op<double>(double value, double& sum) {
if ((sum < 0.0) == (value < 0.0) &&
(std::abs(sum) >
(std::numeric_limits<double>::max() - std::abs(value)))) {
Expand Down
73 changes: 73 additions & 0 deletions tiledb/sm/query/readers/aggregators/safe_sum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* @file safe_sum.h
*
* @section LICENSE
*
* The MIT License
*
* @copyright Copyright (c) 2023 TileDB, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
* @section DESCRIPTION
*
* This file defines class SafeSum.
*/

#ifndef TILEDB_SAFE_SUM_H
#define TILEDB_SAFE_SUM_H

#include <atomic>
#include <stdexcept>

namespace tiledb {
namespace sm {

class SafeSum {
public:
/**
* Sum function that prevent wrap arounds on overflow.
*
* @param value Value to add to the sum.
* @param sum Computed sum.
*/
template <typename SUM_T>
static void op(SUM_T value, SUM_T& sum);

/**
* Sum function for atomics that prevent wrap arounds on overflow.
*
* @param value Value to add to the sum.
* @param sum Computed sum.
*/
template <typename SUM_T>
static void safe_sum(SUM_T value, std::atomic<SUM_T>& sum) {
SUM_T cur_sum = sum;
SUM_T new_sum;
do {
new_sum = cur_sum;
op(value, new_sum);
} while (!std::atomic_compare_exchange_weak(&sum, &cur_sum, new_sum));
}
};

} // namespace sm
} // namespace tiledb

#endif // TILEDB_SAFE_SUM_H
16 changes: 9 additions & 7 deletions tiledb/sm/query/readers/aggregators/sum_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,18 @@ void SumAggregator<T>::aggregate_data(AggregateBuffer& input_data) {
0, 0, nullopt};

if (input_data.is_count_bitmap()) {
res =
summator_.template sum<typename sum_type_data<T>::sum_type, uint64_t>(
input_data);
res = summator_.template aggregate<
typename sum_type_data<T>::sum_type,
uint64_t,
SafeSum>(input_data);
} else {
res =
summator_.template sum<typename sum_type_data<T>::sum_type, uint8_t>(
input_data);
res = summator_.template aggregate<
typename sum_type_data<T>::sum_type,
uint8_t,
SafeSum>(input_data);
}

safe_sum(std::get<0>(res), sum_);
SafeSum::safe_sum(std::get<0>(res), sum_);
if (field_info_.is_nullable_ && std::get<2>(res).value() == 1) {
validity_value_ = 1;
}
Expand Down
7 changes: 4 additions & 3 deletions tiledb/sm/query/readers/aggregators/sum_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
#ifndef TILEDB_SUM_AGGREGATOR_H
#define TILEDB_SUM_AGGREGATOR_H

#include "tiledb/sm/query/readers/aggregators/aggregate_sum.h"
#include "tiledb/sm/query/readers/aggregators/aggregate_with_count.h"
#include "tiledb/sm/query/readers/aggregators/field_info.h"
#include "tiledb/sm/query/readers/aggregators/iaggregator.h"
#include "tiledb/sm/query/readers/aggregators/sum_type.h"

namespace tiledb {
namespace sm {
Expand Down Expand Up @@ -115,8 +116,8 @@ class SumAggregator : public OutputBufferValidator, public IAggregator {
/** Field information. */
const FieldInfo field_info_;

/** AggregateSum to do summation of AggregateBuffer data. */
AggregateSum<T> summator_;
/** AggregateWithCount to do summation of AggregateBuffer data. */
AggregateWithCount<T> summator_;

/** Computed sum. */
std::atomic<typename sum_type_data<T>::sum_type> sum_;
Expand Down
Loading

0 comments on commit daae965

Please sign in to comment.