Skip to content

Commit

Permalink
Refactor AVG function conversion in project binder
Browse files Browse the repository at this point in the history
The conversion of AVG function to SUM / COUNT has been refactored for efficiency and readability. A helper function, ConvertAvgToSumDivideCount, has been introduced to modularize this logic. The updated code significantly shrinks the BuildExpression method, improving its maintainability.
  • Loading branch information
loloxwg committed Jan 2, 2024
1 parent d799967 commit ac0699c
Showing 1 changed file with 33 additions and 42 deletions.
75 changes: 33 additions & 42 deletions src/planner/binder/project_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,60 +27,51 @@ import infinity_exception;

module project_binder;

namespace {

using namespace infinity;

void ConvertAvgToSumDivideCount(FunctionExpr &func_expression, const Vector<String> &column_names) {
func_expression.func_name_ = "/";
func_expression.arguments_->clear();
auto createFunctionWithColumnArg = [&column_names](const String &func_name) {
auto function_expression = MakeUnique<FunctionExpr>();
function_expression->func_name_ = func_name;
function_expression->arguments_ = new Vector<ParsedExpr *>();
auto column_expr = MakeUnique<ColumnExpr>();
column_expr->names_.push_back(column_names[0]);
function_expression->arguments_->push_back(column_expr.release());
return function_expression.release();
};
func_expression.arguments_->push_back(createFunctionWithColumnArg("sum"));
func_expression.arguments_->push_back(createFunctionWithColumnArg("count"));
}

} // namespace

namespace infinity {

SharedPtr<BaseExpression> ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) {
String expr_name = expr.GetName();

// Covert avg function expr to (sum / count) function expr
if (expr.type_ == ParsedExprType::kFunction) {
auto special_function = TryBuildSpecialFuncExpr((FunctionExpr &)expr, bind_context_ptr, depth);

auto &function_expression = (FunctionExpr &)expr;
auto special_function = TryBuildSpecialFuncExpr(function_expression, bind_context_ptr, depth);
if (special_function.has_value()) {
return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root);
}

SharedPtr<FunctionSet> function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), (FunctionExpr &)expr);

CheckFuncType(function_set_ptr->type_);
if (IsEqual(function_set_ptr->name(), String("AVG"))) {
auto &fun_expr = (FunctionExpr &)expr;
Vector<String> col_names{};

if (fun_expr.arguments_->size() == 1) {
if ((*fun_expr.arguments_)[0]->type_ == ParsedExprType::kColumn) {
ColumnExpr *col_expr = (ColumnExpr *)(*fun_expr.arguments_)[0];
col_names = Move(col_expr->names_);

fun_expr.func_name_ = "/";
delete col_expr;
delete fun_expr.arguments_;
fun_expr.arguments_ = nullptr;
fun_expr.arguments_ = new Vector<ParsedExpr *>();

FunctionExpr *sum_function_expr = new FunctionExpr();
sum_function_expr->func_name_ = String("sum");
sum_function_expr->arguments_ = new Vector<ParsedExpr *>();
ColumnExpr *column_expr_for_sum = new ColumnExpr();
column_expr_for_sum->names_.emplace_back(col_names[0]);
sum_function_expr->arguments_->emplace_back(column_expr_for_sum);

FunctionExpr *count_function_expr = new FunctionExpr();
count_function_expr->func_name_ = String("count");
count_function_expr->arguments_ = new Vector<ParsedExpr *>();
ColumnExpr *column_expr_for_count = new ColumnExpr();
column_expr_for_count->names_.emplace_back(col_names[0]);
count_function_expr->arguments_->emplace_back(column_expr_for_count);

fun_expr.arguments_->emplace_back(sum_function_expr);
fun_expr.arguments_->emplace_back(count_function_expr);

return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root);
}
}
auto function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), function_expression);

if (IsEqual(function_set_ptr->name(), String("AVG")) && function_expression.arguments_->size() == 1 &&
(*function_expression.arguments_)[0]->type_ == ParsedExprType::kColumn) {
auto column_expr = (ColumnExpr *)(*function_expression.arguments_)[0];
Vector<String> column_names(Move(column_expr->names_));
delete column_expr;
ConvertAvgToSumDivideCount(function_expression, column_names);
return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root);
}
}

// If the expr isn't from aggregate function and coming from group by lists.
if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) {
i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name];
Expand Down

0 comments on commit ac0699c

Please sign in to comment.