diff --git a/src/planner/binder/project_binder.cpp b/src/planner/binder/project_binder.cpp index 924047e673..944e4f3787 100644 --- a/src/planner/binder/project_binder.cpp +++ b/src/planner/binder/project_binder.cpp @@ -27,6 +27,28 @@ import infinity_exception; module project_binder; +namespace { + +using namespace infinity; + +void ConvertAvgToSumDivideCount(FunctionExpr &func_expression, const Vector &column_names) { + func_expression.func_name_ = "/"; + func_expression.arguments_->clear(); + auto createFunctionWithColumnArg = [&column_names](const String &func_name) { + auto function_expression = MakeUnique(); + function_expression->func_name_ = func_name; + function_expression->arguments_ = new Vector(); + auto column_expr = MakeUnique(); + column_expr->names_.push_back(column_names[0]); + function_expression->arguments_->push_back(column_expr.release()); + return function_expression.release(); + }; + func_expression.arguments_->push_back(createFunctionWithColumnArg("sum")); + func_expression.arguments_->push_back(createFunctionWithColumnArg("count")); +} + +} // namespace + namespace infinity { SharedPtr ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { @@ -34,53 +56,22 @@ SharedPtr ProjectBinder::BuildExpression(const ParsedExpr &expr, // Covert avg function expr to (sum / count) function expr if (expr.type_ == ParsedExprType::kFunction) { - auto special_function = TryBuildSpecialFuncExpr((FunctionExpr &)expr, bind_context_ptr, depth); - + auto &function_expression = (FunctionExpr &)expr; + auto special_function = TryBuildSpecialFuncExpr(function_expression, bind_context_ptr, depth); if (special_function.has_value()) { return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); } - - SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), (FunctionExpr &)expr); - - CheckFuncType(function_set_ptr->type_); - if (IsEqual(function_set_ptr->name(), String("AVG"))) { - auto &fun_expr = (FunctionExpr &)expr; - Vector col_names{}; - - if (fun_expr.arguments_->size() == 1) { - if ((*fun_expr.arguments_)[0]->type_ == ParsedExprType::kColumn) { - ColumnExpr *col_expr = (ColumnExpr *)(*fun_expr.arguments_)[0]; - col_names = Move(col_expr->names_); - - fun_expr.func_name_ = "/"; - delete col_expr; - delete fun_expr.arguments_; - fun_expr.arguments_ = nullptr; - fun_expr.arguments_ = new Vector(); - - FunctionExpr *sum_function_expr = new FunctionExpr(); - sum_function_expr->func_name_ = String("sum"); - sum_function_expr->arguments_ = new Vector(); - ColumnExpr *column_expr_for_sum = new ColumnExpr(); - column_expr_for_sum->names_.emplace_back(col_names[0]); - sum_function_expr->arguments_->emplace_back(column_expr_for_sum); - - FunctionExpr *count_function_expr = new FunctionExpr(); - count_function_expr->func_name_ = String("count"); - count_function_expr->arguments_ = new Vector(); - ColumnExpr *column_expr_for_count = new ColumnExpr(); - column_expr_for_count->names_.emplace_back(col_names[0]); - count_function_expr->arguments_->emplace_back(column_expr_for_count); - - fun_expr.arguments_->emplace_back(sum_function_expr); - fun_expr.arguments_->emplace_back(count_function_expr); - - return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); - } - } + auto function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), function_expression); + + if (IsEqual(function_set_ptr->name(), String("AVG")) && function_expression.arguments_->size() == 1 && + (*function_expression.arguments_)[0]->type_ == ParsedExprType::kColumn) { + auto column_expr = (ColumnExpr *)(*function_expression.arguments_)[0]; + Vector column_names(Move(column_expr->names_)); + delete column_expr; + ConvertAvgToSumDivideCount(function_expression, column_names); + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); } } - // If the expr isn't from aggregate function and coming from group by lists. if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) { i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name];