Skip to content

Commit

Permalink
Refactor merge aggregate and expression binder code
Browse files Browse the repository at this point in the history
The PhysicalMergeAggregate class was refactored by introducing a new method, SimpleMergeAggregateExecute, to encapsulate some of the implementation details. The newly defined method simplifies the complexity of the code and makes the execute method easier to understand. ExpressionBinder class was also modified to improve clarity of the division function expression building process, simplifying it and making code more maintainable. Minor tweaks were also made to the fragment_context.cpp file to further enhance readability.
  • Loading branch information
loloxwg authored and yuzhichang committed Dec 30, 2023
1 parent 99b6802 commit 6bdfbc1
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 52 deletions.
85 changes: 44 additions & 41 deletions src/executor/operator/physical_merge_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,54 +37,57 @@ bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState
LOG_TRACE("PhysicalMergeAggregate::Execute:: mark");
auto merge_aggregate_op_state = static_cast<MergeAggregateOperatorState *>(operator_state);

if (!merge_aggregate_op_state->input_complete_) {
SimpleMergeAggregateExecute(merge_aggregate_op_state);

// auto result = SimpleMergeAggregateExecute(merge_aggregate_op_state->input_data_blocks_, merge_aggregate_op_state->data_block_array_);
if (merge_aggregate_op_state->input_complete_) {

if (merge_aggregate_op_state->data_block_array_.size() == 0) {
merge_aggregate_op_state->data_block_array_.emplace_back(Move(merge_aggregate_op_state->input_data_block_));
} else {
LOG_TRACE("PhysicalMergeAggregate::Input is complete");
// for (auto &output_block : merge_aggregate_op_state->data_block_array_) {
// output_block->Finalize();
// }

auto agg_op = dynamic_cast<PhysicalAggregate *>(this->left());

for (SizeT col_idx = 0; auto expr : agg_op->aggregates_) {
auto agg_expression = static_cast<AggregateExpression *>(expr.get());

auto function_name = agg_expression->aggregate_function_.GetFuncName();

auto function_return_type = agg_expression->aggregate_function_.return_type_;

if (String(function_name) == String("COUNT")) {
LOG_TRACE("PhysicalAggregate::Execute:: COUNT");
UpdateBlockData(merge_aggregate_op_state, col_idx);
} else if (String(function_name) == String("MIN")) {
LOG_TRACE("PhysicalAggregate::Execute:: MIN");
IntegerT input_int = GetDataFromValueAtInputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT out_int = GetDataFromValueAtOutputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT new_int = MinValue(input_int, out_int);
WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int);
} else if (String(function_name) == String("MAX")) {
LOG_TRACE("PhysicalAggregate::Execute:: MAX");
merge_aggregate_op_state->SetComplete();
return true;
}

IntegerT input_int = GetDataFromValueAtInputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT out_int = GetDataFromValueAtOutputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT new_int = MaxValue(input_int, out_int);
WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int);
} else if (String(function_name) == String("SUM")) {
UpdateBlockData(merge_aggregate_op_state, col_idx);
}
return false;
}

++col_idx;
}
}
return false;
void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state) {
if (merge_aggregate_op_state->data_block_array_.size() == 0) {
merge_aggregate_op_state->data_block_array_.emplace_back(Move(merge_aggregate_op_state->input_data_block_));
} else {
LOG_TRACE("PhysicalAggregate::Input is complete");
for (auto &output_block : merge_aggregate_op_state->data_block_array_) {
output_block->Finalize();
auto agg_op = dynamic_cast<PhysicalAggregate *>(this->left());

for (SizeT col_idx = 0; auto expr : agg_op->aggregates_) {
auto agg_expression = static_cast<AggregateExpression *>(expr.get());

auto function_name = agg_expression->aggregate_function_.GetFuncName();

auto function_return_type = agg_expression->aggregate_function_.return_type_;

if (String(function_name) == String("COUNT")) {
LOG_TRACE("PhysicalAggregate::Execute:: COUNT");
UpdateBlockData(merge_aggregate_op_state, col_idx);
} else if (String(function_name) == String("MIN")) {
LOG_TRACE("PhysicalAggregate::Execute:: MIN");
IntegerT input_int = GetDataFromValueAtInputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT out_int = GetDataFromValueAtOutputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT new_int = MinValue(input_int, out_int);
WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int);
} else if (String(function_name) == String("MAX")) {
LOG_TRACE("PhysicalAggregate::Execute:: MAX");

IntegerT input_int = GetDataFromValueAtInputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT out_int = GetDataFromValueAtOutputBlockPosition<IntegerT>(merge_aggregate_op_state, 0, col_idx, 0);
IntegerT new_int = MaxValue(input_int, out_int);
WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int);
} else if (String(function_name) == String("SUM")) {
UpdateBlockData(merge_aggregate_op_state, col_idx);
}

++col_idx;
}
merge_aggregate_op_state->SetComplete();
return true;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/executor/operator/physical_merge_aggregate.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ public:

IntegerT MaxValue(IntegerT a, IntegerT b) { return (a > b) ? a : b; }

void SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state) ;

private:
SharedPtr<Vector<String>> output_names_{};
SharedPtr<Vector<SharedPtr<DataType>>> output_types_{};
Expand Down
38 changes: 28 additions & 10 deletions src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace infinity {

SharedPtr<BaseExpression> ExpressionBinder::Bind(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) {
// Call implemented BuildExpression

SharedPtr<BaseExpression> result = BuildExpression(expr, bind_context_ptr, depth, root);
if (result.get() == nullptr) {
if (result.get() == nullptr) {
Expand Down Expand Up @@ -286,23 +287,40 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex
Vector<String> col_names{};
ColumnExpr *col_expr = nullptr;
FunctionExpr *div_function_expr = new FunctionExpr();
FunctionExpr *sum_function_expr = new FunctionExpr();
FunctionExpr *count_function_expr = new FunctionExpr();
Vector<ParsedExpr *> *arguments = new Vector<ParsedExpr *>();
div_function_expr->func_name_ = String("div");
div_function_expr->arguments_ = new Vector<ParsedExpr*>();

if (expr.arguments_->size() == 1) {
if ((*expr.arguments_)[0]->type_ == ParsedExprType::kColumn) {
col_expr = (ColumnExpr *)(*expr.arguments_)[0];
col_names = col_expr->names_;
}
FunctionExpr *sum_function_expr = new FunctionExpr();
sum_function_expr->func_name_ = String("sum");
sum_function_expr->arguments_->reserve(1);
sum_function_expr->arguments_ = new Vector<ParsedExpr*>();
ColumnExpr* column_expr_for_sum = new ColumnExpr();
column_expr_for_sum->names_.emplace_back(col_names[0]);
sum_function_expr->arguments_->emplace_back(column_expr_for_sum);

FunctionExpr *count_function_expr = new FunctionExpr();
count_function_expr->func_name_= String("count");
sum_function_expr->arguments_->reserve(1);
*count_function_expr->arguments_->emplace_back(col_expr);
arguments->emplace_back(sum_function_expr);
arguments->emplace_back(count_function_expr);
div_function_expr->arguments_=arguments;
}
count_function_expr->arguments_ = new Vector<ParsedExpr*>();
ColumnExpr* column_expr_for_count = new ColumnExpr();
column_expr_for_sum->names_.emplace_back(col_names[0]);
count_function_expr->arguments_->emplace_back(column_expr_for_count);

div_function_expr->arguments_->emplace_back(sum_function_expr);
div_function_expr->arguments_->emplace_back(count_function_expr);
}

// Vector<SharedPtr<BaseExpression>> arguments;
// arguments.reserve(div_function_expr->arguments_->size());
// for (const auto *arg_expr : *div_function_expr->arguments_) {
// // The argument expression isn't root expression.
// // SharedPtr<BaseExpression> expr_ptr
// auto expr_ptr = BuildExpression(*arg_expr, bind_context_ptr, depth, false);
// arguments.emplace_back(expr_ptr);
// }
}

Vector<SharedPtr<BaseExpression>> arguments;
Expand Down
4 changes: 3 additions & 1 deletion src/scheduler/fragment_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,9 @@ void FragmentContext::MakeSinkState(i64 parallel_count) {
}

for (u64 task_id = 0; (i64)task_id < parallel_count; ++task_id) {
tasks_[task_id]->sink_state_ = MakeUnique<QueueSinkState>(fragment_ptr_->FragmentID(), task_id);
auto sink_state = MakeUnique<QueueSinkState>(fragment_ptr_->FragmentID(), task_id);

tasks_[task_id]->sink_state_ = Move(sink_state);
}
break;
}
Expand Down

0 comments on commit 6bdfbc1

Please sign in to comment.