From aa14de8f90b5a17482dbe2428500dcab89899da5 Mon Sep 17 00:00:00 2001 From: Xwg Date: Fri, 29 Dec 2023 18:29:40 +0800 Subject: [PATCH] For avg functions replacing it with (sum / count) The commit involves converting avg function expressions to (sum / count) function expressions in the SQL planner and test updates accordingly. Tests cover both the regular and exceptional cases. Corresponding changes are reflected in other parts of the code like ProjectBinder and Aggregate operators. --- src/executor/fragment_builder.cpp | 1 - .../operator/physical_merge_aggregate.cpp | 7 +- src/planner/binder/project_binder.cpp | 59 ++++++++- src/planner/expression_binder.cpp | 42 ------ test/sql/basic.slt | 3 + test/sql/dql/aggregate/test_simple_agg.slt | 47 ++++++- test/sql/dql/rbo_rule/column_pruner.slt | 10 +- tools/generate_aggregate.py | 123 ++++++++++++++++++ tools/sqllogictest.py | 2 + 9 files changed, 242 insertions(+), 52 deletions(-) create mode 100644 tools/generate_aggregate.py diff --git a/src/executor/fragment_builder.cpp b/src/executor/fragment_builder.cpp index 4ace7f677d..580e46c6df 100644 --- a/src/executor/fragment_builder.cpp +++ b/src/executor/fragment_builder.cpp @@ -233,7 +233,6 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu phys_op->left()->GetOutputTypes()); BuildFragments(phys_op->left(), next_plan_fragment.get()); current_fragment_ptr->AddChild(Move(next_plan_fragment)); - if (phys_op->right() != nullptr) { auto next_plan_fragment = MakeUnique(GetFragmentId()); next_plan_fragment->SetSinkNode(query_context_ptr_, diff --git a/src/executor/operator/physical_merge_aggregate.cpp b/src/executor/operator/physical_merge_aggregate.cpp index 1fd86a4ab2..07afa931b2 100644 --- a/src/executor/operator/physical_merge_aggregate.cpp +++ b/src/executor/operator/physical_merge_aggregate.cpp @@ -42,9 +42,9 @@ bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState if (merge_aggregate_op_state->input_complete_) { LOG_TRACE("PhysicalMergeAggregate::Input is complete"); - // for (auto &output_block : merge_aggregate_op_state->data_block_array_) { - // output_block->Finalize(); - // } + for (auto &output_block : merge_aggregate_op_state->data_block_array_) { + output_block->Finalize(); + } merge_aggregate_op_state->SetComplete(); return true; @@ -83,6 +83,7 @@ void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorS IntegerT new_int = MaxValue(input_int, out_int); WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int); } else if (String(function_name) == String("SUM")) { + LOG_TRACE("PhysicalAggregate::Execute:: COUNT"); UpdateBlockData(merge_aggregate_op_state, col_idx); } diff --git a/src/planner/binder/project_binder.cpp b/src/planner/binder/project_binder.cpp index 7f2386fce5..c86cf58bad 100644 --- a/src/planner/binder/project_binder.cpp +++ b/src/planner/binder/project_binder.cpp @@ -32,6 +32,58 @@ namespace infinity { SharedPtr ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { String expr_name = expr.GetName(); + // Covert avg function expr to (sum / count) function expr + if (expr.type_ == ParsedExprType::kFunction) { + auto special_function = TryBuildSpecialFuncExpr((FunctionExpr &)expr, bind_context_ptr, depth); + + if (special_function.has_value()) { + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); + } + + SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), (FunctionExpr &)expr); + + CheckFuncType(function_set_ptr->type_); + if (IsEqual(function_set_ptr->name(), String("AVG"))) { + auto &fun_expr = (FunctionExpr &)expr; + Vector col_names{}; + FunctionExpr *div_function_expr = new FunctionExpr(); + div_function_expr->func_name_ = String("/"); + div_function_expr->arguments_ = new Vector(); + + if (fun_expr.arguments_->size() == 1) { + if ((*fun_expr.arguments_)[0]->type_ == ParsedExprType::kColumn) { + ColumnExpr *col_expr = (ColumnExpr *)(*fun_expr.arguments_)[0]; + col_names = Move(col_expr->names_); + + fun_expr.func_name_ = "/"; + delete fun_expr.arguments_; + fun_expr.arguments_ = nullptr; + } + + FunctionExpr *sum_function_expr = new FunctionExpr(); + sum_function_expr->func_name_ = String("sum"); + sum_function_expr->arguments_ = new Vector(); + ColumnExpr *column_expr_for_sum = new ColumnExpr(); + column_expr_for_sum->names_.emplace_back(col_names[0]); + sum_function_expr->arguments_->emplace_back(column_expr_for_sum); + + FunctionExpr *count_function_expr = new FunctionExpr(); + count_function_expr->func_name_ = String("count"); + count_function_expr->arguments_ = new Vector(); + ColumnExpr *column_expr_for_count = new ColumnExpr(); + column_expr_for_count->names_.emplace_back(col_names[0]); + count_function_expr->arguments_->emplace_back(column_expr_for_count); + + div_function_expr->arguments_->emplace_back(sum_function_expr); + div_function_expr->arguments_->emplace_back(count_function_expr); + + auto div_func_name = div_function_expr->GetName(); + + return ExpressionBinder::BuildExpression(*div_function_expr, bind_context_ptr, depth, root); + } + } + } + // If the expr isn't from aggregate function and coming from group by lists. if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) { i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name]; @@ -94,8 +146,13 @@ SharedPtr ProjectBinder::BuildFuncExpr(const FunctionExpr &expr, SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), expr); if (function_set_ptr->type_ == FunctionType::kAggregate) { if (this->binding_agg_func_) { - Error("Aggregate function is called in another aggregate function."); + Error(Format("Aggregate function {} is called in another aggregate function.", function_set_ptr->name())); } else { + // if (IsEqual(function_set_ptr->name(),String("AVG"))) { + // this->binding_agg_func_ = false; + // }else { + // this->binding_agg_func_ = true; + // } this->binding_agg_func_ = true; } } diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index 7d767dd875..3ea44fd0b9 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -281,48 +281,6 @@ SharedPtr ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex } } - // covert avg function expr to (sum / count) function expr - - if (function_set_ptr->name() == "AVG") { - Vector col_names{}; - ColumnExpr *col_expr = nullptr; - FunctionExpr *div_function_expr = new FunctionExpr(); - div_function_expr->func_name_ = String("div"); - div_function_expr->arguments_ = new Vector(); - - if (expr.arguments_->size() == 1) { - if ((*expr.arguments_)[0]->type_ == ParsedExprType::kColumn) { - col_expr = (ColumnExpr *)(*expr.arguments_)[0]; - col_names = col_expr->names_; - } - FunctionExpr *sum_function_expr = new FunctionExpr(); - sum_function_expr->func_name_ = String("sum"); - sum_function_expr->arguments_ = new Vector(); - ColumnExpr* column_expr_for_sum = new ColumnExpr(); - column_expr_for_sum->names_.emplace_back(col_names[0]); - sum_function_expr->arguments_->emplace_back(column_expr_for_sum); - - FunctionExpr *count_function_expr = new FunctionExpr(); - count_function_expr->func_name_= String("count"); - count_function_expr->arguments_ = new Vector(); - ColumnExpr* column_expr_for_count = new ColumnExpr(); - column_expr_for_sum->names_.emplace_back(col_names[0]); - count_function_expr->arguments_->emplace_back(column_expr_for_count); - - div_function_expr->arguments_->emplace_back(sum_function_expr); - div_function_expr->arguments_->emplace_back(count_function_expr); - } - - // Vector> arguments; - // arguments.reserve(div_function_expr->arguments_->size()); - // for (const auto *arg_expr : *div_function_expr->arguments_) { - // // The argument expression isn't root expression. - // // SharedPtr expr_ptr - // auto expr_ptr = BuildExpression(*arg_expr, bind_context_ptr, depth, false); - // arguments.emplace_back(expr_ptr); - // } - } - Vector> arguments; arguments.reserve(expr.arguments_->size()); for (const auto *arg_expr : *expr.arguments_) { diff --git a/test/sql/basic.slt b/test/sql/basic.slt index 9cbe8f2ba2..63b82e97cd 100644 --- a/test/sql/basic.slt +++ b/test/sql/basic.slt @@ -2,6 +2,9 @@ # description: Test basic sql statement for sample # group: [basic] +statement ok +DROP TABLE IF EXISTS NATION; + # Expecting IDENTIFIER or PRIMARY or UNIQUE statement error CREATE TABLE NATION ( diff --git a/test/sql/dql/aggregate/test_simple_agg.slt b/test/sql/dql/aggregate/test_simple_agg.slt index 91d6c79730..19f8fea87c 100644 --- a/test/sql/dql/aggregate/test_simple_agg.slt +++ b/test/sql/dql/aggregate/test_simple_agg.slt @@ -22,7 +22,7 @@ SELECT SUM(c2) FROM simple_agg query I SELECT AVG(c1) FROM simple_agg ---- -2.000000 +2 query II SELECT AVG(c2) FROM simple_agg @@ -54,6 +54,51 @@ SELECT COUNT(c1) FROM simple_agg ---- 3 +query III +SELECT SUM(c1)+SUM(c1) FROM simple_agg; +---- +12 + +query III +SELECT MAX(c1)+SUM(c1) FROM simple_agg; +---- +9 + +query III +SELECT MAX(c1)+SUM(c2) FROM simple_agg; +---- +9.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)-SUM(c2) FROM simple_agg; +---- +-3.000000 + +query III +SELECT MAX(c1)/SUM(c2) FROM simple_agg; +---- +0.500000 + +query III +SELECT MAX(c1)/AVG(c2) FROM simple_agg; +---- +1.500000 + +query III +SELECT MAX(c1)*AVG(c2) FROM simple_agg; +---- +6.000000 + statement ok DROP TABLE simple_agg; diff --git a/test/sql/dql/rbo_rule/column_pruner.slt b/test/sql/dql/rbo_rule/column_pruner.slt index f6d516a46b..583854f0e5 100644 --- a/test/sql/dql/rbo_rule/column_pruner.slt +++ b/test/sql/dql/rbo_rule/column_pruner.slt @@ -112,24 +112,25 @@ EXPLAIN LOGICAL SELECT MIN(c1 + 1), AVG(c2) FROM t1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#0), avg(c2) (#1)] + - expressions: [min((c1 + 1)) (#0), sum(c2) (#1) / count(c2) (#2)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] -> TABLE SCAN (2) - table name: t1(default.t1) - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#1), avg(c2) (#2)] + - expressions: [min((c1 + 1)) (#1), sum(c2) (#2) / count(c2) (#3)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] - group by table index: #2 - group by: [c1 (#0)] -> TABLE SCAN (2) @@ -137,6 +138,7 @@ EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL DELETE FROM t1 WHERE c1=1; ---- diff --git a/tools/generate_aggregate.py b/tools/generate_aggregate.py new file mode 100644 index 0000000000..f630733dc4 --- /dev/null +++ b/tools/generate_aggregate.py @@ -0,0 +1,123 @@ +import numpy as np +import random +import os +import argparse + + +def generate(generate_if_exists: bool, copy_dir: str): + row_n = 8000 + sort_dir = "./test/data/csv" + slt_dir = "./test/sql/dql" + + table_name = "test_simple_agg_big_cpp" + agg_path = sort_dir + "/test_simple_agg_big.csv" + slt_path = slt_dir + "/test_simple_agg_big.slt" + copy_path = copy_dir + "/test_simple_agg_big.csv" + + os.makedirs(sort_dir, exist_ok=True) + os.makedirs(slt_dir, exist_ok=True) + if os.path.exists(agg_path) and os.path.exists(slt_path) and generate_if_exists: + print( + "File {} and {} already existed exists. Skip Generating.".format( + slt_path, agg_path + ) + ) + return + with open(agg_path, "w") as agg_file, open(slt_path, "w") as slt_file: + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name)) + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write( + "CREATE TABLE {} (c1 int, c2 float);\n".format(table_name) + ) + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write( + "COPY {} FROM '{}' WITH ( DELIMITER ',' );\n".format( + table_name, copy_path + ) + ) + + + sequence = np.arange(1, 8001) + + for i in sequence: + agg_file.write(str(i) + "," + str(i)) + agg_file.write("\n") + + + # select max(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT max(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(8000)) + slt_file.write("\n") + + # select min(c2) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT min(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(1)) + slt_file.write("\n") + + + # select sum(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT sum(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(np.sum(sequence))) + slt_file.write("\n") + + + # select avg(c1) from test_simple_agg_big + + # slt_file.write("\n") + # slt_file.write("query I\n") + # slt_file.write("SELECT AVG(c1) FROM {};\n".format(table_name)) + # slt_file.write("----\n") + # slt_file.write(str(np.mean(sequence))) + # slt_file.write("\n") + + # select count(c1) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(8000)) + slt_file.write("\n") + + + + + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE {};\n".format(table_name)) + random.random() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate limit data for test") + + parser.add_argument( + "-g", + "--generate", + type=bool, + default=False, + dest="generate_if_exists", + ) + parser.add_argument( + "-c", + "--copy", + type=str, + default="/tmp/infinity/test_data", + dest="copy_dir", + ) + args = parser.parse_args() + generate(args.generate_if_exists, args.copy_dir) diff --git a/tools/sqllogictest.py b/tools/sqllogictest.py index 5c49dccfb1..0c4f99cb9b 100644 --- a/tools/sqllogictest.py +++ b/tools/sqllogictest.py @@ -6,6 +6,7 @@ from generate_fvecs import generate as generate2 from generate_sort import generate as generate3 from generate_limit import generate as generate4 +from generate_aggregate import generate as generate5 def python_skd_test(python_test_dir: str): @@ -108,6 +109,7 @@ def copy_all(data_dir, copy_dir): generate2(args.generate_if_exists, args.copy) generate3(args.generate_if_exists, args.copy) generate4(args.generate_if_exists, args.copy) + generate5(args.generate_if_exists, args.copy) print("Generate file finshed.") python_skd_test(python_test_dir) test_process(args.path, args.test, args.data, args.copy)