Skip to content

Commit

Permalink
Clarifying comments. @robertbindar.
Browse files Browse the repository at this point in the history
  • Loading branch information
KiterLuc committed Oct 4, 2023
1 parent 0841db2 commit 4f4ff49
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
8 changes: 6 additions & 2 deletions tiledb/sm/query/readers/aggregators/aggregate_with_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,18 @@ class AggregateWithCount {
* @tparam AggPolicy Aggregation policy.
* @param input_data Input data for the aggregation.
*
* @return {Aggregate value, number of cells, optional validity value}.
* NOTE: Count of cells returned is used to infer the validity from the
* caller.
* @return {Aggregate value, count of cells}.
*/
template <typename AGG_T, typename BITMAP_T, class AggPolicy>
tuple<AGG_T, uint64_t> aggregate(AggregateBuffer& input_data) {
typedef typename type_data<T>::value_type VALUE_T;
AggPolicy agg_policy;
AGG_T res;
if constexpr (!std::is_same<AGG_T, std::string_view>::value) {
if constexpr (std::is_same<AGG_T, std::string_view>::value) {
res = "";
} else {
res = 0;
}
uint64_t count{0};
Expand Down
11 changes: 8 additions & 3 deletions tiledb/sm/query/readers/aggregators/mean_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,14 @@ void MeanAggregator<T>::aggregate_data(AggregateBuffer& input_data) {
SafeSum>(input_data);
}

SafeSum().safe_sum(std::get<0>(res), sum_);
count_ += std::get<1>(res);
if (field_info_.is_nullable_ && std::get<1>(res) > 0) {
const auto value = std::get<0>(res);
const auto count = std::get<1>(res);
SafeSum().safe_sum(value, sum_);
count_ += count;

// Here we know that if the count is greater than 0, it means at least one
// valid item was found, which means the result is valid.
if (field_info_.is_nullable_ && count) {
validity_value_ = 1;
}
} catch (std::overflow_error& e) {
Expand Down
13 changes: 8 additions & 5 deletions tiledb/sm/query/readers/aggregators/min_max_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,17 @@ void ComparatorAggregator<T, Op>::aggregate_data(AggregateBuffer& input_data) {
// This might be called on multiple threads, the final result stored in
// value_ should be computed in a thread safe manner.
std::unique_lock lock(value_mtx_);
if (std::get<1>(res) > 0 &&
const auto value = std::get<0>(res);
const auto count = std::get<1>(res);
if (count > 0 &&
(ComparatorAggregatorBase<T>::value_ == std::nullopt ||
op_(std::get<0>(res), ComparatorAggregatorBase<T>::value_.value()))) {
ComparatorAggregatorBase<T>::value_ = std::get<0>(res);
op_(value, ComparatorAggregatorBase<T>::value_.value()))) {
ComparatorAggregatorBase<T>::value_ = value;
}

if (ComparatorAggregatorBase<T>::field_info_.is_nullable_ &&
std::get<1>(res) > 0) {
// Here we know that if the count is greater than 0, it means at least one
// valid item was found, which means the result is valid.
if (ComparatorAggregatorBase<T>::field_info_.is_nullable_ && count > 0) {
ComparatorAggregatorBase<T>::validity_value_ = 1;
}
}
Expand Down
9 changes: 7 additions & 2 deletions tiledb/sm/query/readers/aggregators/sum_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ void SumAggregator<T>::aggregate_data(AggregateBuffer& input_data) {
SafeSum>(input_data);
}

SafeSum().safe_sum(std::get<0>(res), sum_);
if (field_info_.is_nullable_ && std::get<1>(res) > 0) {
const auto value = std::get<0>(res);
const auto count = std::get<1>(res);
SafeSum().safe_sum(value, sum_);

// Here we know that if the count is greater than 0, it means at least one
// valid item was found, which means the result is valid.
if (field_info_.is_nullable_ && count) {
validity_value_ = 1;
}
} catch (std::overflow_error& e) {
Expand Down

0 comments on commit 4f4ff49

Please sign in to comment.