From 597cb05e30373cd5f2a3edd9f1c65b638ca02575 Mon Sep 17 00:00:00 2001 From: Xwg Date: Wed, 3 Jan 2024 19:20:08 +0800 Subject: [PATCH] For avg functions replacing it with (sum / count), count(*) (#399) * Fix top fragment source's type * Implement PhysicalMergeAggregate operator execution This commit introduces the execution functionality in the PhysicalMergeAggregate operator. It provides data procedures for various aggregate functions like COUNT, MIN, MAX, and SUM across different DataTypes. Additionally, proper handling for input and Output dataBlocks has been implemented. Consequently, various changes to the TaskScheduler, code refactoring and logging adjustments were also done. * Update PhysicalMergeAggregate operator execution logic The PhysicalMergeAggregate operator execution logic has been updated to correctly handle COUNT, MIN, MAX, and SUM aggregate functions for different DataTypes. This commit also implements appropriate handling for input and output dataBlocks. Concurrently, various changes are performed in the TaskScheduler. Several code sections were refactored for better readability and logs have been properly adjusted for detailed tracing. * Remove unused functions in physical_merge_aggregate files Deleted several unused template functions in `physical_merge_aggregate.cppm` and `physical_merge_aggregate.cpp`. Also removed a sizable chunk of unused logic related to execution of different aggregate functions from `PhysicalMergeAggregate::SimpleMergeAggregateExecute()` method. This declutters the codebase, making it easier to read and understand. * Build new func expr * Refactor merge aggregate and expression binder code The PhysicalMergeAggregate class was refactored by introducing a new method, SimpleMergeAggregateExecute, to encapsulate some of the implementation details. The newly defined method simplifies the complexity of the code and makes the execute method easier to understand. ExpressionBinder class was also modified to improve clarity of the division function expression building process, simplifying it and making code more maintainable. Minor tweaks were also made to the fragment_context.cpp file to further enhance readability. * For avg functions replacing it with (sum / count) The commit involves converting avg function expressions to (sum / count) function expressions in the SQL planner and test updates accordingly. Tests cover both the regular and exceptional cases. Corresponding changes are reflected in other parts of the code like ProjectBinder and Aggregate operators. * Fix memory leak * Fix typo * Refactor AVG function conversion in project binder The conversion of AVG function to SUM / COUNT has been refactored for efficiency and readability. A helper function, ConvertAvgToSumDivideCount, has been introduced to modularize this logic. The updated code significantly shrinks the BuildExpression method, improving its maintainability. * Update division functions and test cases Modified the div functions to return results as Double type instead of previous specific integer types. Additionally, adjusted relevant test cases to match the changes. This allows for more accurate division results across the codebase. * Implement count(*) functionality for SQL tables A "count(*)" functionality was implemented, enabling users to count the number of rows of SQL tables. Several test scripts were modified to include a "SELECT count(*)" command immediately after table operations to verify the row count. The 'Update' and 'Delete' operations were adjusted in 'table_collection_entry.cpp', and 'StateUpdate' was modified in 'aggregate_function.cppm'. New 'CountStar' related files were added to implement the new functionality. --- .../expression/expression_evaluator.cpp | 11 +- src/executor/fragment_builder.cpp | 2 +- src/executor/operator/physical_aggregate.cpp | 86 ++++------ src/executor/operator/physical_aggregate.cppm | 9 +- .../operator/physical_merge_aggregate.cpp | 147 +++++++++++++++++- .../operator/physical_merge_aggregate.cppm | 79 ++++++++++ src/executor/operator_state.cpp | 8 + src/executor/operator_state.cppm | 5 + src/executor/physical_planner.cpp | 2 + src/function/aggregate/count.cpp | 29 ++-- src/function/aggregate/count_star.cpp | 59 +++++++ src/function/aggregate/count_star.cppm | 27 ++++ src/function/aggregate_function.cppm | 47 ++++++ src/function/builtin_functions.cpp | 3 + src/function/scalar/divide.cpp | 34 ++-- src/planner/binder/project_binder.cpp | 42 ++++- src/planner/expression_binder.cpp | 11 +- src/scheduler/fragment_context.cpp | 18 ++- .../meta/entry/table_collection_entry.cpp | 4 +- .../function/scalar/div_functions.cpp | 44 +++--- test/sql/basic.slt | 3 + test/sql/dml/delete.slt | 10 ++ test/sql/dml/import/test_embedding.slt | 5 + test/sql/dml/import/test_jsonl.slt | 8 + test/sql/dml/import/test_varchar.slt | 6 + test/sql/dml/insert.slt | 10 ++ test/sql/dml/update.slt | 11 ++ test/sql/dql/aggregate/test_simple_agg.slt | 50 ++++++ test/sql/dql/rbo_rule/column_pruner.slt | 10 +- test/sql/dql/select.slt | 6 + tools/generate_aggregate.py | 144 +++++++++++++++++ tools/sqllogictest.py | 2 + 32 files changed, 806 insertions(+), 126 deletions(-) create mode 100644 src/function/aggregate/count_star.cpp create mode 100644 src/function/aggregate/count_star.cppm create mode 100644 tools/generate_aggregate.py diff --git a/src/executor/expression/expression_evaluator.cpp b/src/executor/expression/expression_evaluator.cpp index ac65293201..8bb7f029eb 100644 --- a/src/executor/expression/expression_evaluator.cpp +++ b/src/executor/expression/expression_evaluator.cpp @@ -79,10 +79,10 @@ void ExpressionEvaluator::Execute(const SharedPtr &expr, // Create output chunk. // TODO: Now output chunk is pre-allocate memory in expression state // TODO: In the future, it can be implemented as on-demand allocation. - SharedPtr &child_output = child_state->OutputColumnVector(); - Execute(child_expr, child_state, child_output); + SharedPtr &child_output_col = child_state->OutputColumnVector(); + this->Execute(child_expr, child_state, child_output_col); - if (expr->aggregate_function_.argument_type_ != *child_output->data_type()) { + if (expr->aggregate_function_.argument_type_ != *child_output_col->data_type()) { Error("Argument type isn't matched with the child expression output"); } if (expr->aggregate_function_.return_type_ != *output_column_vector->data_type()) { @@ -93,7 +93,7 @@ void ExpressionEvaluator::Execute(const SharedPtr &expr, expr->aggregate_function_.init_func_(expr->aggregate_function_.GetState()); // 2. Loop to fill the aggregate state - expr->aggregate_function_.update_func_(expr->aggregate_function_.GetState(), child_output); + expr->aggregate_function_.update_func_(expr->aggregate_function_.GetState(), child_output_col); // 3. Get the aggregate result and append to output column vector. @@ -157,7 +157,8 @@ void ExpressionEvaluator::Execute(const SharedPtr &expr, SharedPtr &, SharedPtr &output_column_vector) { // memory copy here. - output_column_vector->SetValue(0, expr->GetValue()); + auto value = expr->GetValue(); + output_column_vector->SetValue(0, value); output_column_vector->Finalize(1); } diff --git a/src/executor/fragment_builder.cpp b/src/executor/fragment_builder.cpp index cf69d1d76b..580e46c6df 100644 --- a/src/executor/fragment_builder.cpp +++ b/src/executor/fragment_builder.cpp @@ -38,7 +38,7 @@ UniquePtr FragmentBuilder::BuildFragment(PhysicalOperator *phys_op auto plan_fragment = MakeUnique(GetFragmentId()); plan_fragment->SetSinkNode(query_context_ptr_, SinkType::kResult, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); BuildFragments(phys_op, plan_fragment.get()); - if (plan_fragment->GetSourceNode() != nullptr) { + if (plan_fragment->GetSourceNode() == nullptr) { plan_fragment->SetSourceNode(query_context_ptr_, SourceType::kEmpty, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); } return plan_fragment; diff --git a/src/executor/operator/physical_aggregate.cpp b/src/executor/operator/physical_aggregate.cpp index 825f159c12..627542c6d7 100644 --- a/src/executor/operator/physical_aggregate.cpp +++ b/src/executor/operator/physical_aggregate.cpp @@ -15,6 +15,7 @@ module; #include +#include import stl; import txn; import query_context; @@ -49,14 +50,15 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper Vector> groupby_columns; SizeT group_count = groups_.size(); - if(group_count == 0) { + if (group_count == 0) { // Aggregate without group by expression // e.g. SELECT count(a) FROM table; - if (SimpleAggregate(this->output_, prev_op_state, aggregate_operator_state)) { - return true; - } else { - return false; + auto result = SimpleAggregateExecute(prev_op_state->data_block_array_, aggregate_operator_state->data_block_array_); + prev_op_state->data_block_array_.clear(); + if (prev_op_state->Complete()) { + aggregate_operator_state->SetComplete(); } + return result; } #if 0 groupby_columns.reserve(group_count); @@ -404,17 +406,19 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_ } case kVarchar: { Error("Varchar data shuffle isn't implemented."); -// VarcharT &dst_ref = ((VarcharT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx]; -// VarcharT &src_ref = ((VarcharT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; -// if (src_ref.IsInlined()) { -// Memcpy((char *)&dst_ref, (char *)&src_ref, sizeof(VarcharT)); -// } else { -// dst_ref.length = src_ref.length; -// Memcpy(dst_ref.prefix, src_ref.prefix, VarcharT::PREFIX_LENGTH); -// -// dst_ref.ptr = output_datablock->column_vectors[column_id]->buffer_->fix_heap_mgr_->Allocate(src_ref.length); -// Memcpy(dst_ref.ptr, src_ref.ptr, src_ref.length); -// } + // VarcharT &dst_ref = ((VarcharT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx]; + // VarcharT &src_ref = ((VarcharT + // *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; if + // (src_ref.IsInlined()) { + // Memcpy((char *)&dst_ref, (char *)&src_ref, sizeof(VarcharT)); + // } else { + // dst_ref.length = src_ref.length; + // Memcpy(dst_ref.prefix, src_ref.prefix, VarcharT::PREFIX_LENGTH); + // + // dst_ref.ptr = + // output_datablock->column_vectors[column_id]->buffer_->fix_heap_mgr_->Allocate(src_ref.length); + // Memcpy(dst_ref.ptr, src_ref.ptr, src_ref.length); + // } break; } case kDate: { @@ -559,9 +563,7 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_ #endif } -bool PhysicalAggregate::SimpleAggregate(SharedPtr &output_table, - OperatorState *pre_operator_state, - AggregateOperatorState *aggregate_operator_state) { +bool PhysicalAggregate::SimpleAggregateExecute(const Vector> &input_blocks, Vector> &output_blocks) { SizeT aggregates_count = aggregates_.size(); if (aggregates_count <= 0) { Error("Simple Aggregate without aggregate expression."); @@ -579,19 +581,16 @@ bool PhysicalAggregate::SimpleAggregate(SharedPtr &output_table, Vector> output_types; output_types.reserve(aggregates_count); - SizeT input_block_count = pre_operator_state->data_block_array_.size(); + SizeT input_block_count = input_blocks.size(); - for (i64 idx = 0; auto &expr: aggregates_) { + for (i64 idx = 0; auto &expr : aggregates_) { // expression state expr_states.emplace_back(ExpressionState::CreateState(expr)); SharedPtr output_type = MakeShared(expr->Type()); // column definition - SharedPtr col_def = MakeShared(idx, - output_type, - expr->Name(), - HashSet()); + SharedPtr col_def = MakeShared(idx, output_type, expr->Name(), HashSet()); aggregate_columns.emplace_back(col_def); // for output block @@ -605,51 +604,26 @@ bool PhysicalAggregate::SimpleAggregate(SharedPtr &output_table, LOG_TRACE("No input, no aggregate result"); return true; } - // Loop blocks - - // ExpressionEvaluator evaluator; - // //evaluator.Init(input_table_->data_blocks_); - // for (SizeT expr_idx = 0; expr_idx < aggregates_count; ++expr_idx) { - // - // ExpressionEvaluator evaluator; - // evaluator.Init(aggregates_[]) - // Vector> blocks_column; - // blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]); - // evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], blocks_column[expr_idx]); - // if(blocks_column[0].get() != output_data_block->column_vectors[expr_idx].get()) { - // // column vector in blocks column might be changed to the column vector from column reference. - // // This check and assignment is to make sure the right column vector are assign to output_data_block - // output_data_block->column_vectors[expr_idx] = blocks_column[0]; - // } - // } - // - // output_data_block->Finalize(); for (SizeT block_idx = 0; block_idx < input_block_count; ++block_idx) { - DataBlock *input_data_block = pre_operator_state->data_block_array_[block_idx].get(); + DataBlock *input_data_block = input_blocks[block_idx].get(); + + output_blocks.emplace_back(DataBlock::MakeUniquePtr()); - aggregate_operator_state->data_block_array_.emplace_back(DataBlock::MakeUniquePtr()); - DataBlock *output_data_block = aggregate_operator_state->data_block_array_.back().get(); + DataBlock *output_data_block = output_blocks.back().get(); output_data_block->Init(*GetOutputTypes()); ExpressionEvaluator evaluator; evaluator.Init(input_data_block); SizeT expression_count = aggregates_count; - // Prepare the expression states - + // calculate every columns value for (SizeT expr_idx = 0; expr_idx < expression_count; ++expr_idx) { - // Vector> blocks_column; - // blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]); + LOG_TRACE("Physical aggregate Execute"); evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], output_data_block->column_vectors[expr_idx]); } output_data_block->Finalize(); } - - pre_operator_state->data_block_array_.clear(); - if (pre_operator_state->Complete()) { - aggregate_operator_state->SetComplete(); - } return true; } diff --git a/src/executor/operator/physical_aggregate.cppm b/src/executor/operator/physical_aggregate.cppm index f9f14762f9..e25016ce0b 100644 --- a/src/executor/operator/physical_aggregate.cppm +++ b/src/executor/operator/physical_aggregate.cppm @@ -25,6 +25,7 @@ import hash_table; import base_expression; import load_meta; import infinity_exception; +import data_block; export module physical_aggregate; @@ -44,8 +45,8 @@ public: Vector> aggregates, u64 aggregate_index, SharedPtr> load_metas) - : PhysicalOperator(PhysicalOperatorType::kAggregate, Move(left), nullptr, id, load_metas), groups_(Move(groups)), aggregates_(Move(aggregates)), - groupby_index_(groupby_index), aggregate_index_(aggregate_index) {} + : PhysicalOperator(PhysicalOperatorType::kAggregate, Move(left), nullptr, id, load_metas), groups_(Move(groups)), + aggregates_(Move(aggregates)), groupby_index_(groupby_index), aggregate_index_(aggregate_index) {} ~PhysicalAggregate() override = default; @@ -66,9 +67,7 @@ public: Vector> aggregates_{}; HashTable hash_table_; - bool SimpleAggregate(SharedPtr &output_table, - OperatorState *pre_operator_state, - AggregateOperatorState *aggregate_operator_state); + bool SimpleAggregateExecute(const Vector> &input_blocks, Vector> &output_blocks); inline u64 GroupTableIndex() const { return groupby_index_; } diff --git a/src/executor/operator/physical_merge_aggregate.cpp b/src/executor/operator/physical_merge_aggregate.cpp index 7775588187..81fc1fe6ea 100644 --- a/src/executor/operator/physical_merge_aggregate.cpp +++ b/src/executor/operator/physical_merge_aggregate.cpp @@ -14,19 +14,164 @@ module; +#include +#include +import stl; import query_context; import operator_state; import infinity_exception; +import logger; +import value; +import data_block; +import parser; +import physical_aggregate; +import aggregate_expression; module physical_merge_aggregate; namespace infinity { +template +using MathOperation = StdFunction; + void PhysicalMergeAggregate::Init() {} bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState *operator_state) { - Error("Not Implement"); + LOG_TRACE("PhysicalMergeAggregate::Execute:: mark"); + auto merge_aggregate_op_state = static_cast(operator_state); + + SimpleMergeAggregateExecute(merge_aggregate_op_state); + + if (merge_aggregate_op_state->input_complete_) { + + LOG_TRACE("PhysicalMergeAggregate::Input is complete"); + for (auto &output_block : merge_aggregate_op_state->data_block_array_) { + output_block->Finalize(); + } + + merge_aggregate_op_state->SetComplete(); + return true; + } + return false; } +void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorState *op_state) { + if (op_state->data_block_array_.empty()) { + op_state->data_block_array_.emplace_back(Move(op_state->input_data_block_)); + } else { + auto agg_op = dynamic_cast(this->left()); + auto aggs_size = agg_op->aggregates_.size(); + for (SizeT col_idx = 0; col_idx < aggs_size; ++col_idx) { + auto agg_expression = static_cast(agg_op->aggregates_[col_idx].get()); + + auto function_name = agg_expression->aggregate_function_.GetFuncName(); + + auto func_return_type = agg_expression->aggregate_function_.return_type_; + + switch (func_return_type.type()) { + case kInteger: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + case kBigInt: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + case kFloat: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + case kDouble: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + default: + Error("input_value_type not Implement"); + } + } + } +} + +template +void PhysicalMergeAggregate::HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx) { + if (String(function_name) == String("COUNT")) { + LOG_TRACE("COUNT"); + HandleCount(op_state, col_idx); + } else if (String(function_name) == String("MIN")) { + LOG_TRACE("MIN"); + HandleMin(op_state, col_idx); + } else if (String(function_name) == String("MAX")) { + LOG_TRACE("MAX"); + HandleMax(op_state, col_idx); + } else if (String(function_name) == String("SUM")) { + LOG_TRACE("SUM"); + HandleSum(op_state, col_idx); + } +} + +template +void PhysicalMergeAggregate::HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation minOperation = [](T a, T b) -> T { return (a < b) ? a : b; }; + UpdateData(op_state, minOperation, col_idx); +} + +template +void PhysicalMergeAggregate::HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation maxOperation = [](T a, T b) -> T { return (a > b) ? a : b; }; + UpdateData(op_state, maxOperation, col_idx); +} + +template +void PhysicalMergeAggregate::HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation countOperation = [](T a, T b) -> T { return a + b; }; + UpdateData(op_state, countOperation, col_idx); +} + +template +void PhysicalMergeAggregate::HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation sumOperation = [](T a, T b) -> T { return a + b; }; + UpdateData(op_state, sumOperation, col_idx); +} + +template +T PhysicalMergeAggregate::GetInputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx) { + Value value = op_state->input_data_block_->GetValue(col_idx, row_idx); + return value.GetValue(); +} + +template +T PhysicalMergeAggregate::GetOutputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx) { + Value value = op_state->data_block_array_[block_index]->GetValue(col_idx, row_idx); + return value.GetValue(); +} + +template +void PhysicalMergeAggregate::WriteValueAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, T value) { + op_state->data_block_array_[block_index]->SetValue(col_idx, row_idx, CreateValue(value)); +} + +template +void PhysicalMergeAggregate::UpdateData(MergeAggregateOperatorState *op_state, MathOperation operation, SizeT col_idx) { + T input = GetInputData(op_state, 0, col_idx, 0); + T output = GetOutputData(op_state, 0, col_idx, 0); + T new_value = operation(input, output); + WriteValueAtPosition(op_state, 0, col_idx, 0, new_value); +} + +template +T PhysicalMergeAggregate::AddData(T a, T b) { + return a + b; +} + +template +T PhysicalMergeAggregate::MinValue(T a, T b) { + return (a < b) ? a : b; +} + +template +T PhysicalMergeAggregate::MaxValue(T a, T b) { + return (a > b) ? a : b; +} + } // namespace infinity diff --git a/src/executor/operator/physical_merge_aggregate.cppm b/src/executor/operator/physical_merge_aggregate.cppm index 8dda917458..aca0859317 100644 --- a/src/executor/operator/physical_merge_aggregate.cppm +++ b/src/executor/operator/physical_merge_aggregate.cppm @@ -23,6 +23,8 @@ import physical_operator_type; import load_meta; import base_table_ref; import infinity_exception; +import value; +import data_block; export module physical_merge_aggregate; @@ -53,6 +55,83 @@ public: Error("TaskletCount not Implement"); return 0; } + + template + T GetDataFromValueAtInputBlockPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + template + T GetDataFromValueAtOutputBlockPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + void WriteIntegerAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, IntegerT integer); + + template + T GetInputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + template + T GetOutputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + template + T MinValue(T a, T b); + + template + T MaxValue(T a, T b); + + template + T AddData(T a, T b); + + template + using MathOperation = StdFunction; + + void SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state); + + void UpdateBlockData(MergeAggregateOperatorState *merge_aggregate_op_state, SizeT col_idx); + + template + void UpdateData(MergeAggregateOperatorState *op_state, MathOperation operation, SizeT col_idx); + + template + void WriteValueAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, T value); + + template + void HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + Value CreateValue(T value) { + Error("Unhandled type for makeValue"); + } + + template <> + Value CreateValue(IntegerT value) { + return Value::MakeInt(value); + } + + template <> + Value CreateValue(BigIntT value) { + return Value::MakeBigInt(value); + } + + template <> + Value CreateValue(DoubleT value) { + return Value::MakeDouble(value); + } + + template <> + Value CreateValue(FloatT value) { + return Value::MakeFloat(value); + } + private: SharedPtr> output_names_{}; SharedPtr>> output_types_{}; diff --git a/src/executor/operator_state.cpp b/src/executor/operator_state.cpp index 7f8967dabd..1d9ac95869 100644 --- a/src/executor/operator_state.cpp +++ b/src/executor/operator_state.cpp @@ -114,6 +114,14 @@ bool QueueSourceState::GetData() { } break; } + case PhysicalOperatorType::kMergeAggregate: { + auto *fragment_data = static_cast(fragment_data_base.get()); + MergeAggregateOperatorState *merge_aggregate_op_state = (MergeAggregateOperatorState *)next_op_state; + //merge_aggregate_op_state->input_data_blocks_.push_back(Move(fragment_data->data_block_)); + merge_aggregate_op_state->input_data_block_ = Move(fragment_data->data_block_); + merge_aggregate_op_state->input_complete_ = completed; + break; + } default: { Error("Not support operator type"); break; diff --git a/src/executor/operator_state.cppm b/src/executor/operator_state.cppm index a5b2515027..2d4e7635ff 100644 --- a/src/executor/operator_state.cppm +++ b/src/executor/operator_state.cppm @@ -60,6 +60,11 @@ export struct AggregateOperatorState : public OperatorState { // Merge Aggregate export struct MergeAggregateOperatorState : public OperatorState { inline explicit MergeAggregateOperatorState() : OperatorState(PhysicalOperatorType::kMergeAggregate) {} + + /// Since merge agg is the first op, no previous operator state. This ptr is to get input data. + //Vector> input_data_blocks_{nullptr}; + UniquePtr input_data_block_{nullptr}; + bool input_complete_{false}; }; // Merge Parallel Aggregate diff --git a/src/executor/physical_planner.cpp b/src/executor/physical_planner.cpp index eba2599e84..cdda936f1e 100644 --- a/src/executor/physical_planner.cpp +++ b/src/executor/physical_planner.cpp @@ -541,6 +541,8 @@ UniquePtr PhysicalPlanner::BuildAggregate(const SharedPtraggregate_index_, logical_operator->load_metas()); + + if (tasklet_count == 1) { return physical_agg_op; } else { diff --git a/src/function/aggregate/count.cpp b/src/function/aggregate/count.cpp index 08b461e273..6e93f86c13 100644 --- a/src/function/aggregate/count.cpp +++ b/src/function/aggregate/count.cpp @@ -159,15 +159,16 @@ void RegisterCountFunction(const UniquePtr &catalog_ptr) { function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = -// UnaryAggregate, PathT, BigIntT>(func_name, DataType(LogicalType::kPath), DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = + // UnaryAggregate, PathT, BigIntT>(func_name, DataType(LogicalType::kPath), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = UnaryAggregate, PolygonT, BigIntT>(func_name, -// DataType(LogicalType::kPolygon), -// DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = UnaryAggregate, PolygonT, BigIntT>(func_name, + // DataType(LogicalType::kPolygon), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { AggregateFunction count_function = @@ -175,9 +176,10 @@ void RegisterCountFunction(const UniquePtr &catalog_ptr) { function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = -// UnaryAggregate, BitmapT, BigIntT>(func_name, DataType(LogicalType::kBitmap), DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = + // UnaryAggregate, BitmapT, BigIntT>(func_name, DataType(LogicalType::kBitmap), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { AggregateFunction count_function = @@ -185,9 +187,10 @@ void RegisterCountFunction(const UniquePtr &catalog_ptr) { function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = -// UnaryAggregate, BlobT, BigIntT>(func_name, DataType(LogicalType::kBlob), DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = + // UnaryAggregate, BlobT, BigIntT>(func_name, DataType(LogicalType::kBlob), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { AggregateFunction count_function = UnaryAggregate, EmbeddingT, BigIntT>(func_name, diff --git a/src/function/aggregate/count_star.cpp b/src/function/aggregate/count_star.cpp new file mode 100644 index 0000000000..11b983710d --- /dev/null +++ b/src/function/aggregate/count_star.cpp @@ -0,0 +1,59 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import new_catalog; + +import infinity_exception; +import aggregate_function; +import aggregate_function_set; +import parser; +import third_party; + +module count_star; + +namespace infinity { + +template +struct CountStarState { + i64 value_{}; + + void Initialize() { this->value_ = 0; } + + void Update(i64 *__restrict input, SizeT idx) { value_ = input[idx]; } + + inline void ConstantUpdate(i64 *__restrict input, SizeT idx, SizeT) { value_ = input[idx]; } + + inline ptr_t Finalize() { return (ptr_t)&value_; } + + inline static SizeT Size(const DataType &) { return sizeof(i64); } +}; + +void RegisterCountStarFunction(const UniquePtr &catalog_ptr) { + String func_name = "COUNT_STAR"; + + SharedPtr function_set_ptr = MakeShared(func_name); + + { + AggregateFunction count_function = + CountStarAggregate, BigIntT>(func_name, DataType(LogicalType::kBigInt), DataType(LogicalType::kBigInt)); + function_set_ptr->AddFunction(count_function); + } + + NewCatalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); +} + +} // namespace infinity diff --git a/src/function/aggregate/count_star.cppm b/src/function/aggregate/count_star.cppm new file mode 100644 index 0000000000..d281869a86 --- /dev/null +++ b/src/function/aggregate/count_star.cppm @@ -0,0 +1,27 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; + +export module count_star; + +namespace infinity { + +class NewCatalog; + +export void RegisterCountStarFunction(const UniquePtr &catalog_ptr); + +} \ No newline at end of file diff --git a/src/function/aggregate_function.cppm b/src/function/aggregate_function.cppm index 8be567b9fa..4c3c515a5c 100644 --- a/src/function/aggregate_function.cppm +++ b/src/function/aggregate_function.cppm @@ -95,6 +95,8 @@ public: [[nodiscard]] ptr_t GetState() const { return state_data_.get(); } + [[nodiscard]] String GetFuncName() const { return name_; } + public: AggregateInitializeFuncType init_func_; AggregateUpdateFuncType update_func_; @@ -117,4 +119,49 @@ inline AggregateFunction UnaryAggregate(const String &name, const DataType &inpu AggregateOperation::StateUpdate, AggregateOperation::StateFinalize); } + + +class CountStarAggregateOperation { +public: + template + static inline void StateInitialize(const ptr_t state) { + ((AggregateState *)state)->Initialize(); + } + + template + static inline void StateUpdate(const ptr_t state, const SharedPtr &input_column_vector) { + // Loop execute state update according to the input column vector + + switch (input_column_vector->vector_type()) { + case ColumnVectorType::kConstant: { + auto *input_ptr = (i64 *)(input_column_vector->data()); + ((AggregateState *)state)->Update(input_ptr, 0); + break; + } + default: { + Error("Other type of column vector isn't implemented"); + } + } + } + + template + static inline ptr_t StateFinalize(const ptr_t state) { + // Loop execute state update according to the input column vector + ptr_t result = ((AggregateState *)state)->Finalize(); + return result; + } +}; + + +export template +inline auto CountStarAggregate(String &name, const DataType &input_type, const DataType &return_type) -> AggregateFunction { + auto agg_function = AggregateFunction(name, + input_type, + return_type, + AggregateState::Size(input_type), + CountStarAggregateOperation::StateInitialize, + CountStarAggregateOperation::StateUpdate, + CountStarAggregateOperation::StateFinalize); + return agg_function; +} } // namespace infinity diff --git a/src/function/builtin_functions.cpp b/src/function/builtin_functions.cpp index 8117ed79ad..0afd36c1e6 100644 --- a/src/function/builtin_functions.cpp +++ b/src/function/builtin_functions.cpp @@ -22,6 +22,7 @@ import first; import max; import min; import sum; +import count_star; import add; import abs; @@ -52,6 +53,7 @@ import special_function; import parser; + module builtin_functions; namespace infinity { @@ -68,6 +70,7 @@ void BuiltinFunctions::Init() { void BuiltinFunctions::RegisterAggregateFunction() { RegisterAvgFunction(catalog_ptr_); RegisterCountFunction(catalog_ptr_); + RegisterCountStarFunction(catalog_ptr_); RegisterFirstFunction(catalog_ptr_); RegisterMaxFunction(catalog_ptr_); RegisterMinFunction(catalog_ptr_); diff --git a/src/function/scalar/divide.cpp b/src/function/scalar/divide.cpp index 0a6511cbd9..293bf348c8 100644 --- a/src/function/scalar/divide.cpp +++ b/src/function/scalar/divide.cpp @@ -38,14 +38,14 @@ struct DivFunction { if (left == std::numeric_limits::min() && right == -1) { return false; } - result = left / right; + result = DoubleT(left) / DoubleT(right); return true; } }; template <> inline bool DivFunction::Run(FloatT left, FloatT right, FloatT &result) { - result = left / right; + result = left / DoubleT(right); if (std::isnan(result) || std::isinf(result)) return false; return true; @@ -65,6 +65,12 @@ inline bool DivFunction::Run(HugeIntT, HugeIntT, HugeIntT &) { return false; } +template <> +inline bool DivFunction::Run(HugeIntT, HugeIntT, DoubleT &) { + Error("Not implement huge int divide operator."); + return false; +} + void RegisterDivFunction(const UniquePtr &catalog_ptr) { String func_name = "/"; @@ -72,38 +78,38 @@ void RegisterDivFunction(const UniquePtr &catalog_ptr) { ScalarFunction div_function_int8(func_name, {DataType(LogicalType::kTinyInt), DataType(LogicalType::kTinyInt)}, - {DataType(LogicalType::kTinyInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int8); ScalarFunction div_function_int16(func_name, {DataType(LogicalType::kSmallInt), DataType(LogicalType::kSmallInt)}, - {DataType(LogicalType::kSmallInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int16); ScalarFunction div_function_int32(func_name, {DataType(LogicalType::kInteger), DataType(LogicalType::kInteger)}, - {DataType(LogicalType::kInteger)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int32); ScalarFunction div_function_int64(func_name, {DataType(LogicalType::kBigInt), DataType(LogicalType::kBigInt)}, - {DataType(LogicalType::kBigInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int64); ScalarFunction div_function_int128(func_name, {DataType(LogicalType::kHugeInt), DataType(LogicalType::kHugeInt)}, - {DataType(LogicalType::kHugeInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int128); ScalarFunction div_function_float(func_name, {DataType(LogicalType::kFloat), DataType(LogicalType::kFloat)}, - {DataType(LogicalType::kFloat)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_float); ScalarFunction div_function_double(func_name, diff --git a/src/planner/binder/project_binder.cpp b/src/planner/binder/project_binder.cpp index 7f2386fce5..c4512c44f1 100644 --- a/src/planner/binder/project_binder.cpp +++ b/src/planner/binder/project_binder.cpp @@ -27,11 +27,51 @@ import infinity_exception; module project_binder; +namespace { + +using namespace infinity; + +void ConvertAvgToSumDivideCount(FunctionExpr &func_expression, const Vector &column_names) { + func_expression.func_name_ = "/"; + func_expression.arguments_->clear(); + auto createFunctionWithColumnArg = [&column_names](const String &func_name) { + auto function_expression = MakeUnique(); + function_expression->func_name_ = func_name; + function_expression->arguments_ = new Vector(); + auto column_expr = MakeUnique(); + column_expr->names_.push_back(column_names[0]); + function_expression->arguments_->push_back(column_expr.release()); + return function_expression.release(); + }; + func_expression.arguments_->push_back(createFunctionWithColumnArg("sum")); + func_expression.arguments_->push_back(createFunctionWithColumnArg("count")); +} + +} // namespace + namespace infinity { SharedPtr ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { String expr_name = expr.GetName(); + // Covert avg function expr to (sum / count) function expr + if (expr.type_ == ParsedExprType::kFunction) { + auto &function_expression = (FunctionExpr &)expr; + auto special_function = TryBuildSpecialFuncExpr(function_expression, bind_context_ptr, depth); + if (special_function.has_value()) { + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); + } + auto function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), function_expression); + + if (IsEqual(function_set_ptr->name(), String("AVG")) && function_expression.arguments_->size() == 1 && + (*function_expression.arguments_)[0]->type_ == ParsedExprType::kColumn) { + auto column_expr = (ColumnExpr *)(*function_expression.arguments_)[0]; + Vector column_names(Move(column_expr->names_)); + delete column_expr; + ConvertAvgToSumDivideCount(function_expression, column_names); + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); + } + } // If the expr isn't from aggregate function and coming from group by lists. if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) { i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name]; @@ -94,7 +134,7 @@ SharedPtr ProjectBinder::BuildFuncExpr(const FunctionExpr &expr, SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), expr); if (function_set_ptr->type_ == FunctionType::kAggregate) { if (this->binding_agg_func_) { - Error("Aggregate function is called in another aggregate function."); + Error(Format("Aggregate function {} is called in another aggregate function.", function_set_ptr->name())); } else { this->binding_agg_func_ = true; } diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index 6aacad501a..8fe576b33a 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -64,6 +64,7 @@ namespace infinity { SharedPtr ExpressionBinder::Bind(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { // Call implemented BuildExpression + SharedPtr result = BuildExpression(expr, bind_context_ptr, depth, root); if (result.get() == nullptr) { if (result.get() == nullptr) { @@ -274,7 +275,15 @@ SharedPtr ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex ColumnExpr *col_expr = (ColumnExpr *)(*expr.arguments_)[0]; if (col_expr->star_) { delete (*expr.arguments_)[0]; - (*expr.arguments_)[0] = new ConstantExpr(LiteralType::kBoolean); + auto constant_exp = new ConstantExpr(LiteralType::kInteger); + // catulate row count + String &table_name = bind_context_ptr->table_names_[0]; + TableCollectionEntry *table_entry = bind_context_ptr->binding_by_name_[table_name]->table_collection_entry_ptr_; + constant_exp->integer_value_ = table_entry->row_count_; + (*expr.arguments_)[0] = constant_exp; + auto &expr_rewrite = (FunctionExpr &)expr; + expr_rewrite.func_name_ = "COUNT_STAR"; + return ExpressionBinder::BuildFuncExpr(expr_rewrite, bind_context_ptr, depth, true); } } } diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index f0c890723c..bf5fd969d3 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -674,7 +674,22 @@ void FragmentContext::MakeSinkState(i64 parallel_count) { case PhysicalOperatorType::kInvalid: { Error("Unexpected operator type"); } - case PhysicalOperatorType::kAggregate: + case PhysicalOperatorType::kAggregate: { + if (fragment_type_ != FragmentType::kParallelStream) { + Error(Format("{} should in parallel stream fragment", PhysicalOperatorToString(last_operator->operator_type()))); + } + + if ((i64)tasks_.size() != parallel_count) { + Error(Format("{} task count isn't correct.", PhysicalOperatorToString(last_operator->operator_type()))); + } + + for (u64 task_id = 0; (i64)task_id < parallel_count; ++task_id) { + auto sink_state = MakeUnique(fragment_ptr_->FragmentID(), task_id); + + tasks_[task_id]->sink_state_ = Move(sink_state); + } + break; + } case PhysicalOperatorType::kParallelAggregate: case PhysicalOperatorType::kHash: case PhysicalOperatorType::kTop: { @@ -712,6 +727,7 @@ void FragmentContext::MakeSinkState(i64 parallel_count) { break; } case PhysicalOperatorType::kMergeParallelAggregate: + case PhysicalOperatorType::kMergeAggregate: case PhysicalOperatorType::kMergeHash: case PhysicalOperatorType::kMergeLimit: case PhysicalOperatorType::kMergeTop: diff --git a/src/storage/meta/entry/table_collection_entry.cpp b/src/storage/meta/entry/table_collection_entry.cpp index 5561aeb884..2c99ef63c7 100644 --- a/src/storage/meta/entry/table_collection_entry.cpp +++ b/src/storage/meta/entry/table_collection_entry.cpp @@ -322,7 +322,7 @@ void TableCollectionEntry::CommitDelete(TableCollectionEntry *table_entry, Txn * SegmentEntry::CommitDelete(segment, txn_ptr, block_row_hashmap); row_count += block_row_hashmap.size(); } - table_entry->row_count_ += row_count; + table_entry->row_count_ -= row_count; } UniquePtr TableCollectionEntry::RollbackDelete(TableCollectionEntry *, Txn *, DeleteState &, BufferManager *) { @@ -348,7 +348,7 @@ UniquePtr TableCollectionEntry::ImportSegment(TableCollectionEntry *tabl } UniqueLock rw_locker(table_entry->rw_locker_); - table_entry->row_count_ += row_count; + table_entry->row_count_ = row_count; table_entry->segment_map_.emplace(segment->segment_id_, Move(segment)); return nullptr; } diff --git a/src/unit_test/function/scalar/div_functions.cpp b/src/unit_test/function/scalar/div_functions.cpp index 4017d797ed..55cf58e41a 100644 --- a/src/unit_test/function/scalar/div_functions.cpp +++ b/src/unit_test/function/scalar/div_functions.cpp @@ -53,7 +53,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kTinyInt); - SharedPtr result_type = MakeShared(LogicalType::kTinyInt); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -61,7 +61,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(TinyInt, TinyInt)->TinyInt", func.ToString().c_str()); + EXPECT_STREQ("/(TinyInt, TinyInt)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -98,8 +98,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kTinyInt); - EXPECT_EQ(v.value_.tiny_int, 1); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 1); } } } @@ -108,7 +108,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kSmallInt); - SharedPtr result_type = MakeShared(LogicalType::kSmallInt); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -116,7 +116,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(SmallInt, SmallInt)->SmallInt", func.ToString().c_str()); + EXPECT_STREQ("/(SmallInt, SmallInt)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -152,8 +152,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kSmallInt); - EXPECT_EQ(v.value_.small_int, 2); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 2); } } } @@ -162,7 +162,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kInteger); - SharedPtr result_type = MakeShared(LogicalType::kInteger); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -170,7 +170,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(Integer, Integer)->Integer", func.ToString().c_str()); + EXPECT_STREQ("/(Integer, Integer)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -206,8 +206,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kInteger); - EXPECT_EQ(v.value_.integer, 3); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 3); } } } @@ -216,7 +216,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kBigInt); - SharedPtr result_type = MakeShared(LogicalType::kBigInt); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -224,7 +224,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(BigInt, BigInt)->BigInt", func.ToString().c_str()); + EXPECT_STREQ("/(BigInt, BigInt)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -260,8 +260,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kBigInt); - EXPECT_EQ(v.value_.big_int, 4); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 4); } } } @@ -270,7 +270,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; DataType data_type(LogicalType::kHugeInt); - DataType result_type(LogicalType::kHugeInt); + DataType result_type(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(data_type, "t1", 1, "c2", 1, 0); @@ -278,7 +278,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(HugeInt, HugeInt)->HugeInt", func.ToString().c_str()); + EXPECT_STREQ("/(HugeInt, HugeInt)->Double", func.ToString().c_str()); // TODO: need to complete it. } @@ -287,7 +287,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kFloat); - SharedPtr result_type = MakeShared(LogicalType::kFloat); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -295,7 +295,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(Float, Float)->Float", func.ToString().c_str()); + EXPECT_STREQ("/(Float, Float)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -331,8 +331,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kFloat); - EXPECT_FLOAT_EQ(v.value_.float32, 5); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_FLOAT_EQ(v.value_.float64, 5); } } } diff --git a/test/sql/basic.slt b/test/sql/basic.slt index 9cbe8f2ba2..63b82e97cd 100644 --- a/test/sql/basic.slt +++ b/test/sql/basic.slt @@ -2,6 +2,9 @@ # description: Test basic sql statement for sample # group: [basic] +statement ok +DROP TABLE IF EXISTS NATION; + # Expecting IDENTIFIER or PRIMARY or UNIQUE statement error CREATE TABLE NATION ( diff --git a/test/sql/dml/delete.slt b/test/sql/dml/delete.slt index b1e1099bef..f0119bf31c 100644 --- a/test/sql/dml/delete.slt +++ b/test/sql/dml/delete.slt @@ -23,6 +23,11 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +4 + statement ok DELETE FROM products WHERE product_no = 3; @@ -33,6 +38,11 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +3 + statement ok DELETE FROM products; diff --git a/test/sql/dml/import/test_embedding.slt b/test/sql/dml/import/test_embedding.slt index 70aebb35ab..245cfc4072 100644 --- a/test/sql/dml/import/test_embedding.slt +++ b/test/sql/dml/import/test_embedding.slt @@ -22,6 +22,11 @@ SELECT c1, c2 FROM test_embedding_type; 5 6,7,8 9 10,11,12 +query II +SELECT count(*) FROM test_embedding_type; +---- +3 + # Clean up statement ok diff --git a/test/sql/dml/import/test_jsonl.slt b/test/sql/dml/import/test_jsonl.slt index 69fb95cb8d..fc09489987 100644 --- a/test/sql/dml/import/test_jsonl.slt +++ b/test/sql/dml/import/test_jsonl.slt @@ -25,3 +25,11 @@ Ben 33 1,2,3,4,5 William 28 1,2,3,4,5 Chuck 29 1,2,3,4,5 Viola 35 1,2,3,4,5 + +query III +SELECT count(*) FROM test_jsonl; +---- +14 + +statement ok +DROP TABLE IF EXISTS test_jsonl; \ No newline at end of file diff --git a/test/sql/dml/import/test_varchar.slt b/test/sql/dml/import/test_varchar.slt index 83c772b813..de521843e9 100644 --- a/test/sql/dml/import/test_varchar.slt +++ b/test/sql/dml/import/test_varchar.slt @@ -23,6 +23,12 @@ SELECT c1, c2 FROM test_varchar_type; 3 hello world 4 hello hello hello hello hello + +query III +SELECT count(*) FROM test_varchar_type; +---- +4 + # Clean up statement ok DROP TABLE test_varchar_type; diff --git a/test/sql/dml/insert.slt b/test/sql/dml/insert.slt index 632a82a328..91c7b42438 100644 --- a/test/sql/dml/insert.slt +++ b/test/sql/dml/insert.slt @@ -19,6 +19,11 @@ SELECT * FROM products; ---- 1 2 a +query II +SELECT count(*) FROM products; +---- +1 + query I INSERT INTO products VALUES (3, 4, 'abcdef'), (5, 6, 'abcdefghijklmnopqrstuvwxyz'); ---- @@ -30,6 +35,11 @@ SELECT * FROM products; 3 4 abcdef 5 6 abcdefghijklmnopqrstuvwxyz +query II +SELECT count(*) FROM products; +---- +3 + # Clean up statement ok DROP TABLE products; diff --git a/test/sql/dml/update.slt b/test/sql/dml/update.slt index dfffbd1f96..26d41ae1b5 100644 --- a/test/sql/dml/update.slt +++ b/test/sql/dml/update.slt @@ -23,6 +23,11 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +4 + statement ok UPDATE products SET price=100 WHERE product_no = 3; @@ -34,6 +39,12 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +4 + + statement ok UPDATE products SET price=price+3 WHERE product_no = 1 OR product_no = 5; diff --git a/test/sql/dql/aggregate/test_simple_agg.slt b/test/sql/dql/aggregate/test_simple_agg.slt index 91d6c79730..e3f5a7af2c 100644 --- a/test/sql/dql/aggregate/test_simple_agg.slt +++ b/test/sql/dql/aggregate/test_simple_agg.slt @@ -54,6 +54,56 @@ SELECT COUNT(c1) FROM simple_agg ---- 3 +query III +SELECT SUM(c1)+SUM(c1) FROM simple_agg; +---- +12 + +query III +SELECT MAX(c1)+SUM(c1) FROM simple_agg; +---- +9 + +query III +SELECT MAX(c1)+SUM(c2) FROM simple_agg; +---- +9.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)-SUM(c2) FROM simple_agg; +---- +-3.000000 + +query III +SELECT MAX(c1)/SUM(c2) FROM simple_agg; +---- +0.500000 + +query III +SELECT MAX(c1)/AVG(c2) FROM simple_agg; +---- +1.500000 + +query III +SELECT MAX(c1)*AVG(c2) FROM simple_agg; +---- +6.000000 + +query IIII +SELECT COUNT(*) FROM simple_agg; +---- +3 + statement ok DROP TABLE simple_agg; diff --git a/test/sql/dql/rbo_rule/column_pruner.slt b/test/sql/dql/rbo_rule/column_pruner.slt index f6d516a46b..583854f0e5 100644 --- a/test/sql/dql/rbo_rule/column_pruner.slt +++ b/test/sql/dql/rbo_rule/column_pruner.slt @@ -112,24 +112,25 @@ EXPLAIN LOGICAL SELECT MIN(c1 + 1), AVG(c2) FROM t1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#0), avg(c2) (#1)] + - expressions: [min((c1 + 1)) (#0), sum(c2) (#1) / count(c2) (#2)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] -> TABLE SCAN (2) - table name: t1(default.t1) - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#1), avg(c2) (#2)] + - expressions: [min((c1 + 1)) (#1), sum(c2) (#2) / count(c2) (#3)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] - group by table index: #2 - group by: [c1 (#0)] -> TABLE SCAN (2) @@ -137,6 +138,7 @@ EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL DELETE FROM t1 WHERE c1=1; ---- diff --git a/test/sql/dql/select.slt b/test/sql/dql/select.slt index 8d0444bb9e..2395c20dc0 100644 --- a/test/sql/dql/select.slt +++ b/test/sql/dql/select.slt @@ -1,6 +1,12 @@ statement ok DROP TABLE IF EXISTS select1; +statement ok +DROP TABLE IF EXISTS select2; + +statement ok +DROP TABLE IF EXISTS select3; + statement ok CREATE TABLE select1 (id INTEGER PRIMARY KEY, name VARCHAR, age INTEGER); diff --git a/tools/generate_aggregate.py b/tools/generate_aggregate.py new file mode 100644 index 0000000000..5ed6e97ef9 --- /dev/null +++ b/tools/generate_aggregate.py @@ -0,0 +1,144 @@ +import numpy as np +import random +import os +import argparse + + +def generate(generate_if_exists: bool, copy_dir: str): + row_n = 9000 + sort_dir = "./test/data/csv" + slt_dir = "./test/sql/dql/aggregate" + + table_name = "test_simple_agg_big_cpp" + agg_path = sort_dir + "/test_simple_agg_big.csv" + slt_path = slt_dir + "/test_simple_agg_big.slt" + copy_path = copy_dir + "/test_simple_agg_big.csv" + + os.makedirs(sort_dir, exist_ok=True) + os.makedirs(slt_dir, exist_ok=True) + if os.path.exists(agg_path) and os.path.exists(slt_path) and generate_if_exists: + print( + "File {} and {} already existed exists. Skip Generating.".format( + slt_path, agg_path + ) + ) + return + with open(agg_path, "w") as agg_file, open(slt_path, "w") as slt_file: + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name)) + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write( + "CREATE TABLE {} (c1 int, c2 float);\n".format(table_name) + ) + + # select count(*) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(*) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(0)) + slt_file.write("\n") + + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write( + "COPY {} FROM '{}' WITH ( DELIMITER ',' );\n".format( + table_name, copy_path + ) + ) + + sequence = np.arange(1, row_n+1) + + for i in sequence: + agg_file.write(str(i) + "," + str(i)) + agg_file.write("\n") + + + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write("SELECT c1 FROM {};\n".format(table_name)) + slt_file.write("\n") + + + # select max(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT max(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(row_n)) + slt_file.write("\n") + + # select min(c2) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT min(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(1)) + slt_file.write("\n") + + + # select sum(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT sum(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(np.sum(sequence))) + slt_file.write("\n") + + + # select avg(c1) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT AVG(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(np.mean(sequence))+"00000") + slt_file.write("\n") + + # select count(c1) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(row_n)) + slt_file.write("\n") + + + # select count(*) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(*) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(row_n)) + slt_file.write("\n") + + + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE {};\n".format(table_name)) + random.random() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate limit data for test") + + parser.add_argument( + "-g", + "--generate", + type=bool, + default=False, + dest="generate_if_exists", + ) + parser.add_argument( + "-c", + "--copy", + type=str, + default="/tmp/infinity/test_data", + dest="copy_dir", + ) + args = parser.parse_args() + generate(args.generate_if_exists, args.copy_dir) diff --git a/tools/sqllogictest.py b/tools/sqllogictest.py index 5c49dccfb1..0c4f99cb9b 100644 --- a/tools/sqllogictest.py +++ b/tools/sqllogictest.py @@ -6,6 +6,7 @@ from generate_fvecs import generate as generate2 from generate_sort import generate as generate3 from generate_limit import generate as generate4 +from generate_aggregate import generate as generate5 def python_skd_test(python_test_dir: str): @@ -108,6 +109,7 @@ def copy_all(data_dir, copy_dir): generate2(args.generate_if_exists, args.copy) generate3(args.generate_if_exists, args.copy) generate4(args.generate_if_exists, args.copy) + generate5(args.generate_if_exists, args.copy) print("Generate file finshed.") python_skd_test(python_test_dir) test_process(args.path, args.test, args.data, args.copy)