diff --git a/be/src/exec/aggregator.cpp b/be/src/exec/aggregator.cpp index a14f33672f018..ca5bfd43db37c 100644 --- a/be/src/exec/aggregator.cpp +++ b/be/src/exec/aggregator.cpp @@ -808,6 +808,7 @@ Status Aggregator::compute_single_agg_state(Chunk* chunk, size_t chunk_size) { bool use_intermediate = _use_intermediate_as_input(); auto& agg_expr_ctxs = use_intermediate ? _intermediate_agg_expr_ctxs : _agg_expr_ctxs; + TRY_CATCH_ALLOC_SCOPE_START() for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) { // evaluate arguments at i-th agg function RETURN_IF_ERROR(evaluate_agg_input_column(chunk, agg_expr_ctxs[i], i)); @@ -822,6 +823,7 @@ Status Aggregator::compute_single_agg_state(Chunk* chunk, size_t chunk_size) { _agg_input_columns[i][0].get(), 0, chunk_size); } } + TRY_CATCH_ALLOC_SCOPE_END(); RETURN_IF_ERROR(check_has_error()); return Status::OK(); } @@ -831,6 +833,7 @@ Status Aggregator::compute_batch_agg_states(Chunk* chunk, size_t chunk_size) { bool use_intermediate = _use_intermediate_as_input(); auto& agg_expr_ctxs = use_intermediate ? _intermediate_agg_expr_ctxs : _agg_expr_ctxs; + TRY_CATCH_ALLOC_SCOPE_START() for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) { // evaluate arguments at i-th agg function RETURN_IF_ERROR(evaluate_agg_input_column(chunk, agg_expr_ctxs[i], i)); @@ -845,6 +848,7 @@ Status Aggregator::compute_batch_agg_states(Chunk* chunk, size_t chunk_size) { _agg_input_columns[i][0].get(), _tmp_agg_states.data()); } } + TRY_CATCH_ALLOC_SCOPE_END(); RETURN_IF_ERROR(check_has_error()); return Status::OK(); } @@ -854,6 +858,7 @@ Status Aggregator::compute_batch_agg_states_with_selection(Chunk* chunk, size_t bool use_intermediate = _use_intermediate_as_input(); auto& agg_expr_ctxs = use_intermediate ? _intermediate_agg_expr_ctxs : _agg_expr_ctxs; + TRY_CATCH_ALLOC_SCOPE_START() for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) { RETURN_IF_ERROR(evaluate_agg_input_column(chunk, agg_expr_ctxs[i], i)); SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(_allocator.get()); @@ -868,6 +873,7 @@ Status Aggregator::compute_batch_agg_states_with_selection(Chunk* chunk, size_t _tmp_agg_states.data(), _streaming_selection); } } + TRY_CATCH_ALLOC_SCOPE_END(); RETURN_IF_ERROR(check_has_error()); return Status::OK(); } diff --git a/be/src/exprs/agg/distinct.h b/be/src/exprs/agg/distinct.h index a00d6652da205..86dc8459f3f56 100644 --- a/be/src/exprs/agg/distinct.h +++ b/be/src/exprs/agg/distinct.h @@ -75,7 +75,7 @@ struct DistinctAggregateState> { DCHECK(output.length() == set.dump_bound()); } - void deserialize_and_merge(const uint8_t* src, size_t len) { + void deserialize_and_merge(FunctionContext* ctx, const uint8_t* src, size_t len) { phmap::InMemoryInput input(reinterpret_cast(src)); auto old_size = set.size(); if (old_size == 0) { @@ -83,6 +83,7 @@ struct DistinctAggregateState> { } else { MyHashSet set_src; set_src.load(input); + // TODO: modify merge set.merge(set_src); } } @@ -158,8 +159,9 @@ struct DistinctAggregateState> { } } - void deserialize_and_merge(MemPool* mem_pool, const uint8_t* src, size_t len) { + void deserialize_and_merge(FunctionContext* ctx, MemPool* mem_pool, const uint8_t* src, size_t len) { const uint8_t* end = src + len; + int64_t i = 0; while (src < end) { uint32_t size = 0; memcpy(&size, src, sizeof(uint32_t)); @@ -178,6 +180,12 @@ struct DistinctAggregateState> { ctor(pos, key.size, key.hash); }); src += size; + i++; + if (i % 4096 == 0) { + if (ctx->has_error()) { + return; + } + } } DCHECK(src == end); } @@ -219,7 +227,7 @@ struct DistinctAggregateStateV2> { } } - void deserialize_and_merge(const uint8_t* src, size_t len) { + void deserialize_and_merge(FunctionContext* ctx, const uint8_t* src, size_t len) { size_t size = 0; memcpy(&size, src, sizeof(size)); set.rehash(set.size() + size); @@ -350,13 +358,13 @@ class TDistinctAggregateFunction : public AggregateFunctionBatchHelper< const auto* input_column = down_cast(column); Slice slice = input_column->get_slice(row_num); if constexpr (IsSlice) { - this->data(state).deserialize_and_merge(ctx->mem_pool(), (const uint8_t*)slice.data, slice.size); + this->data(state).deserialize_and_merge(ctx, ctx->mem_pool(), (const uint8_t*)slice.data, slice.size); } else { // slice size larger than `MIN_SIZE_OF_HASH_SET_SERIALIZED_DATA`, means which is a hash set // that's said, size of hash set serialization data should be larger than `MIN_SIZE_OF_HASH_SET_SERIALIZED_DATA` // otherwise this slice could be reinterpreted as a single value going be to inserted into hashset. if (slice.size >= MIN_SIZE_OF_HASH_SET_SERIALIZED_DATA) { - this->data(state).deserialize_and_merge((const uint8_t*)slice.data, slice.size); + this->data(state).deserialize_and_merge(ctx, (const uint8_t*)slice.data, slice.size); } else { T key; memcpy(&key, slice.data, sizeof(T)); @@ -512,7 +520,7 @@ class DictMergeAggregateFunction final const auto* input_column = down_cast(column); Slice slice = input_column->get_slice(row_num); - this->data(state).deserialize_and_merge(ctx->mem_pool(), (const uint8_t*)slice.data, slice.size); + this->data(state).deserialize_and_merge(ctx, ctx->mem_pool(), (const uint8_t*)slice.data, slice.size); agg_state.over_limit = agg_state.set.size() > DICT_DECODE_MAX_SIZE; }