From 6bdfbc18bad18d282526d220849ccb75298fca2d Mon Sep 17 00:00:00 2001 From: Xwg Date: Fri, 29 Dec 2023 01:59:32 +0800 Subject: [PATCH] Refactor merge aggregate and expression binder code The PhysicalMergeAggregate class was refactored by introducing a new method, SimpleMergeAggregateExecute, to encapsulate some of the implementation details. The newly defined method simplifies the complexity of the code and makes the execute method easier to understand. ExpressionBinder class was also modified to improve clarity of the division function expression building process, simplifying it and making code more maintainable. Minor tweaks were also made to the fragment_context.cpp file to further enhance readability. --- .../operator/physical_merge_aggregate.cpp | 85 ++++++++++--------- .../operator/physical_merge_aggregate.cppm | 2 + src/planner/expression_binder.cpp | 38 ++++++--- src/scheduler/fragment_context.cpp | 4 +- 4 files changed, 77 insertions(+), 52 deletions(-) diff --git a/src/executor/operator/physical_merge_aggregate.cpp b/src/executor/operator/physical_merge_aggregate.cpp index a260d207ff..1fd86a4ab2 100644 --- a/src/executor/operator/physical_merge_aggregate.cpp +++ b/src/executor/operator/physical_merge_aggregate.cpp @@ -37,54 +37,57 @@ bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState LOG_TRACE("PhysicalMergeAggregate::Execute:: mark"); auto merge_aggregate_op_state = static_cast(operator_state); - if (!merge_aggregate_op_state->input_complete_) { + SimpleMergeAggregateExecute(merge_aggregate_op_state); - // auto result = SimpleMergeAggregateExecute(merge_aggregate_op_state->input_data_blocks_, merge_aggregate_op_state->data_block_array_); + if (merge_aggregate_op_state->input_complete_) { - if (merge_aggregate_op_state->data_block_array_.size() == 0) { - merge_aggregate_op_state->data_block_array_.emplace_back(Move(merge_aggregate_op_state->input_data_block_)); - } else { + LOG_TRACE("PhysicalMergeAggregate::Input is complete"); + // for (auto &output_block : merge_aggregate_op_state->data_block_array_) { + // output_block->Finalize(); + // } - auto agg_op = dynamic_cast(this->left()); - - for (SizeT col_idx = 0; auto expr : agg_op->aggregates_) { - auto agg_expression = static_cast(expr.get()); - - auto function_name = agg_expression->aggregate_function_.GetFuncName(); - - auto function_return_type = agg_expression->aggregate_function_.return_type_; - - if (String(function_name) == String("COUNT")) { - LOG_TRACE("PhysicalAggregate::Execute:: COUNT"); - UpdateBlockData(merge_aggregate_op_state, col_idx); - } else if (String(function_name) == String("MIN")) { - LOG_TRACE("PhysicalAggregate::Execute:: MIN"); - IntegerT input_int = GetDataFromValueAtInputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); - IntegerT out_int = GetDataFromValueAtOutputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); - IntegerT new_int = MinValue(input_int, out_int); - WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int); - } else if (String(function_name) == String("MAX")) { - LOG_TRACE("PhysicalAggregate::Execute:: MAX"); + merge_aggregate_op_state->SetComplete(); + return true; + } - IntegerT input_int = GetDataFromValueAtInputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); - IntegerT out_int = GetDataFromValueAtOutputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); - IntegerT new_int = MaxValue(input_int, out_int); - WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int); - } else if (String(function_name) == String("SUM")) { - UpdateBlockData(merge_aggregate_op_state, col_idx); - } + return false; +} - ++col_idx; - } - } - return false; +void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state) { + if (merge_aggregate_op_state->data_block_array_.size() == 0) { + merge_aggregate_op_state->data_block_array_.emplace_back(Move(merge_aggregate_op_state->input_data_block_)); } else { - LOG_TRACE("PhysicalAggregate::Input is complete"); - for (auto &output_block : merge_aggregate_op_state->data_block_array_) { - output_block->Finalize(); + auto agg_op = dynamic_cast(this->left()); + + for (SizeT col_idx = 0; auto expr : agg_op->aggregates_) { + auto agg_expression = static_cast(expr.get()); + + auto function_name = agg_expression->aggregate_function_.GetFuncName(); + + auto function_return_type = agg_expression->aggregate_function_.return_type_; + + if (String(function_name) == String("COUNT")) { + LOG_TRACE("PhysicalAggregate::Execute:: COUNT"); + UpdateBlockData(merge_aggregate_op_state, col_idx); + } else if (String(function_name) == String("MIN")) { + LOG_TRACE("PhysicalAggregate::Execute:: MIN"); + IntegerT input_int = GetDataFromValueAtInputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); + IntegerT out_int = GetDataFromValueAtOutputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); + IntegerT new_int = MinValue(input_int, out_int); + WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int); + } else if (String(function_name) == String("MAX")) { + LOG_TRACE("PhysicalAggregate::Execute:: MAX"); + + IntegerT input_int = GetDataFromValueAtInputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); + IntegerT out_int = GetDataFromValueAtOutputBlockPosition(merge_aggregate_op_state, 0, col_idx, 0); + IntegerT new_int = MaxValue(input_int, out_int); + WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int); + } else if (String(function_name) == String("SUM")) { + UpdateBlockData(merge_aggregate_op_state, col_idx); + } + + ++col_idx; } - merge_aggregate_op_state->SetComplete(); - return true; } } diff --git a/src/executor/operator/physical_merge_aggregate.cppm b/src/executor/operator/physical_merge_aggregate.cppm index ab99d0f1b6..38c25fa8cd 100644 --- a/src/executor/operator/physical_merge_aggregate.cppm +++ b/src/executor/operator/physical_merge_aggregate.cppm @@ -81,6 +81,8 @@ public: IntegerT MaxValue(IntegerT a, IntegerT b) { return (a > b) ? a : b; } + void SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state) ; + private: SharedPtr> output_names_{}; SharedPtr>> output_types_{}; diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index b7fa1a727d..7d767dd875 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -64,6 +64,7 @@ namespace infinity { SharedPtr ExpressionBinder::Bind(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { // Call implemented BuildExpression + SharedPtr result = BuildExpression(expr, bind_context_ptr, depth, root); if (result.get() == nullptr) { if (result.get() == nullptr) { @@ -286,23 +287,40 @@ SharedPtr ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex Vector col_names{}; ColumnExpr *col_expr = nullptr; FunctionExpr *div_function_expr = new FunctionExpr(); - FunctionExpr *sum_function_expr = new FunctionExpr(); - FunctionExpr *count_function_expr = new FunctionExpr(); - Vector *arguments = new Vector(); + div_function_expr->func_name_ = String("div"); + div_function_expr->arguments_ = new Vector(); + if (expr.arguments_->size() == 1) { if ((*expr.arguments_)[0]->type_ == ParsedExprType::kColumn) { col_expr = (ColumnExpr *)(*expr.arguments_)[0]; col_names = col_expr->names_; } + FunctionExpr *sum_function_expr = new FunctionExpr(); sum_function_expr->func_name_ = String("sum"); - sum_function_expr->arguments_->reserve(1); + sum_function_expr->arguments_ = new Vector(); + ColumnExpr* column_expr_for_sum = new ColumnExpr(); + column_expr_for_sum->names_.emplace_back(col_names[0]); + sum_function_expr->arguments_->emplace_back(column_expr_for_sum); + + FunctionExpr *count_function_expr = new FunctionExpr(); count_function_expr->func_name_= String("count"); - sum_function_expr->arguments_->reserve(1); - *count_function_expr->arguments_->emplace_back(col_expr); - arguments->emplace_back(sum_function_expr); - arguments->emplace_back(count_function_expr); - div_function_expr->arguments_=arguments; - } + count_function_expr->arguments_ = new Vector(); + ColumnExpr* column_expr_for_count = new ColumnExpr(); + column_expr_for_sum->names_.emplace_back(col_names[0]); + count_function_expr->arguments_->emplace_back(column_expr_for_count); + + div_function_expr->arguments_->emplace_back(sum_function_expr); + div_function_expr->arguments_->emplace_back(count_function_expr); + } + + // Vector> arguments; + // arguments.reserve(div_function_expr->arguments_->size()); + // for (const auto *arg_expr : *div_function_expr->arguments_) { + // // The argument expression isn't root expression. + // // SharedPtr expr_ptr + // auto expr_ptr = BuildExpression(*arg_expr, bind_context_ptr, depth, false); + // arguments.emplace_back(expr_ptr); + // } } Vector> arguments; diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index 5061ab4a49..21d1fafb8e 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -695,7 +695,9 @@ void FragmentContext::MakeSinkState(i64 parallel_count) { } for (u64 task_id = 0; (i64)task_id < parallel_count; ++task_id) { - tasks_[task_id]->sink_state_ = MakeUnique(fragment_ptr_->FragmentID(), task_id); + auto sink_state = MakeUnique(fragment_ptr_->FragmentID(), task_id); + + tasks_[task_id]->sink_state_ = Move(sink_state); } break; }