diff --git a/tiledb/sm/query/readers/aggregators/aggregate_with_count.h b/tiledb/sm/query/readers/aggregators/aggregate_with_count.h index c1afc6d2db4..e18bd81b266 100644 --- a/tiledb/sm/query/readers/aggregators/aggregate_with_count.h +++ b/tiledb/sm/query/readers/aggregators/aggregate_with_count.h @@ -76,14 +76,18 @@ class AggregateWithCount { * @tparam AggPolicy Aggregation policy. * @param input_data Input data for the aggregation. * - * @return {Aggregate value, number of cells, optional validity value}. + * NOTE: Count of cells returned is used to infer the validity from the + * caller. + * @return {Aggregate value, count of cells}. */ template tuple aggregate(AggregateBuffer& input_data) { typedef typename type_data::value_type VALUE_T; AggPolicy agg_policy; AGG_T res; - if constexpr (!std::is_same::value) { + if constexpr (std::is_same::value) { + res = ""; + } else { res = 0; } uint64_t count{0}; diff --git a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc index 12176ac1764..ae944140d24 100644 --- a/tiledb/sm/query/readers/aggregators/mean_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/mean_aggregator.cc @@ -110,9 +110,14 @@ void MeanAggregator::aggregate_data(AggregateBuffer& input_data) { SafeSum>(input_data); } - SafeSum().safe_sum(std::get<0>(res), sum_); - count_ += std::get<1>(res); - if (field_info_.is_nullable_ && std::get<1>(res) > 0) { + const auto value = std::get<0>(res); + const auto count = std::get<1>(res); + SafeSum().safe_sum(value, sum_); + count_ += count; + + // Here we know that if the count is greater than 0, it means at least one + // valid item was found, which means the result is valid. + if (field_info_.is_nullable_ && count) { validity_value_ = 1; } } catch (std::overflow_error& e) { diff --git a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc index 6c2c06f1930..df924af0c9d 100644 --- a/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/min_max_aggregator.cc @@ -162,14 +162,17 @@ void ComparatorAggregator::aggregate_data(AggregateBuffer& input_data) { // This might be called on multiple threads, the final result stored in // value_ should be computed in a thread safe manner. std::unique_lock lock(value_mtx_); - if (std::get<1>(res) > 0 && + const auto value = std::get<0>(res); + const auto count = std::get<1>(res); + if (count > 0 && (ComparatorAggregatorBase::value_ == std::nullopt || - op_(std::get<0>(res), ComparatorAggregatorBase::value_.value()))) { - ComparatorAggregatorBase::value_ = std::get<0>(res); + op_(value, ComparatorAggregatorBase::value_.value()))) { + ComparatorAggregatorBase::value_ = value; } - if (ComparatorAggregatorBase::field_info_.is_nullable_ && - std::get<1>(res) > 0) { + // Here we know that if the count is greater than 0, it means at least one + // valid item was found, which means the result is valid. + if (ComparatorAggregatorBase::field_info_.is_nullable_ && count > 0) { ComparatorAggregatorBase::validity_value_ = 1; } } diff --git a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc index 1eaed2c09da..50033cf3cb4 100644 --- a/tiledb/sm/query/readers/aggregators/sum_aggregator.cc +++ b/tiledb/sm/query/readers/aggregators/sum_aggregator.cc @@ -103,8 +103,13 @@ void SumAggregator::aggregate_data(AggregateBuffer& input_data) { SafeSum>(input_data); } - SafeSum().safe_sum(std::get<0>(res), sum_); - if (field_info_.is_nullable_ && std::get<1>(res) > 0) { + const auto value = std::get<0>(res); + const auto count = std::get<1>(res); + SafeSum().safe_sum(value, sum_); + + // Here we know that if the count is greater than 0, it means at least one + // valid item was found, which means the result is valid. + if (field_info_.is_nullable_ && count) { validity_value_ = 1; } } catch (std::overflow_error& e) {