Skip to content

Commit

Permalink
For avg functions replacing it with (sum / count)
Browse files Browse the repository at this point in the history
The commit involves converting avg function expressions to (sum / count) function expressions in the SQL planner and test updates accordingly. Tests cover both the regular and exceptional cases. Corresponding changes are reflected in other parts of the code like ProjectBinder and Aggregate operators.
  • Loading branch information
loloxwg committed Dec 29, 2023
1 parent b05e7c4 commit 5a54225
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 52 deletions.
1 change: 0 additions & 1 deletion src/executor/fragment_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu
phys_op->left()->GetOutputTypes());
BuildFragments(phys_op->left(), next_plan_fragment.get());
current_fragment_ptr->AddChild(Move(next_plan_fragment));

if (phys_op->right() != nullptr) {
auto next_plan_fragment = MakeUnique<PlanFragment>(GetFragmentId());
next_plan_fragment->SetSinkNode(query_context_ptr_,
Expand Down
7 changes: 4 additions & 3 deletions src/executor/operator/physical_merge_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState
if (merge_aggregate_op_state->input_complete_) {

LOG_TRACE("PhysicalMergeAggregate::Input is complete");
// for (auto &output_block : merge_aggregate_op_state->data_block_array_) {
// output_block->Finalize();
// }
for (auto &output_block : merge_aggregate_op_state->data_block_array_) {
output_block->Finalize();
}

merge_aggregate_op_state->SetComplete();
return true;
Expand Down Expand Up @@ -83,6 +83,7 @@ void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorS
IntegerT new_int = MaxValue(input_int, out_int);
WriteIntegerAtPosition(merge_aggregate_op_state, 0, col_idx, 0, new_int);
} else if (String(function_name) == String("SUM")) {
LOG_TRACE("PhysicalAggregate::Execute:: COUNT");
UpdateBlockData(merge_aggregate_op_state, col_idx);
}

Expand Down
59 changes: 58 additions & 1 deletion src/planner/binder/project_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,58 @@ namespace infinity {
SharedPtr<BaseExpression> ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) {
String expr_name = expr.GetName();

// Covert avg function expr to (sum / count) function expr
if (expr.type_ == ParsedExprType::kFunction) {
auto special_function = TryBuildSpecialFuncExpr((FunctionExpr &)expr, bind_context_ptr, depth);

if (special_function.has_value()) {
return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root);
}

SharedPtr<FunctionSet> function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), (FunctionExpr &)expr);

CheckFuncType(function_set_ptr->type_);
if (IsEqual(function_set_ptr->name(), String("AVG"))) {
auto &fun_expr = (FunctionExpr &)expr;
Vector<String> col_names{};
FunctionExpr *div_function_expr = new FunctionExpr();
div_function_expr->func_name_ = String("/");
div_function_expr->arguments_ = new Vector<ParsedExpr *>();

if (fun_expr.arguments_->size() == 1) {
if ((*fun_expr.arguments_)[0]->type_ == ParsedExprType::kColumn) {
ColumnExpr *col_expr = (ColumnExpr *)(*fun_expr.arguments_)[0];
col_names = Move(col_expr->names_);

fun_expr.func_name_ = "/";
delete fun_expr.arguments_;
fun_expr.arguments_ = nullptr;
}

FunctionExpr *sum_function_expr = new FunctionExpr();
sum_function_expr->func_name_ = String("sum");
sum_function_expr->arguments_ = new Vector<ParsedExpr *>();
ColumnExpr *column_expr_for_sum = new ColumnExpr();
column_expr_for_sum->names_.emplace_back(col_names[0]);
sum_function_expr->arguments_->emplace_back(column_expr_for_sum);

FunctionExpr *count_function_expr = new FunctionExpr();
count_function_expr->func_name_ = String("count");
count_function_expr->arguments_ = new Vector<ParsedExpr *>();
ColumnExpr *column_expr_for_count = new ColumnExpr();
column_expr_for_count->names_.emplace_back(col_names[0]);
count_function_expr->arguments_->emplace_back(column_expr_for_count);

div_function_expr->arguments_->emplace_back(sum_function_expr);
div_function_expr->arguments_->emplace_back(count_function_expr);

auto div_func_name = div_function_expr->GetName();

return ExpressionBinder::BuildExpression(*div_function_expr, bind_context_ptr, depth, root);
}
}
}

// If the expr isn't from aggregate function and coming from group by lists.
if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) {
i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name];
Expand Down Expand Up @@ -94,8 +146,13 @@ SharedPtr<BaseExpression> ProjectBinder::BuildFuncExpr(const FunctionExpr &expr,
SharedPtr<FunctionSet> function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), expr);
if (function_set_ptr->type_ == FunctionType::kAggregate) {
if (this->binding_agg_func_) {
Error<PlannerException>("Aggregate function is called in another aggregate function.");
Error<PlannerException>(Format("Aggregate function {} is called in another aggregate function.", function_set_ptr->name()));
} else {
// if (IsEqual(function_set_ptr->name(),String("AVG"))) {
// this->binding_agg_func_ = false;
// }else {
// this->binding_agg_func_ = true;
// }
this->binding_agg_func_ = true;
}
}
Expand Down
42 changes: 0 additions & 42 deletions src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,48 +281,6 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex
}
}

// covert avg function expr to (sum / count) function expr

if (function_set_ptr->name() == "AVG") {
Vector<String> col_names{};
ColumnExpr *col_expr = nullptr;
FunctionExpr *div_function_expr = new FunctionExpr();
div_function_expr->func_name_ = String("div");
div_function_expr->arguments_ = new Vector<ParsedExpr*>();

if (expr.arguments_->size() == 1) {
if ((*expr.arguments_)[0]->type_ == ParsedExprType::kColumn) {
col_expr = (ColumnExpr *)(*expr.arguments_)[0];
col_names = col_expr->names_;
}
FunctionExpr *sum_function_expr = new FunctionExpr();
sum_function_expr->func_name_ = String("sum");
sum_function_expr->arguments_ = new Vector<ParsedExpr*>();
ColumnExpr* column_expr_for_sum = new ColumnExpr();
column_expr_for_sum->names_.emplace_back(col_names[0]);
sum_function_expr->arguments_->emplace_back(column_expr_for_sum);

FunctionExpr *count_function_expr = new FunctionExpr();
count_function_expr->func_name_= String("count");
count_function_expr->arguments_ = new Vector<ParsedExpr*>();
ColumnExpr* column_expr_for_count = new ColumnExpr();
column_expr_for_sum->names_.emplace_back(col_names[0]);
count_function_expr->arguments_->emplace_back(column_expr_for_count);

div_function_expr->arguments_->emplace_back(sum_function_expr);
div_function_expr->arguments_->emplace_back(count_function_expr);
}

// Vector<SharedPtr<BaseExpression>> arguments;
// arguments.reserve(div_function_expr->arguments_->size());
// for (const auto *arg_expr : *div_function_expr->arguments_) {
// // The argument expression isn't root expression.
// // SharedPtr<BaseExpression> expr_ptr
// auto expr_ptr = BuildExpression(*arg_expr, bind_context_ptr, depth, false);
// arguments.emplace_back(expr_ptr);
// }
}

Vector<SharedPtr<BaseExpression>> arguments;
arguments.reserve(expr.arguments_->size());
for (const auto *arg_expr : *expr.arguments_) {
Expand Down
3 changes: 3 additions & 0 deletions test/sql/basic.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# description: Test basic sql statement for sample
# group: [basic]

statement ok
DROP TABLE IF EXISTS NATION;

# Expecting IDENTIFIER or PRIMARY or UNIQUE
statement error
CREATE TABLE NATION (
Expand Down
47 changes: 46 additions & 1 deletion test/sql/dql/aggregate/test_simple_agg.slt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ SELECT SUM(c2) FROM simple_agg
query I
SELECT AVG(c1) FROM simple_agg
----
2.000000
2

query II
SELECT AVG(c2) FROM simple_agg
Expand Down Expand Up @@ -54,6 +54,51 @@ SELECT COUNT(c1) FROM simple_agg
----
3

query III
SELECT SUM(c1)+SUM(c1) FROM simple_agg;
----
12

query III
SELECT MAX(c1)+SUM(c1) FROM simple_agg;
----
9

query III
SELECT MAX(c1)+SUM(c2) FROM simple_agg;
----
9.000000

query III
SELECT MAX(c1)*SUM(c2) FROM simple_agg;
----
18.000000

query III
SELECT MAX(c1)*SUM(c2) FROM simple_agg;
----
18.000000

query III
SELECT MAX(c1)-SUM(c2) FROM simple_agg;
----
-3.000000

query III
SELECT MAX(c1)/SUM(c2) FROM simple_agg;
----
0.500000

query III
SELECT MAX(c1)/AVG(c2) FROM simple_agg;
----
1.500000

query III
SELECT MAX(c1)*AVG(c2) FROM simple_agg;
----
6.000000


statement ok
DROP TABLE simple_agg;
10 changes: 6 additions & 4 deletions test/sql/dql/rbo_rule/column_pruner.slt
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,33 @@ EXPLAIN LOGICAL SELECT MIN(c1 + 1), AVG(c2) FROM t1;
----
PROJECT (4)
- table index: #4
- expressions: [min((c1 + 1)) (#0), avg(c2) (#1)]
- expressions: [min((c1 + 1)) (#0), sum(c2) (#1) / count(c2) (#2)]
-> AGGREGATE (3)
- aggregate table index: #3
- aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))]
- aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))]
-> TABLE SCAN (2)
- table name: t1(default.t1)
- table index: #1
- output columns: [c1, c2, __rowid]


query I
EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1;
----
PROJECT (4)
- table index: #4
- expressions: [min((c1 + 1)) (#1), avg(c2) (#2)]
- expressions: [min((c1 + 1)) (#1), sum(c2) (#2) / count(c2) (#3)]
-> AGGREGATE (3)
- aggregate table index: #3
- aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))]
- aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))]
- group by table index: #2
- group by: [c1 (#0)]
-> TABLE SCAN (2)
- table name: t1(default.t1)
- table index: #1
- output columns: [c1, c2, __rowid]


query I
EXPLAIN LOGICAL DELETE FROM t1 WHERE c1=1;
----
Expand Down
123 changes: 123 additions & 0 deletions tools/generate_aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
import random
import os
import argparse


def generate(generate_if_exists: bool, copy_dir: str):
row_n = 8000
sort_dir = "./test/data/csv"
slt_dir = "./test/sql/dql"

table_name = "test_simple_agg_big_cpp"
agg_path = sort_dir + "/test_simple_agg_big.csv"
slt_path = slt_dir + "/test_simple_agg_big.slt"
copy_path = copy_dir + "/test_simple_agg_big.csv"

os.makedirs(sort_dir, exist_ok=True)
os.makedirs(slt_dir, exist_ok=True)
if os.path.exists(agg_path) and os.path.exists(slt_path) and generate_if_exists:
print(
"File {} and {} already existed exists. Skip Generating.".format(
slt_path, agg_path
)
)
return
with open(agg_path, "w") as agg_file, open(slt_path, "w") as slt_file:
slt_file.write("statement ok\n")
slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name))
slt_file.write("\n")
slt_file.write("statement ok\n")
slt_file.write(
"CREATE TABLE {} (c1 int, c2 float);\n".format(table_name)
)
slt_file.write("\n")
slt_file.write("query I\n")
slt_file.write(
"COPY {} FROM '{}' WITH ( DELIMITER ',' );\n".format(
table_name, copy_path
)
)


sequence = np.arange(1, 8001)

for i in sequence:
agg_file.write(str(i) + "," + str(i))
agg_file.write("\n")


# select max(c1) from test_simple_agg_big

slt_file.write("\n")
slt_file.write("query I\n")
slt_file.write("SELECT max(c1) FROM {};\n".format(table_name))
slt_file.write("----\n")
slt_file.write(str(8000))
slt_file.write("\n")

# select min(c2) from test_simple_agg_big

slt_file.write("\n")
slt_file.write("query I\n")
slt_file.write("SELECT min(c1) FROM {};\n".format(table_name))
slt_file.write("----\n")
slt_file.write(str(1))
slt_file.write("\n")


# select sum(c1) from test_simple_agg_big

slt_file.write("\n")
slt_file.write("query I\n")
slt_file.write("SELECT sum(c1) FROM {};\n".format(table_name))
slt_file.write("----\n")
slt_file.write(str(np.sum(sequence)))
slt_file.write("\n")


# select avg(c1) from test_simple_agg_big

# slt_file.write("\n")
# slt_file.write("query I\n")
# slt_file.write("SELECT AVG(c1) FROM {};\n".format(table_name))
# slt_file.write("----\n")
# slt_file.write(str(np.mean(sequence)))
# slt_file.write("\n")

# select count(c1) from test_simple_agg_big
slt_file.write("\n")
slt_file.write("query I\n")
slt_file.write("SELECT count(c1) FROM {};\n".format(table_name))
slt_file.write("----\n")
slt_file.write(str(8000))
slt_file.write("\n")




slt_file.write("\n")
slt_file.write("statement ok\n")
slt_file.write("DROP TABLE {};\n".format(table_name))
random.random()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate limit data for test")

parser.add_argument(
"-g",
"--generate",
type=bool,
default=False,
dest="generate_if_exists",
)
parser.add_argument(
"-c",
"--copy",
type=str,
default="/tmp/infinity/test_data",
dest="copy_dir",
)
args = parser.parse_args()
generate(args.generate_if_exists, args.copy_dir)
Loading

0 comments on commit 5a54225

Please sign in to comment.