diff --git a/src/executor/expression/expression_evaluator.cpp b/src/executor/expression/expression_evaluator.cpp index ac65293201..8bb7f029eb 100644 --- a/src/executor/expression/expression_evaluator.cpp +++ b/src/executor/expression/expression_evaluator.cpp @@ -79,10 +79,10 @@ void ExpressionEvaluator::Execute(const SharedPtr &expr, // Create output chunk. // TODO: Now output chunk is pre-allocate memory in expression state // TODO: In the future, it can be implemented as on-demand allocation. - SharedPtr &child_output = child_state->OutputColumnVector(); - Execute(child_expr, child_state, child_output); + SharedPtr &child_output_col = child_state->OutputColumnVector(); + this->Execute(child_expr, child_state, child_output_col); - if (expr->aggregate_function_.argument_type_ != *child_output->data_type()) { + if (expr->aggregate_function_.argument_type_ != *child_output_col->data_type()) { Error("Argument type isn't matched with the child expression output"); } if (expr->aggregate_function_.return_type_ != *output_column_vector->data_type()) { @@ -93,7 +93,7 @@ void ExpressionEvaluator::Execute(const SharedPtr &expr, expr->aggregate_function_.init_func_(expr->aggregate_function_.GetState()); // 2. Loop to fill the aggregate state - expr->aggregate_function_.update_func_(expr->aggregate_function_.GetState(), child_output); + expr->aggregate_function_.update_func_(expr->aggregate_function_.GetState(), child_output_col); // 3. Get the aggregate result and append to output column vector. @@ -157,7 +157,8 @@ void ExpressionEvaluator::Execute(const SharedPtr &expr, SharedPtr &, SharedPtr &output_column_vector) { // memory copy here. - output_column_vector->SetValue(0, expr->GetValue()); + auto value = expr->GetValue(); + output_column_vector->SetValue(0, value); output_column_vector->Finalize(1); } diff --git a/src/executor/fragment_builder.cpp b/src/executor/fragment_builder.cpp index cf69d1d76b..580e46c6df 100644 --- a/src/executor/fragment_builder.cpp +++ b/src/executor/fragment_builder.cpp @@ -38,7 +38,7 @@ UniquePtr FragmentBuilder::BuildFragment(PhysicalOperator *phys_op auto plan_fragment = MakeUnique(GetFragmentId()); plan_fragment->SetSinkNode(query_context_ptr_, SinkType::kResult, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); BuildFragments(phys_op, plan_fragment.get()); - if (plan_fragment->GetSourceNode() != nullptr) { + if (plan_fragment->GetSourceNode() == nullptr) { plan_fragment->SetSourceNode(query_context_ptr_, SourceType::kEmpty, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); } return plan_fragment; diff --git a/src/executor/operator/physical_aggregate.cpp b/src/executor/operator/physical_aggregate.cpp index 825f159c12..627542c6d7 100644 --- a/src/executor/operator/physical_aggregate.cpp +++ b/src/executor/operator/physical_aggregate.cpp @@ -15,6 +15,7 @@ module; #include +#include import stl; import txn; import query_context; @@ -49,14 +50,15 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper Vector> groupby_columns; SizeT group_count = groups_.size(); - if(group_count == 0) { + if (group_count == 0) { // Aggregate without group by expression // e.g. SELECT count(a) FROM table; - if (SimpleAggregate(this->output_, prev_op_state, aggregate_operator_state)) { - return true; - } else { - return false; + auto result = SimpleAggregateExecute(prev_op_state->data_block_array_, aggregate_operator_state->data_block_array_); + prev_op_state->data_block_array_.clear(); + if (prev_op_state->Complete()) { + aggregate_operator_state->SetComplete(); } + return result; } #if 0 groupby_columns.reserve(group_count); @@ -404,17 +406,19 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_ } case kVarchar: { Error("Varchar data shuffle isn't implemented."); -// VarcharT &dst_ref = ((VarcharT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx]; -// VarcharT &src_ref = ((VarcharT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; -// if (src_ref.IsInlined()) { -// Memcpy((char *)&dst_ref, (char *)&src_ref, sizeof(VarcharT)); -// } else { -// dst_ref.length = src_ref.length; -// Memcpy(dst_ref.prefix, src_ref.prefix, VarcharT::PREFIX_LENGTH); -// -// dst_ref.ptr = output_datablock->column_vectors[column_id]->buffer_->fix_heap_mgr_->Allocate(src_ref.length); -// Memcpy(dst_ref.ptr, src_ref.ptr, src_ref.length); -// } + // VarcharT &dst_ref = ((VarcharT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx]; + // VarcharT &src_ref = ((VarcharT + // *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; if + // (src_ref.IsInlined()) { + // Memcpy((char *)&dst_ref, (char *)&src_ref, sizeof(VarcharT)); + // } else { + // dst_ref.length = src_ref.length; + // Memcpy(dst_ref.prefix, src_ref.prefix, VarcharT::PREFIX_LENGTH); + // + // dst_ref.ptr = + // output_datablock->column_vectors[column_id]->buffer_->fix_heap_mgr_->Allocate(src_ref.length); + // Memcpy(dst_ref.ptr, src_ref.ptr, src_ref.length); + // } break; } case kDate: { @@ -559,9 +563,7 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr &input_ #endif } -bool PhysicalAggregate::SimpleAggregate(SharedPtr &output_table, - OperatorState *pre_operator_state, - AggregateOperatorState *aggregate_operator_state) { +bool PhysicalAggregate::SimpleAggregateExecute(const Vector> &input_blocks, Vector> &output_blocks) { SizeT aggregates_count = aggregates_.size(); if (aggregates_count <= 0) { Error("Simple Aggregate without aggregate expression."); @@ -579,19 +581,16 @@ bool PhysicalAggregate::SimpleAggregate(SharedPtr &output_table, Vector> output_types; output_types.reserve(aggregates_count); - SizeT input_block_count = pre_operator_state->data_block_array_.size(); + SizeT input_block_count = input_blocks.size(); - for (i64 idx = 0; auto &expr: aggregates_) { + for (i64 idx = 0; auto &expr : aggregates_) { // expression state expr_states.emplace_back(ExpressionState::CreateState(expr)); SharedPtr output_type = MakeShared(expr->Type()); // column definition - SharedPtr col_def = MakeShared(idx, - output_type, - expr->Name(), - HashSet()); + SharedPtr col_def = MakeShared(idx, output_type, expr->Name(), HashSet()); aggregate_columns.emplace_back(col_def); // for output block @@ -605,51 +604,26 @@ bool PhysicalAggregate::SimpleAggregate(SharedPtr &output_table, LOG_TRACE("No input, no aggregate result"); return true; } - // Loop blocks - - // ExpressionEvaluator evaluator; - // //evaluator.Init(input_table_->data_blocks_); - // for (SizeT expr_idx = 0; expr_idx < aggregates_count; ++expr_idx) { - // - // ExpressionEvaluator evaluator; - // evaluator.Init(aggregates_[]) - // Vector> blocks_column; - // blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]); - // evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], blocks_column[expr_idx]); - // if(blocks_column[0].get() != output_data_block->column_vectors[expr_idx].get()) { - // // column vector in blocks column might be changed to the column vector from column reference. - // // This check and assignment is to make sure the right column vector are assign to output_data_block - // output_data_block->column_vectors[expr_idx] = blocks_column[0]; - // } - // } - // - // output_data_block->Finalize(); for (SizeT block_idx = 0; block_idx < input_block_count; ++block_idx) { - DataBlock *input_data_block = pre_operator_state->data_block_array_[block_idx].get(); + DataBlock *input_data_block = input_blocks[block_idx].get(); + + output_blocks.emplace_back(DataBlock::MakeUniquePtr()); - aggregate_operator_state->data_block_array_.emplace_back(DataBlock::MakeUniquePtr()); - DataBlock *output_data_block = aggregate_operator_state->data_block_array_.back().get(); + DataBlock *output_data_block = output_blocks.back().get(); output_data_block->Init(*GetOutputTypes()); ExpressionEvaluator evaluator; evaluator.Init(input_data_block); SizeT expression_count = aggregates_count; - // Prepare the expression states - + // calculate every columns value for (SizeT expr_idx = 0; expr_idx < expression_count; ++expr_idx) { - // Vector> blocks_column; - // blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]); + LOG_TRACE("Physical aggregate Execute"); evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], output_data_block->column_vectors[expr_idx]); } output_data_block->Finalize(); } - - pre_operator_state->data_block_array_.clear(); - if (pre_operator_state->Complete()) { - aggregate_operator_state->SetComplete(); - } return true; } diff --git a/src/executor/operator/physical_aggregate.cppm b/src/executor/operator/physical_aggregate.cppm index f9f14762f9..e25016ce0b 100644 --- a/src/executor/operator/physical_aggregate.cppm +++ b/src/executor/operator/physical_aggregate.cppm @@ -25,6 +25,7 @@ import hash_table; import base_expression; import load_meta; import infinity_exception; +import data_block; export module physical_aggregate; @@ -44,8 +45,8 @@ public: Vector> aggregates, u64 aggregate_index, SharedPtr> load_metas) - : PhysicalOperator(PhysicalOperatorType::kAggregate, Move(left), nullptr, id, load_metas), groups_(Move(groups)), aggregates_(Move(aggregates)), - groupby_index_(groupby_index), aggregate_index_(aggregate_index) {} + : PhysicalOperator(PhysicalOperatorType::kAggregate, Move(left), nullptr, id, load_metas), groups_(Move(groups)), + aggregates_(Move(aggregates)), groupby_index_(groupby_index), aggregate_index_(aggregate_index) {} ~PhysicalAggregate() override = default; @@ -66,9 +67,7 @@ public: Vector> aggregates_{}; HashTable hash_table_; - bool SimpleAggregate(SharedPtr &output_table, - OperatorState *pre_operator_state, - AggregateOperatorState *aggregate_operator_state); + bool SimpleAggregateExecute(const Vector> &input_blocks, Vector> &output_blocks); inline u64 GroupTableIndex() const { return groupby_index_; } diff --git a/src/executor/operator/physical_merge_aggregate.cpp b/src/executor/operator/physical_merge_aggregate.cpp index 7775588187..81fc1fe6ea 100644 --- a/src/executor/operator/physical_merge_aggregate.cpp +++ b/src/executor/operator/physical_merge_aggregate.cpp @@ -14,19 +14,164 @@ module; +#include +#include +import stl; import query_context; import operator_state; import infinity_exception; +import logger; +import value; +import data_block; +import parser; +import physical_aggregate; +import aggregate_expression; module physical_merge_aggregate; namespace infinity { +template +using MathOperation = StdFunction; + void PhysicalMergeAggregate::Init() {} bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState *operator_state) { - Error("Not Implement"); + LOG_TRACE("PhysicalMergeAggregate::Execute:: mark"); + auto merge_aggregate_op_state = static_cast(operator_state); + + SimpleMergeAggregateExecute(merge_aggregate_op_state); + + if (merge_aggregate_op_state->input_complete_) { + + LOG_TRACE("PhysicalMergeAggregate::Input is complete"); + for (auto &output_block : merge_aggregate_op_state->data_block_array_) { + output_block->Finalize(); + } + + merge_aggregate_op_state->SetComplete(); + return true; + } + return false; } +void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorState *op_state) { + if (op_state->data_block_array_.empty()) { + op_state->data_block_array_.emplace_back(Move(op_state->input_data_block_)); + } else { + auto agg_op = dynamic_cast(this->left()); + auto aggs_size = agg_op->aggregates_.size(); + for (SizeT col_idx = 0; col_idx < aggs_size; ++col_idx) { + auto agg_expression = static_cast(agg_op->aggregates_[col_idx].get()); + + auto function_name = agg_expression->aggregate_function_.GetFuncName(); + + auto func_return_type = agg_expression->aggregate_function_.return_type_; + + switch (func_return_type.type()) { + case kInteger: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + case kBigInt: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + case kFloat: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + case kDouble: { + HandleAggregateFunction(function_name, op_state, col_idx); + break; + } + default: + Error("input_value_type not Implement"); + } + } + } +} + +template +void PhysicalMergeAggregate::HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx) { + if (String(function_name) == String("COUNT")) { + LOG_TRACE("COUNT"); + HandleCount(op_state, col_idx); + } else if (String(function_name) == String("MIN")) { + LOG_TRACE("MIN"); + HandleMin(op_state, col_idx); + } else if (String(function_name) == String("MAX")) { + LOG_TRACE("MAX"); + HandleMax(op_state, col_idx); + } else if (String(function_name) == String("SUM")) { + LOG_TRACE("SUM"); + HandleSum(op_state, col_idx); + } +} + +template +void PhysicalMergeAggregate::HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation minOperation = [](T a, T b) -> T { return (a < b) ? a : b; }; + UpdateData(op_state, minOperation, col_idx); +} + +template +void PhysicalMergeAggregate::HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation maxOperation = [](T a, T b) -> T { return (a > b) ? a : b; }; + UpdateData(op_state, maxOperation, col_idx); +} + +template +void PhysicalMergeAggregate::HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation countOperation = [](T a, T b) -> T { return a + b; }; + UpdateData(op_state, countOperation, col_idx); +} + +template +void PhysicalMergeAggregate::HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx) { + MathOperation sumOperation = [](T a, T b) -> T { return a + b; }; + UpdateData(op_state, sumOperation, col_idx); +} + +template +T PhysicalMergeAggregate::GetInputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx) { + Value value = op_state->input_data_block_->GetValue(col_idx, row_idx); + return value.GetValue(); +} + +template +T PhysicalMergeAggregate::GetOutputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx) { + Value value = op_state->data_block_array_[block_index]->GetValue(col_idx, row_idx); + return value.GetValue(); +} + +template +void PhysicalMergeAggregate::WriteValueAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, T value) { + op_state->data_block_array_[block_index]->SetValue(col_idx, row_idx, CreateValue(value)); +} + +template +void PhysicalMergeAggregate::UpdateData(MergeAggregateOperatorState *op_state, MathOperation operation, SizeT col_idx) { + T input = GetInputData(op_state, 0, col_idx, 0); + T output = GetOutputData(op_state, 0, col_idx, 0); + T new_value = operation(input, output); + WriteValueAtPosition(op_state, 0, col_idx, 0, new_value); +} + +template +T PhysicalMergeAggregate::AddData(T a, T b) { + return a + b; +} + +template +T PhysicalMergeAggregate::MinValue(T a, T b) { + return (a < b) ? a : b; +} + +template +T PhysicalMergeAggregate::MaxValue(T a, T b) { + return (a > b) ? a : b; +} + } // namespace infinity diff --git a/src/executor/operator/physical_merge_aggregate.cppm b/src/executor/operator/physical_merge_aggregate.cppm index 8dda917458..aca0859317 100644 --- a/src/executor/operator/physical_merge_aggregate.cppm +++ b/src/executor/operator/physical_merge_aggregate.cppm @@ -23,6 +23,8 @@ import physical_operator_type; import load_meta; import base_table_ref; import infinity_exception; +import value; +import data_block; export module physical_merge_aggregate; @@ -53,6 +55,83 @@ public: Error("TaskletCount not Implement"); return 0; } + + template + T GetDataFromValueAtInputBlockPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + template + T GetDataFromValueAtOutputBlockPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + void WriteIntegerAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, IntegerT integer); + + template + T GetInputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + template + T GetOutputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx); + + template + T MinValue(T a, T b); + + template + T MaxValue(T a, T b); + + template + T AddData(T a, T b); + + template + using MathOperation = StdFunction; + + void SimpleMergeAggregateExecute(MergeAggregateOperatorState *merge_aggregate_op_state); + + void UpdateBlockData(MergeAggregateOperatorState *merge_aggregate_op_state, SizeT col_idx); + + template + void UpdateData(MergeAggregateOperatorState *op_state, MathOperation operation, SizeT col_idx); + + template + void WriteValueAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, T value); + + template + void HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + void HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx); + + template + Value CreateValue(T value) { + Error("Unhandled type for makeValue"); + } + + template <> + Value CreateValue(IntegerT value) { + return Value::MakeInt(value); + } + + template <> + Value CreateValue(BigIntT value) { + return Value::MakeBigInt(value); + } + + template <> + Value CreateValue(DoubleT value) { + return Value::MakeDouble(value); + } + + template <> + Value CreateValue(FloatT value) { + return Value::MakeFloat(value); + } + private: SharedPtr> output_names_{}; SharedPtr>> output_types_{}; diff --git a/src/executor/operator_state.cpp b/src/executor/operator_state.cpp index 7f8967dabd..1d9ac95869 100644 --- a/src/executor/operator_state.cpp +++ b/src/executor/operator_state.cpp @@ -114,6 +114,14 @@ bool QueueSourceState::GetData() { } break; } + case PhysicalOperatorType::kMergeAggregate: { + auto *fragment_data = static_cast(fragment_data_base.get()); + MergeAggregateOperatorState *merge_aggregate_op_state = (MergeAggregateOperatorState *)next_op_state; + //merge_aggregate_op_state->input_data_blocks_.push_back(Move(fragment_data->data_block_)); + merge_aggregate_op_state->input_data_block_ = Move(fragment_data->data_block_); + merge_aggregate_op_state->input_complete_ = completed; + break; + } default: { Error("Not support operator type"); break; diff --git a/src/executor/operator_state.cppm b/src/executor/operator_state.cppm index a5b2515027..2d4e7635ff 100644 --- a/src/executor/operator_state.cppm +++ b/src/executor/operator_state.cppm @@ -60,6 +60,11 @@ export struct AggregateOperatorState : public OperatorState { // Merge Aggregate export struct MergeAggregateOperatorState : public OperatorState { inline explicit MergeAggregateOperatorState() : OperatorState(PhysicalOperatorType::kMergeAggregate) {} + + /// Since merge agg is the first op, no previous operator state. This ptr is to get input data. + //Vector> input_data_blocks_{nullptr}; + UniquePtr input_data_block_{nullptr}; + bool input_complete_{false}; }; // Merge Parallel Aggregate diff --git a/src/executor/physical_planner.cpp b/src/executor/physical_planner.cpp index eba2599e84..cdda936f1e 100644 --- a/src/executor/physical_planner.cpp +++ b/src/executor/physical_planner.cpp @@ -541,6 +541,8 @@ UniquePtr PhysicalPlanner::BuildAggregate(const SharedPtraggregate_index_, logical_operator->load_metas()); + + if (tasklet_count == 1) { return physical_agg_op; } else { diff --git a/src/function/aggregate/count.cpp b/src/function/aggregate/count.cpp index 08b461e273..6e93f86c13 100644 --- a/src/function/aggregate/count.cpp +++ b/src/function/aggregate/count.cpp @@ -159,15 +159,16 @@ void RegisterCountFunction(const UniquePtr &catalog_ptr) { function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = -// UnaryAggregate, PathT, BigIntT>(func_name, DataType(LogicalType::kPath), DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = + // UnaryAggregate, PathT, BigIntT>(func_name, DataType(LogicalType::kPath), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = UnaryAggregate, PolygonT, BigIntT>(func_name, -// DataType(LogicalType::kPolygon), -// DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = UnaryAggregate, PolygonT, BigIntT>(func_name, + // DataType(LogicalType::kPolygon), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { AggregateFunction count_function = @@ -175,9 +176,10 @@ void RegisterCountFunction(const UniquePtr &catalog_ptr) { function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = -// UnaryAggregate, BitmapT, BigIntT>(func_name, DataType(LogicalType::kBitmap), DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = + // UnaryAggregate, BitmapT, BigIntT>(func_name, DataType(LogicalType::kBitmap), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { AggregateFunction count_function = @@ -185,9 +187,10 @@ void RegisterCountFunction(const UniquePtr &catalog_ptr) { function_set_ptr->AddFunction(count_function); } { -// AggregateFunction count_function = -// UnaryAggregate, BlobT, BigIntT>(func_name, DataType(LogicalType::kBlob), DataType(LogicalType::kBigInt)); -// function_set_ptr->AddFunction(count_function); + // AggregateFunction count_function = + // UnaryAggregate, BlobT, BigIntT>(func_name, DataType(LogicalType::kBlob), + // DataType(LogicalType::kBigInt)); + // function_set_ptr->AddFunction(count_function); } { AggregateFunction count_function = UnaryAggregate, EmbeddingT, BigIntT>(func_name, diff --git a/src/function/aggregate/count_star.cpp b/src/function/aggregate/count_star.cpp new file mode 100644 index 0000000000..11b983710d --- /dev/null +++ b/src/function/aggregate/count_star.cpp @@ -0,0 +1,59 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import new_catalog; + +import infinity_exception; +import aggregate_function; +import aggregate_function_set; +import parser; +import third_party; + +module count_star; + +namespace infinity { + +template +struct CountStarState { + i64 value_{}; + + void Initialize() { this->value_ = 0; } + + void Update(i64 *__restrict input, SizeT idx) { value_ = input[idx]; } + + inline void ConstantUpdate(i64 *__restrict input, SizeT idx, SizeT) { value_ = input[idx]; } + + inline ptr_t Finalize() { return (ptr_t)&value_; } + + inline static SizeT Size(const DataType &) { return sizeof(i64); } +}; + +void RegisterCountStarFunction(const UniquePtr &catalog_ptr) { + String func_name = "COUNT_STAR"; + + SharedPtr function_set_ptr = MakeShared(func_name); + + { + AggregateFunction count_function = + CountStarAggregate, BigIntT>(func_name, DataType(LogicalType::kBigInt), DataType(LogicalType::kBigInt)); + function_set_ptr->AddFunction(count_function); + } + + NewCatalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); +} + +} // namespace infinity diff --git a/src/function/aggregate/count_star.cppm b/src/function/aggregate/count_star.cppm new file mode 100644 index 0000000000..d281869a86 --- /dev/null +++ b/src/function/aggregate/count_star.cppm @@ -0,0 +1,27 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; + +export module count_star; + +namespace infinity { + +class NewCatalog; + +export void RegisterCountStarFunction(const UniquePtr &catalog_ptr); + +} \ No newline at end of file diff --git a/src/function/aggregate_function.cppm b/src/function/aggregate_function.cppm index 8be567b9fa..4c3c515a5c 100644 --- a/src/function/aggregate_function.cppm +++ b/src/function/aggregate_function.cppm @@ -95,6 +95,8 @@ public: [[nodiscard]] ptr_t GetState() const { return state_data_.get(); } + [[nodiscard]] String GetFuncName() const { return name_; } + public: AggregateInitializeFuncType init_func_; AggregateUpdateFuncType update_func_; @@ -117,4 +119,49 @@ inline AggregateFunction UnaryAggregate(const String &name, const DataType &inpu AggregateOperation::StateUpdate, AggregateOperation::StateFinalize); } + + +class CountStarAggregateOperation { +public: + template + static inline void StateInitialize(const ptr_t state) { + ((AggregateState *)state)->Initialize(); + } + + template + static inline void StateUpdate(const ptr_t state, const SharedPtr &input_column_vector) { + // Loop execute state update according to the input column vector + + switch (input_column_vector->vector_type()) { + case ColumnVectorType::kConstant: { + auto *input_ptr = (i64 *)(input_column_vector->data()); + ((AggregateState *)state)->Update(input_ptr, 0); + break; + } + default: { + Error("Other type of column vector isn't implemented"); + } + } + } + + template + static inline ptr_t StateFinalize(const ptr_t state) { + // Loop execute state update according to the input column vector + ptr_t result = ((AggregateState *)state)->Finalize(); + return result; + } +}; + + +export template +inline auto CountStarAggregate(String &name, const DataType &input_type, const DataType &return_type) -> AggregateFunction { + auto agg_function = AggregateFunction(name, + input_type, + return_type, + AggregateState::Size(input_type), + CountStarAggregateOperation::StateInitialize, + CountStarAggregateOperation::StateUpdate, + CountStarAggregateOperation::StateFinalize); + return agg_function; +} } // namespace infinity diff --git a/src/function/builtin_functions.cpp b/src/function/builtin_functions.cpp index 8117ed79ad..0afd36c1e6 100644 --- a/src/function/builtin_functions.cpp +++ b/src/function/builtin_functions.cpp @@ -22,6 +22,7 @@ import first; import max; import min; import sum; +import count_star; import add; import abs; @@ -52,6 +53,7 @@ import special_function; import parser; + module builtin_functions; namespace infinity { @@ -68,6 +70,7 @@ void BuiltinFunctions::Init() { void BuiltinFunctions::RegisterAggregateFunction() { RegisterAvgFunction(catalog_ptr_); RegisterCountFunction(catalog_ptr_); + RegisterCountStarFunction(catalog_ptr_); RegisterFirstFunction(catalog_ptr_); RegisterMaxFunction(catalog_ptr_); RegisterMinFunction(catalog_ptr_); diff --git a/src/function/scalar/divide.cpp b/src/function/scalar/divide.cpp index 0a6511cbd9..293bf348c8 100644 --- a/src/function/scalar/divide.cpp +++ b/src/function/scalar/divide.cpp @@ -38,14 +38,14 @@ struct DivFunction { if (left == std::numeric_limits::min() && right == -1) { return false; } - result = left / right; + result = DoubleT(left) / DoubleT(right); return true; } }; template <> inline bool DivFunction::Run(FloatT left, FloatT right, FloatT &result) { - result = left / right; + result = left / DoubleT(right); if (std::isnan(result) || std::isinf(result)) return false; return true; @@ -65,6 +65,12 @@ inline bool DivFunction::Run(HugeIntT, HugeIntT, HugeIntT &) { return false; } +template <> +inline bool DivFunction::Run(HugeIntT, HugeIntT, DoubleT &) { + Error("Not implement huge int divide operator."); + return false; +} + void RegisterDivFunction(const UniquePtr &catalog_ptr) { String func_name = "/"; @@ -72,38 +78,38 @@ void RegisterDivFunction(const UniquePtr &catalog_ptr) { ScalarFunction div_function_int8(func_name, {DataType(LogicalType::kTinyInt), DataType(LogicalType::kTinyInt)}, - {DataType(LogicalType::kTinyInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int8); ScalarFunction div_function_int16(func_name, {DataType(LogicalType::kSmallInt), DataType(LogicalType::kSmallInt)}, - {DataType(LogicalType::kSmallInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int16); ScalarFunction div_function_int32(func_name, {DataType(LogicalType::kInteger), DataType(LogicalType::kInteger)}, - {DataType(LogicalType::kInteger)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int32); ScalarFunction div_function_int64(func_name, {DataType(LogicalType::kBigInt), DataType(LogicalType::kBigInt)}, - {DataType(LogicalType::kBigInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int64); ScalarFunction div_function_int128(func_name, {DataType(LogicalType::kHugeInt), DataType(LogicalType::kHugeInt)}, - {DataType(LogicalType::kHugeInt)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_int128); ScalarFunction div_function_float(func_name, {DataType(LogicalType::kFloat), DataType(LogicalType::kFloat)}, - {DataType(LogicalType::kFloat)}, - &ScalarFunction::BinaryFunctionWithFailure); + {DataType(LogicalType::kDouble)}, + &ScalarFunction::BinaryFunctionWithFailure); function_set_ptr->AddFunction(div_function_float); ScalarFunction div_function_double(func_name, diff --git a/src/planner/binder/project_binder.cpp b/src/planner/binder/project_binder.cpp index 7f2386fce5..c4512c44f1 100644 --- a/src/planner/binder/project_binder.cpp +++ b/src/planner/binder/project_binder.cpp @@ -27,11 +27,51 @@ import infinity_exception; module project_binder; +namespace { + +using namespace infinity; + +void ConvertAvgToSumDivideCount(FunctionExpr &func_expression, const Vector &column_names) { + func_expression.func_name_ = "/"; + func_expression.arguments_->clear(); + auto createFunctionWithColumnArg = [&column_names](const String &func_name) { + auto function_expression = MakeUnique(); + function_expression->func_name_ = func_name; + function_expression->arguments_ = new Vector(); + auto column_expr = MakeUnique(); + column_expr->names_.push_back(column_names[0]); + function_expression->arguments_->push_back(column_expr.release()); + return function_expression.release(); + }; + func_expression.arguments_->push_back(createFunctionWithColumnArg("sum")); + func_expression.arguments_->push_back(createFunctionWithColumnArg("count")); +} + +} // namespace + namespace infinity { SharedPtr ProjectBinder::BuildExpression(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { String expr_name = expr.GetName(); + // Covert avg function expr to (sum / count) function expr + if (expr.type_ == ParsedExprType::kFunction) { + auto &function_expression = (FunctionExpr &)expr; + auto special_function = TryBuildSpecialFuncExpr(function_expression, bind_context_ptr, depth); + if (special_function.has_value()) { + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); + } + auto function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), function_expression); + + if (IsEqual(function_set_ptr->name(), String("AVG")) && function_expression.arguments_->size() == 1 && + (*function_expression.arguments_)[0]->type_ == ParsedExprType::kColumn) { + auto column_expr = (ColumnExpr *)(*function_expression.arguments_)[0]; + Vector column_names(Move(column_expr->names_)); + delete column_expr; + ConvertAvgToSumDivideCount(function_expression, column_names); + return ExpressionBinder::BuildExpression(expr, bind_context_ptr, depth, root); + } + } // If the expr isn't from aggregate function and coming from group by lists. if (!this->binding_agg_func_ && bind_context_ptr->group_index_by_name_.contains(expr_name)) { i64 groupby_index = bind_context_ptr->group_index_by_name_[expr_name]; @@ -94,7 +134,7 @@ SharedPtr ProjectBinder::BuildFuncExpr(const FunctionExpr &expr, SharedPtr function_set_ptr = FunctionSet::GetFunctionSet(query_context_->storage()->catalog(), expr); if (function_set_ptr->type_ == FunctionType::kAggregate) { if (this->binding_agg_func_) { - Error("Aggregate function is called in another aggregate function."); + Error(Format("Aggregate function {} is called in another aggregate function.", function_set_ptr->name())); } else { this->binding_agg_func_ = true; } diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index 6aacad501a..8fe576b33a 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -64,6 +64,7 @@ namespace infinity { SharedPtr ExpressionBinder::Bind(const ParsedExpr &expr, BindContext *bind_context_ptr, i64 depth, bool root) { // Call implemented BuildExpression + SharedPtr result = BuildExpression(expr, bind_context_ptr, depth, root); if (result.get() == nullptr) { if (result.get() == nullptr) { @@ -274,7 +275,15 @@ SharedPtr ExpressionBinder::BuildFuncExpr(const FunctionExpr &ex ColumnExpr *col_expr = (ColumnExpr *)(*expr.arguments_)[0]; if (col_expr->star_) { delete (*expr.arguments_)[0]; - (*expr.arguments_)[0] = new ConstantExpr(LiteralType::kBoolean); + auto constant_exp = new ConstantExpr(LiteralType::kInteger); + // catulate row count + String &table_name = bind_context_ptr->table_names_[0]; + TableCollectionEntry *table_entry = bind_context_ptr->binding_by_name_[table_name]->table_collection_entry_ptr_; + constant_exp->integer_value_ = table_entry->row_count_; + (*expr.arguments_)[0] = constant_exp; + auto &expr_rewrite = (FunctionExpr &)expr; + expr_rewrite.func_name_ = "COUNT_STAR"; + return ExpressionBinder::BuildFuncExpr(expr_rewrite, bind_context_ptr, depth, true); } } } diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index f0c890723c..bf5fd969d3 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -674,7 +674,22 @@ void FragmentContext::MakeSinkState(i64 parallel_count) { case PhysicalOperatorType::kInvalid: { Error("Unexpected operator type"); } - case PhysicalOperatorType::kAggregate: + case PhysicalOperatorType::kAggregate: { + if (fragment_type_ != FragmentType::kParallelStream) { + Error(Format("{} should in parallel stream fragment", PhysicalOperatorToString(last_operator->operator_type()))); + } + + if ((i64)tasks_.size() != parallel_count) { + Error(Format("{} task count isn't correct.", PhysicalOperatorToString(last_operator->operator_type()))); + } + + for (u64 task_id = 0; (i64)task_id < parallel_count; ++task_id) { + auto sink_state = MakeUnique(fragment_ptr_->FragmentID(), task_id); + + tasks_[task_id]->sink_state_ = Move(sink_state); + } + break; + } case PhysicalOperatorType::kParallelAggregate: case PhysicalOperatorType::kHash: case PhysicalOperatorType::kTop: { @@ -712,6 +727,7 @@ void FragmentContext::MakeSinkState(i64 parallel_count) { break; } case PhysicalOperatorType::kMergeParallelAggregate: + case PhysicalOperatorType::kMergeAggregate: case PhysicalOperatorType::kMergeHash: case PhysicalOperatorType::kMergeLimit: case PhysicalOperatorType::kMergeTop: diff --git a/src/storage/meta/entry/table_collection_entry.cpp b/src/storage/meta/entry/table_collection_entry.cpp index 5561aeb884..2c99ef63c7 100644 --- a/src/storage/meta/entry/table_collection_entry.cpp +++ b/src/storage/meta/entry/table_collection_entry.cpp @@ -322,7 +322,7 @@ void TableCollectionEntry::CommitDelete(TableCollectionEntry *table_entry, Txn * SegmentEntry::CommitDelete(segment, txn_ptr, block_row_hashmap); row_count += block_row_hashmap.size(); } - table_entry->row_count_ += row_count; + table_entry->row_count_ -= row_count; } UniquePtr TableCollectionEntry::RollbackDelete(TableCollectionEntry *, Txn *, DeleteState &, BufferManager *) { @@ -348,7 +348,7 @@ UniquePtr TableCollectionEntry::ImportSegment(TableCollectionEntry *tabl } UniqueLock rw_locker(table_entry->rw_locker_); - table_entry->row_count_ += row_count; + table_entry->row_count_ = row_count; table_entry->segment_map_.emplace(segment->segment_id_, Move(segment)); return nullptr; } diff --git a/src/unit_test/function/scalar/div_functions.cpp b/src/unit_test/function/scalar/div_functions.cpp index 4017d797ed..55cf58e41a 100644 --- a/src/unit_test/function/scalar/div_functions.cpp +++ b/src/unit_test/function/scalar/div_functions.cpp @@ -53,7 +53,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kTinyInt); - SharedPtr result_type = MakeShared(LogicalType::kTinyInt); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -61,7 +61,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(TinyInt, TinyInt)->TinyInt", func.ToString().c_str()); + EXPECT_STREQ("/(TinyInt, TinyInt)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -98,8 +98,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kTinyInt); - EXPECT_EQ(v.value_.tiny_int, 1); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 1); } } } @@ -108,7 +108,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kSmallInt); - SharedPtr result_type = MakeShared(LogicalType::kSmallInt); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -116,7 +116,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(SmallInt, SmallInt)->SmallInt", func.ToString().c_str()); + EXPECT_STREQ("/(SmallInt, SmallInt)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -152,8 +152,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kSmallInt); - EXPECT_EQ(v.value_.small_int, 2); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 2); } } } @@ -162,7 +162,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kInteger); - SharedPtr result_type = MakeShared(LogicalType::kInteger); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -170,7 +170,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(Integer, Integer)->Integer", func.ToString().c_str()); + EXPECT_STREQ("/(Integer, Integer)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -206,8 +206,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kInteger); - EXPECT_EQ(v.value_.integer, 3); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 3); } } } @@ -216,7 +216,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kBigInt); - SharedPtr result_type = MakeShared(LogicalType::kBigInt); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -224,7 +224,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(BigInt, BigInt)->BigInt", func.ToString().c_str()); + EXPECT_STREQ("/(BigInt, BigInt)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -260,8 +260,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kBigInt); - EXPECT_EQ(v.value_.big_int, 4); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_EQ(v.value_.float64, 4); } } } @@ -270,7 +270,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; DataType data_type(LogicalType::kHugeInt); - DataType result_type(LogicalType::kHugeInt); + DataType result_type(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(data_type, "t1", 1, "c2", 1, 0); @@ -278,7 +278,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(HugeInt, HugeInt)->HugeInt", func.ToString().c_str()); + EXPECT_STREQ("/(HugeInt, HugeInt)->Double", func.ToString().c_str()); // TODO: need to complete it. } @@ -287,7 +287,7 @@ TEST_F(DivFunctionsTest, div_func) { Vector> inputs; SharedPtr data_type = MakeShared(LogicalType::kFloat); - SharedPtr result_type = MakeShared(LogicalType::kFloat); + SharedPtr result_type = MakeShared(LogicalType::kDouble); SharedPtr col1_expr_ptr = MakeShared(*data_type, "t1", 1, "c1", 0, 0); SharedPtr col2_expr_ptr = MakeShared(*data_type, "t1", 1, "c2", 1, 0); @@ -295,7 +295,7 @@ TEST_F(DivFunctionsTest, div_func) { inputs.emplace_back(col2_expr_ptr); ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs); - EXPECT_STREQ("/(Float, Float)->Float", func.ToString().c_str()); + EXPECT_STREQ("/(Float, Float)->Double", func.ToString().c_str()); Vector> column_types; column_types.emplace_back(data_type); @@ -331,8 +331,8 @@ TEST_F(DivFunctionsTest, div_func) { } else { Value v = result->GetValue(i); EXPECT_TRUE(result->nulls_ptr_->IsTrue(i)); - EXPECT_EQ(v.type_.type(), LogicalType::kFloat); - EXPECT_FLOAT_EQ(v.value_.float32, 5); + EXPECT_EQ(v.type_.type(), LogicalType::kDouble); + EXPECT_FLOAT_EQ(v.value_.float64, 5); } } } diff --git a/test/sql/basic.slt b/test/sql/basic.slt index 9cbe8f2ba2..63b82e97cd 100644 --- a/test/sql/basic.slt +++ b/test/sql/basic.slt @@ -2,6 +2,9 @@ # description: Test basic sql statement for sample # group: [basic] +statement ok +DROP TABLE IF EXISTS NATION; + # Expecting IDENTIFIER or PRIMARY or UNIQUE statement error CREATE TABLE NATION ( diff --git a/test/sql/dml/delete.slt b/test/sql/dml/delete.slt index b1e1099bef..f0119bf31c 100644 --- a/test/sql/dml/delete.slt +++ b/test/sql/dml/delete.slt @@ -23,6 +23,11 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +4 + statement ok DELETE FROM products WHERE product_no = 3; @@ -33,6 +38,11 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +3 + statement ok DELETE FROM products; diff --git a/test/sql/dml/import/test_embedding.slt b/test/sql/dml/import/test_embedding.slt index 70aebb35ab..245cfc4072 100644 --- a/test/sql/dml/import/test_embedding.slt +++ b/test/sql/dml/import/test_embedding.slt @@ -22,6 +22,11 @@ SELECT c1, c2 FROM test_embedding_type; 5 6,7,8 9 10,11,12 +query II +SELECT count(*) FROM test_embedding_type; +---- +3 + # Clean up statement ok diff --git a/test/sql/dml/import/test_jsonl.slt b/test/sql/dml/import/test_jsonl.slt index 69fb95cb8d..fc09489987 100644 --- a/test/sql/dml/import/test_jsonl.slt +++ b/test/sql/dml/import/test_jsonl.slt @@ -25,3 +25,11 @@ Ben 33 1,2,3,4,5 William 28 1,2,3,4,5 Chuck 29 1,2,3,4,5 Viola 35 1,2,3,4,5 + +query III +SELECT count(*) FROM test_jsonl; +---- +14 + +statement ok +DROP TABLE IF EXISTS test_jsonl; \ No newline at end of file diff --git a/test/sql/dml/import/test_varchar.slt b/test/sql/dml/import/test_varchar.slt index 83c772b813..de521843e9 100644 --- a/test/sql/dml/import/test_varchar.slt +++ b/test/sql/dml/import/test_varchar.slt @@ -23,6 +23,12 @@ SELECT c1, c2 FROM test_varchar_type; 3 hello world 4 hello hello hello hello hello + +query III +SELECT count(*) FROM test_varchar_type; +---- +4 + # Clean up statement ok DROP TABLE test_varchar_type; diff --git a/test/sql/dml/insert.slt b/test/sql/dml/insert.slt index 632a82a328..91c7b42438 100644 --- a/test/sql/dml/insert.slt +++ b/test/sql/dml/insert.slt @@ -19,6 +19,11 @@ SELECT * FROM products; ---- 1 2 a +query II +SELECT count(*) FROM products; +---- +1 + query I INSERT INTO products VALUES (3, 4, 'abcdef'), (5, 6, 'abcdefghijklmnopqrstuvwxyz'); ---- @@ -30,6 +35,11 @@ SELECT * FROM products; 3 4 abcdef 5 6 abcdefghijklmnopqrstuvwxyz +query II +SELECT count(*) FROM products; +---- +3 + # Clean up statement ok DROP TABLE products; diff --git a/test/sql/dml/update.slt b/test/sql/dml/update.slt index dfffbd1f96..26d41ae1b5 100644 --- a/test/sql/dml/update.slt +++ b/test/sql/dml/update.slt @@ -23,6 +23,11 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +4 + statement ok UPDATE products SET price=100 WHERE product_no = 3; @@ -34,6 +39,12 @@ SELECT * FROM products; 5 6 7 8 +query II +SELECT count(*) FROM products; +---- +4 + + statement ok UPDATE products SET price=price+3 WHERE product_no = 1 OR product_no = 5; diff --git a/test/sql/dql/aggregate/test_simple_agg.slt b/test/sql/dql/aggregate/test_simple_agg.slt index 91d6c79730..e3f5a7af2c 100644 --- a/test/sql/dql/aggregate/test_simple_agg.slt +++ b/test/sql/dql/aggregate/test_simple_agg.slt @@ -54,6 +54,56 @@ SELECT COUNT(c1) FROM simple_agg ---- 3 +query III +SELECT SUM(c1)+SUM(c1) FROM simple_agg; +---- +12 + +query III +SELECT MAX(c1)+SUM(c1) FROM simple_agg; +---- +9 + +query III +SELECT MAX(c1)+SUM(c2) FROM simple_agg; +---- +9.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)*SUM(c2) FROM simple_agg; +---- +18.000000 + +query III +SELECT MAX(c1)-SUM(c2) FROM simple_agg; +---- +-3.000000 + +query III +SELECT MAX(c1)/SUM(c2) FROM simple_agg; +---- +0.500000 + +query III +SELECT MAX(c1)/AVG(c2) FROM simple_agg; +---- +1.500000 + +query III +SELECT MAX(c1)*AVG(c2) FROM simple_agg; +---- +6.000000 + +query IIII +SELECT COUNT(*) FROM simple_agg; +---- +3 + statement ok DROP TABLE simple_agg; diff --git a/test/sql/dql/rbo_rule/column_pruner.slt b/test/sql/dql/rbo_rule/column_pruner.slt index f6d516a46b..583854f0e5 100644 --- a/test/sql/dql/rbo_rule/column_pruner.slt +++ b/test/sql/dql/rbo_rule/column_pruner.slt @@ -112,24 +112,25 @@ EXPLAIN LOGICAL SELECT MIN(c1 + 1), AVG(c2) FROM t1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#0), avg(c2) (#1)] + - expressions: [min((c1 + 1)) (#0), sum(c2) (#1) / count(c2) (#2)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] -> TABLE SCAN (2) - table name: t1(default.t1) - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; ---- PROJECT (4) - table index: #4 - - expressions: [min((c1 + 1)) (#1), avg(c2) (#2)] + - expressions: [min((c1 + 1)) (#1), sum(c2) (#2) / count(c2) (#3)] -> AGGREGATE (3) - aggregate table index: #3 - - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), AVG(c2 (#1))] + - aggregate: [MIN(CAST(c1 (#0) AS BigInt) + 1), SUM(c2 (#1)), COUNT(c2 (#1))] - group by table index: #2 - group by: [c1 (#0)] -> TABLE SCAN (2) @@ -137,6 +138,7 @@ EXPLAIN LOGICAL SELECT Min(c1 + 1), AVG(c2) FROM t1 GROUP BY c1; - table index: #1 - output columns: [c1, c2, __rowid] + query I EXPLAIN LOGICAL DELETE FROM t1 WHERE c1=1; ---- diff --git a/test/sql/dql/select.slt b/test/sql/dql/select.slt index 8d0444bb9e..2395c20dc0 100644 --- a/test/sql/dql/select.slt +++ b/test/sql/dql/select.slt @@ -1,6 +1,12 @@ statement ok DROP TABLE IF EXISTS select1; +statement ok +DROP TABLE IF EXISTS select2; + +statement ok +DROP TABLE IF EXISTS select3; + statement ok CREATE TABLE select1 (id INTEGER PRIMARY KEY, name VARCHAR, age INTEGER); diff --git a/tools/generate_aggregate.py b/tools/generate_aggregate.py new file mode 100644 index 0000000000..5ed6e97ef9 --- /dev/null +++ b/tools/generate_aggregate.py @@ -0,0 +1,144 @@ +import numpy as np +import random +import os +import argparse + + +def generate(generate_if_exists: bool, copy_dir: str): + row_n = 9000 + sort_dir = "./test/data/csv" + slt_dir = "./test/sql/dql/aggregate" + + table_name = "test_simple_agg_big_cpp" + agg_path = sort_dir + "/test_simple_agg_big.csv" + slt_path = slt_dir + "/test_simple_agg_big.slt" + copy_path = copy_dir + "/test_simple_agg_big.csv" + + os.makedirs(sort_dir, exist_ok=True) + os.makedirs(slt_dir, exist_ok=True) + if os.path.exists(agg_path) and os.path.exists(slt_path) and generate_if_exists: + print( + "File {} and {} already existed exists. Skip Generating.".format( + slt_path, agg_path + ) + ) + return + with open(agg_path, "w") as agg_file, open(slt_path, "w") as slt_file: + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE IF EXISTS {};\n".format(table_name)) + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write( + "CREATE TABLE {} (c1 int, c2 float);\n".format(table_name) + ) + + # select count(*) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(*) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(0)) + slt_file.write("\n") + + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write( + "COPY {} FROM '{}' WITH ( DELIMITER ',' );\n".format( + table_name, copy_path + ) + ) + + sequence = np.arange(1, row_n+1) + + for i in sequence: + agg_file.write(str(i) + "," + str(i)) + agg_file.write("\n") + + + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write("SELECT c1 FROM {};\n".format(table_name)) + slt_file.write("\n") + + + # select max(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT max(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(row_n)) + slt_file.write("\n") + + # select min(c2) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT min(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(1)) + slt_file.write("\n") + + + # select sum(c1) from test_simple_agg_big + + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT sum(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(np.sum(sequence))) + slt_file.write("\n") + + + # select avg(c1) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT AVG(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(np.mean(sequence))+"00000") + slt_file.write("\n") + + # select count(c1) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(c1) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(row_n)) + slt_file.write("\n") + + + # select count(*) from test_simple_agg_big + slt_file.write("\n") + slt_file.write("query I\n") + slt_file.write("SELECT count(*) FROM {};\n".format(table_name)) + slt_file.write("----\n") + slt_file.write(str(row_n)) + slt_file.write("\n") + + + slt_file.write("\n") + slt_file.write("statement ok\n") + slt_file.write("DROP TABLE {};\n".format(table_name)) + random.random() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate limit data for test") + + parser.add_argument( + "-g", + "--generate", + type=bool, + default=False, + dest="generate_if_exists", + ) + parser.add_argument( + "-c", + "--copy", + type=str, + default="/tmp/infinity/test_data", + dest="copy_dir", + ) + args = parser.parse_args() + generate(args.generate_if_exists, args.copy_dir) diff --git a/tools/sqllogictest.py b/tools/sqllogictest.py index 5c49dccfb1..0c4f99cb9b 100644 --- a/tools/sqllogictest.py +++ b/tools/sqllogictest.py @@ -6,6 +6,7 @@ from generate_fvecs import generate as generate2 from generate_sort import generate as generate3 from generate_limit import generate as generate4 +from generate_aggregate import generate as generate5 def python_skd_test(python_test_dir: str): @@ -108,6 +109,7 @@ def copy_all(data_dir, copy_dir): generate2(args.generate_if_exists, args.copy) generate3(args.generate_if_exists, args.copy) generate4(args.generate_if_exists, args.copy) + generate5(args.generate_if_exists, args.copy) print("Generate file finshed.") python_skd_test(python_test_dir) test_process(args.path, args.test, args.data, args.copy)