diff --git a/src/executor/fragment_builder.cpp b/src/executor/fragment_builder.cpp index 4ace7f677d..580e46c6df 100644 --- a/src/executor/fragment_builder.cpp +++ b/src/executor/fragment_builder.cpp @@ -233,7 +233,6 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu phys_op->left()->GetOutputTypes()); BuildFragments(phys_op->left(), next_plan_fragment.get()); current_fragment_ptr->AddChild(Move(next_plan_fragment)); - if (phys_op->right() != nullptr) { auto next_plan_fragment = MakeUnique(GetFragmentId()); next_plan_fragment->SetSinkNode(query_context_ptr_, diff --git a/src/executor/operator/physical_merge_aggregate.cpp b/src/executor/operator/physical_merge_aggregate.cpp index 1fd86a4ab2..07afa931b2 100644 --- a/src/executor/operator/physical_merge_aggregate.cpp +++ b/src/executor/operator/physical_merge_aggregate.cpp @@ -42,9 +42,9 @@ bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState if (merge_aggregate_op_state->input_complete_) { LOG_TRACE("PhysicalMergeAggregate::Input is complete"); - // for (auto &output_block : merge_aggregate_op_state->data_block_array_) { - // output_block->Finalize(); - // } + for (auto &output_block : merge_aggregate_op_state->data_block_array_) { + output_block->Finalize(); + } merge_aggregate_op_state->SetComplete(); return true; @@ -83,6 +83,7 @@ void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorS IntegerT new_int = MaxValue(input_int, out_int); WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int); } else if (String(function_name) == String("SUM")) { + LOG_TRACE("PhysicalAggregate::Execute:: COUNT"); UpdateBlockData(merge_aggregate_op_state, col_idx); } diff --git a/src/planner/binder/project_binder.cpp b/src/planner/binder/project_binder.cpp index 7f2386fce5..c86cf58bad 100644 --- a/src/planner/binder/project_binder.cpp +++ b/src/planner/binder/project_binder.cpp @@ -32,6 +32,58 @@ namespace infinity { SharedPtr ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { String expr_name = expr.GetName(); + // Covert avg function expr to (sum / count) function expr + if (expr.type_ == ParsedExprType::kFunction) { + auto special_function = TryBuildSpecialFuncExpr((FunctionExpr &)expr, bind_context_ptr, depth); + + if (special_function.has_value()) { + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); + } + + SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), (FunctionExpr &)expr); + + CheckFuncType(function_set_ptr->type_); + if (IsEqual(function_set_ptr->name(), String("AVG"))) { + auto &fun_expr = (FunctionExpr &)expr; + Vector col_names{}; + FunctionExpr *div_function_expr = new FunctionExpr(); + div_function_expr->func_name_ = String("/"); + div_function_expr->arguments_ = new Vector(); + + if (fun_expr.arguments_->size() == 1) { + if ((*fun_expr.arguments_)[0]->type_ == ParsedExprType::kColumn) { + ColumnExpr *col_expr = (ColumnExpr *)(*fun_expr.arguments_)[0]; + col_names = Move(col_expr->names_); + + fun_expr.func_name_ = "/"; + delete fun_expr.arguments_; + fun_expr.arguments_ = nullptr; + } + + FunctionExpr *sum_function_expr = new FunctionExpr(); + sum_function_expr->func_name_ = String("sum"); + sum_function_expr->arguments_ = new Vector(); + ColumnExpr *column_expr_for_sum = new ColumnExpr(); + column_expr_for_sum->names_.emplace_back(col_names[0]); + sum_function_expr->arguments_->emplace_back(column_expr_for_sum); + + FunctionExpr *count_function_expr = new FunctionExpr(); + count_function_expr->func_name_ = String("count"); + count_function_expr->arguments_ = new Vector(); + ColumnExpr *column_expr_for_count = new ColumnExpr(); + column_expr_for_count->names_.emplace_back(col_names[0]); + count_function_expr->arguments_->emplace_back(column_expr_for_count); + + div_function_expr->arguments_->emplace_back(sum_function_expr); + div_function_expr->arguments_->emplace_back(count_function_expr); + + auto div_func_name = div_function_expr->GetName(); + + return ExpressionBinder::BuildExpression(*div_function_expr, bind_context_ptr, depth, root); + } + } + } + // If the expr isn't from aggregate function and coming from group by lists. if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) { i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name]; @@ -94,8 +146,13 @@ SharedPtr ProjectBinder::BuildFuncExpr(const FunctionExpr &expr, SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), expr); if (function_set_ptr->type_ == FunctionType::kAggregate) { if (this->binding_agg_func_) { - Error("Aggregate function is called in another aggregate function."); + Error(Format("Aggregate function {} is called in another aggregate function.", function_set_ptr->name())); } else { + // if (IsEqual(function_set_ptr->name(),String("AVG"))) { + // this->binding_agg_func_ = false; + // }else { + // this->binding_agg_func_ = true; + // } this->binding_agg_func_ = true; } } diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index 7d767dd875..3ea44fd0b9 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -281,48 +281,6 @@ SharedPtr ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex } } - // covert avg function expr to (sum / count) function expr - - if (function_set_ptr->name() == "AVG") { - Vector col_names{}; - ColumnExpr *col_expr = nullptr; - FunctionExpr *div_function_expr = new FunctionExpr(); - div_function_expr->func_name_ = String("div"); - div_function_expr->arguments_ = new Vector(); - - if (expr.arguments_->size() == 1) { - if ((*expr.arguments_)[0]->type_ == ParsedExprType::kColumn) { - col_expr = (ColumnExpr *)(*expr.arguments_)[0]; - col_names = col_expr->names_; - } - FunctionExpr *sum_function_expr = new FunctionExpr(); - sum_function_expr->func_name_ = String("sum"); - sum_function_expr->arguments_ = new Vector(); - ColumnExpr* column_expr_for_sum = new ColumnExpr(); - column_expr_for_sum->names_.emplace_back(col_names[0]); - sum_function_expr->arguments_->emplace_back(column_expr_for_sum); - - FunctionExpr *count_function_expr = new FunctionExpr(); - count_function_expr->func_name_= String("count"); - count_function_expr->arguments_ = new Vector(); - ColumnExpr* column_expr_for_count = new ColumnExpr(); - column_expr_for_sum->names_.emplace_back(col_names[0]); - count_function_expr->arguments_->emplace_back(column_expr_for_count); - - div_function_expr->arguments_->emplace_back(sum_function_expr); - div_function_expr->arguments_->emplace_back(count_function_expr); - } - - // Vector> arguments; - // arguments.reserve(div_function_expr->arguments_->size()); - // for (const auto *arg_expr : *div_function_expr->arguments_) { - // // The argument expression isn't root expression. - // // SharedPtr expr_ptr - // auto expr_ptr = BuildExpression(*arg_expr, bind_context_ptr, depth, false); - // arguments.emplace_back(expr_ptr); - // } - } - Vector> arguments; arguments.reserve(expr.arguments_->size()); for (const auto *arg_expr : *expr.arguments_) { diff --git a/test/sql/basic.slt b/test/sql/basic.slt index 9cbe8f2ba2..63b82e97cd 100644 --- a/test/sql/basic.slt +++ b/test/sql/basic.slt @@ -2,6 +2,9 @@ # description: Test basic sql statement for sample # group: [basic] +statement ok +DROP TABLE IF EXISTS NATION; + # Expecting IDENTIFIER or PRIMARY or UNIQUE statement error CREATE TABLE NATION ( diff --git a/test/sql/dql/aggregate/test_simple_agg.slt b/test/sql/dql/aggregate/test_simple_agg.slt index 91d6c79730..19f8fea87c 100644 --- a/test/sql/dql/aggregate/test_simple_agg.slt +++ b/test/sql/dql/aggregate/test_simple_agg.slt @@ -22,7 +22,7 @@ SELECT SUM(c2) FROM simple_agg query I SELECT AVG(c1) FROM simple_agg ---- -2.000000 +2 query II SELECT AVG(c2) FROM simple_agg @@ -54,6 +54,51 @@ SELECT COUNT(c1) FROM simple_agg ---- 3 +query III +SELECT SUM(c1)+SUM(c1) FROM simple_agg; +---- +12 + +query III +SELECT MAX(c1)+SUM(c1) FROM simple_agg; +---- +9 + +query III +SELECT MAX(c1)+SUM(c2) FROM simple_agg; +---- +9.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)-SUM(c2) FROM simple_agg; +---- +-3.000000 + +query III +SELECT MAX(c1)/SUM(c2) FROM simple_agg; +---- +0.500000 + +query III +SELECT MAX(c1)/AVG(c2) FROM simple_agg; +---- +1.500000 + +query III +SELECT MAX(c1)*AVG(c2) FROM simple_agg; +---- +6.000000 + statement ok DROP TABLE simple_agg; diff --git a/test/sql/dql/rbo_rule/column_pruner.slt b/test/sql/dql/rbo_rule/column_pruner.slt index f6d516a46b..583854f0e5 100644 --- a/test/sql/dql/rbo_rule/column_pruner.slt +++ b/test/sql/dql/rbo_rule/column_pruner.slt @@ -112,24 +112,25 @@ EXPLAIN LOGICAL SELECT MIN(c1 + 1), AVG(c2) FROM t1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#0), avg(c2) (#1)] + - expressions: [min((c1 + 1)) (#0), sum(c2) (#1) / count(c2) (#2)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] -> TABLE SCAN (2) - table name: t1(default.t1) - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#1), avg(c2) (#2)] + - expressions: [min((c1 + 1)) (#1), sum(c2) (#2) / count(c2) (#3)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] - group by table index: #2 - group by: [c1 (#0)] -> TABLE SCAN (2) @@ -137,6 +138,7 @@ EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL DELETE FROM t1 WHERE c1=1; ---- diff --git a/tools/generate_aggregate.py b/tools/generate_aggregate.py new file mode 100644 index 0000000000..f630733dc4 --- /dev/null +++ b/tools/generate_aggregate.py @@ -0,0 +1,123 @@ +import numpy as np +import random +import os +import argparse + + +def generate(generate_if_exists: bool, copy_dir: str): + row_n = 8000 + sort_dir = "./test/data/csv" + slt_dir = "./test/sql/dql" + + table_name = "test_simple_agg_big_cpp" + agg_path = sort_dir + "/test_simple_agg_big.csv" + slt_path = slt_dir + "/test_simple_agg_big.slt" + copy_path = copy_dir + "/test_simple_agg_big.csv" + + os.makedirs(sort_dir, exist_ok=True) + os.makedirs(slt_dir, exist_ok=True) + if os.path.exists(agg_path) and os.path.exists(slt_path) and generate_if_exists: + print( + "File {} and {} already existed exists. Skip Generating.".format( + slt_path, agg_path + ) + ) + return + with open(agg_path, "w") as agg_file, open(slt_path, "w") as slt_file: + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name)) + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write( + "CREATE TABLE {} (c1 int, c2 float);\n".format(table_name) + ) + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write( + "COPY {} FROM '{}' WITH ( DELIMITER ',' );\n".format( + table_name, copy_path + ) + ) + + + sequence = np.arange(1, 8001) + + for i in sequence: + agg_file.write(str(i) + "," + str(i)) + agg_file.write("\n") + + + # select max(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT max(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(8000)) + slt_file.write("\n") + + # select min(c2) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT min(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(1)) + slt_file.write("\n") + + + # select sum(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT sum(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(np.sum(sequence))) + slt_file.write("\n") + + + # select avg(c1) from test_simple_agg_big + + # slt_file.write("\n") + # slt_file.write("query I\n") + # slt_file.write("SELECT AVG(c1) FROM {};\n".format(table_name)) + # slt_file.write("----\n") + # slt_file.write(str(np.mean(sequence))) + # slt_file.write("\n") + + # select count(c1) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(8000)) + slt_file.write("\n") + + + + + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE {};\n".format(table_name)) + random.random() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate limit data for test") + + parser.add_argument( + "-g", + "--generate", + type=bool, + default=False, + dest="generate_if_exists", + ) + parser.add_argument( + "-c", + "--copy", + type=str, + default="/tmp/infinity/test_data", + dest="copy_dir", + ) + args = parser.parse_args() + generate(args.generate_if_exists, args.copy_dir) diff --git a/tools/sqllogictest.py b/tools/sqllogictest.py index 5c49dccfb1..0c4f99cb9b 100644 --- a/tools/sqllogictest.py +++ b/tools/sqllogictest.py @@ -6,6 +6,7 @@ from generate_fvecs import generate as generate2 from generate_sort import generate as generate3 from generate_limit import generate as generate4 +from generate_aggregate import generate as generate5 def python_skd_test(python_test_dir: str): @@ -108,6 +109,7 @@ def copy_all(data_dir, copy_dir): generate2(args.generate_if_exists, args.copy) generate3(args.generate_if_exists, args.copy) generate4(args.generate_if_exists, args.copy) + generate5(args.generate_if_exists, args.copy) print("Generate file finshed.") python_skd_test(python_test_dir) test_process(args.path, args.test, args.data, args.copy)