Skip to content

Commit

Permalink
For avg functions replacing it with (sum / count), count(*) (#399)
Browse files Browse the repository at this point in the history
* Fix top fragment source's type

* Implement PhysicalMergeAggregate operator execution

This commit introduces the execution functionality in the PhysicalMergeAggregate operator. It provides data procedures for various aggregate functions like COUNT, MIN, MAX, and SUM across different DataTypes. Additionally, proper handling for input and Output dataBlocks has been implemented. Consequently, various changes to the TaskScheduler, code refactoring and logging adjustments were also done.

* Update PhysicalMergeAggregate operator execution logic

The PhysicalMergeAggregate operator execution logic has been updated to correctly handle COUNT, MIN, MAX, and SUM aggregate functions for different DataTypes. This commit also implements appropriate handling for input and output dataBlocks. Concurrently, various changes are performed in the TaskScheduler. Several code sections were refactored for better readability and logs have been properly adjusted for detailed tracing.

* Remove unused functions in physical_merge_aggregate files

Deleted several unused template functions in `physical_merge_aggregate.cppm` and `physical_merge_aggregate.cpp`. Also removed a sizable chunk of unused logic related to execution of different aggregate functions from `PhysicalMergeAggregate::SimpleMergeAggregateExecute()` method. This declutters the codebase, making it easier to read and understand.

* Build new func expr

* Refactor merge aggregate and expression binder code

The PhysicalMergeAggregate class was refactored by introducing a new method, SimpleMergeAggregateExecute, to encapsulate some of the implementation details. The newly defined method simplifies the complexity of the code and makes the execute method easier to understand. ExpressionBinder class was also modified to improve clarity of the division function expression building process, simplifying it and making code more maintainable. Minor tweaks were also made to the fragment_context.cpp file to further enhance readability.

* For avg functions replacing it with (sum / count)

The commit involves converting avg function expressions to (sum / count) function expressions in the SQL planner and test updates accordingly. Tests cover both the regular and exceptional cases. Corresponding changes are reflected in other parts of the code like ProjectBinder and Aggregate operators.

* Fix memory leak

* Fix typo

* Refactor AVG function conversion in project binder

The conversion of AVG function to SUM / COUNT has been refactored for efficiency and readability. A helper function, ConvertAvgToSumDivideCount, has been introduced to modularize this logic. The updated code significantly shrinks the BuildExpression method, improving its maintainability.

* Update division functions and test cases

Modified the div functions to return results as Double type instead of previous specific integer types. Additionally, adjusted relevant test cases to match the changes. This allows for more accurate division results across the codebase.

* Implement count(*) functionality for SQL tables

A "count(*)" functionality was implemented, enabling users to count the number of rows of SQL tables. Several test scripts were modified to include a "SELECT count(*)" command immediately after table operations to verify the row count. The 'Update' and 'Delete' operations were adjusted in 'table_collection_entry.cpp', and 'StateUpdate' was modified in 'aggregate_function.cppm'. New 'CountStar' related files were added to implement the new functionality.
  • Loading branch information
loloxwg authored Jan 3, 2024
1 parent deaac04 commit 597cb05
Show file tree
Hide file tree
Showing 32 changed files with 806 additions and 126 deletions.
11 changes: 6 additions & 5 deletions src/executor/expression/expression_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ void ExpressionEvaluator::Execute(const SharedPtr<AggregateExpression> &expr,
// Create output chunk.
// TODO: Now output chunk is pre-allocate memory in expression state
// TODO: In the future, it can be implemented as on-demand allocation.
SharedPtr<ColumnVector> &child_output = child_state->OutputColumnVector();
Execute(child_expr, child_state, child_output);
SharedPtr<ColumnVector> &child_output_col = child_state->OutputColumnVector();
this->Execute(child_expr, child_state, child_output_col);

if (expr->aggregate_function_.argument_type_ != *child_output->data_type()) {
if (expr->aggregate_function_.argument_type_ != *child_output_col->data_type()) {
Error<ExecutorException>("Argument type isn't matched with the child expression output");
}
if (expr->aggregate_function_.return_type_ != *output_column_vector->data_type()) {
Expand All @@ -93,7 +93,7 @@ void ExpressionEvaluator::Execute(const SharedPtr<AggregateExpression> &expr,
expr->aggregate_function_.init_func_(expr->aggregate_function_.GetState());

// 2. Loop to fill the aggregate state
expr->aggregate_function_.update_func_(expr->aggregate_function_.GetState(), child_output);
expr->aggregate_function_.update_func_(expr->aggregate_function_.GetState(), child_output_col);

// 3. Get the aggregate result and append to output column vector.

Expand Down Expand Up @@ -157,7 +157,8 @@ void ExpressionEvaluator::Execute(const SharedPtr<ValueExpression> &expr,
SharedPtr<ExpressionState> &,
SharedPtr<ColumnVector> &output_column_vector) {
// memory copy here.
output_column_vector->SetValue(0, expr->GetValue());
auto value = expr->GetValue();
output_column_vector->SetValue(0, value);
output_column_vector->Finalize(1);
}

Expand Down
2 changes: 1 addition & 1 deletion src/executor/fragment_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ UniquePtr<PlanFragment> FragmentBuilder::BuildFragment(PhysicalOperator *phys_op
auto plan_fragment = MakeUnique<PlanFragment>(GetFragmentId());
plan_fragment->SetSinkNode(query_context_ptr_, SinkType::kResult, phys_op->GetOutputNames(), phys_op->GetOutputTypes());
BuildFragments(phys_op, plan_fragment.get());
if (plan_fragment->GetSourceNode() != nullptr) {
if (plan_fragment->GetSourceNode() == nullptr) {
plan_fragment->SetSourceNode(query_context_ptr_, SourceType::kEmpty, phys_op->GetOutputNames(), phys_op->GetOutputTypes());
}
return plan_fragment;
Expand Down
86 changes: 30 additions & 56 deletions src/executor/operator/physical_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
module;

#include <string>
#include <vector>
import stl;
import txn;
import query_context;
Expand Down Expand Up @@ -49,14 +50,15 @@ bool PhysicalAggregate::Execute(QueryContext *query_context, OperatorState *oper
Vector<SharedPtr<ColumnDef>> groupby_columns;
SizeT group_count = groups_.size();

if(group_count == 0) {
if (group_count == 0) {
// Aggregate without group by expression
// e.g. SELECT count(a) FROM table;
if (SimpleAggregate(this->output_, prev_op_state, aggregate_operator_state)) {
return true;
} else {
return false;
auto result = SimpleAggregateExecute(prev_op_state->data_block_array_, aggregate_operator_state->data_block_array_);
prev_op_state->data_block_array_.clear();
if (prev_op_state->Complete()) {
aggregate_operator_state->SetComplete();
}
return result;
}
#if 0
groupby_columns.reserve(group_count);
Expand Down Expand Up @@ -404,17 +406,19 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr<DataTable> &input_
}
case kVarchar: {
Error<NotImplementException>("Varchar data shuffle isn't implemented.");
// VarcharT &dst_ref = ((VarcharT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx];
// VarcharT &src_ref = ((VarcharT *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset];
// if (src_ref.IsInlined()) {
// Memcpy((char *)&dst_ref, (char *)&src_ref, sizeof(VarcharT));
// } else {
// dst_ref.length = src_ref.length;
// Memcpy(dst_ref.prefix, src_ref.prefix, VarcharT::PREFIX_LENGTH);
//
// dst_ref.ptr = output_datablock->column_vectors[column_id]->buffer_->fix_heap_mgr_->Allocate(src_ref.length);
// Memcpy(dst_ref.ptr, src_ref.ptr, src_ref.length);
// }
// VarcharT &dst_ref = ((VarcharT *)(output_datablock->column_vectors[column_id]->data()))[block_row_idx];
// VarcharT &src_ref = ((VarcharT
// *)(input_datablocks[input_block_id]->column_vectors[column_id]->data()))[input_offset]; if
// (src_ref.IsInlined()) {
// Memcpy((char *)&dst_ref, (char *)&src_ref, sizeof(VarcharT));
// } else {
// dst_ref.length = src_ref.length;
// Memcpy(dst_ref.prefix, src_ref.prefix, VarcharT::PREFIX_LENGTH);
//
// dst_ref.ptr =
// output_datablock->column_vectors[column_id]->buffer_->fix_heap_mgr_->Allocate(src_ref.length);
// Memcpy(dst_ref.ptr, src_ref.ptr, src_ref.length);
// }
break;
}
case kDate: {
Expand Down Expand Up @@ -559,9 +563,7 @@ void PhysicalAggregate::GenerateGroupByResult(const SharedPtr<DataTable> &input_
#endif
}

bool PhysicalAggregate::SimpleAggregate(SharedPtr<DataTable> &output_table,
OperatorState *pre_operator_state,
AggregateOperatorState *aggregate_operator_state) {
bool PhysicalAggregate::SimpleAggregateExecute(const Vector<UniquePtr<DataBlock>> &input_blocks, Vector<UniquePtr<DataBlock>> &output_blocks) {
SizeT aggregates_count = aggregates_.size();
if (aggregates_count <= 0) {
Error<ExecutorException>("Simple Aggregate without aggregate expression.");
Expand All @@ -579,19 +581,16 @@ bool PhysicalAggregate::SimpleAggregate(SharedPtr<DataTable> &output_table,
Vector<SharedPtr<DataType>> output_types;
output_types.reserve(aggregates_count);

SizeT input_block_count = pre_operator_state->data_block_array_.size();
SizeT input_block_count = input_blocks.size();

for (i64 idx = 0; auto &expr: aggregates_) {
for (i64 idx = 0; auto &expr : aggregates_) {
// expression state
expr_states.emplace_back(ExpressionState::CreateState(expr));

SharedPtr<DataType> output_type = MakeShared<DataType>(expr->Type());

// column definition
SharedPtr<ColumnDef> col_def = MakeShared<ColumnDef>(idx,
output_type,
expr->Name(),
HashSet<ConstraintType>());
SharedPtr<ColumnDef> col_def = MakeShared<ColumnDef>(idx, output_type, expr->Name(), HashSet<ConstraintType>());
aggregate_columns.emplace_back(col_def);

// for output block
Expand All @@ -605,51 +604,26 @@ bool PhysicalAggregate::SimpleAggregate(SharedPtr<DataTable> &output_table,
LOG_TRACE("No input, no aggregate result");
return true;
}
// Loop blocks

// ExpressionEvaluator evaluator;
// //evaluator.Init(input_table_->data_blocks_);
// for (SizeT expr_idx = 0; expr_idx < aggregates_count; ++expr_idx) {
//
// ExpressionEvaluator evaluator;
// evaluator.Init(aggregates_[])
// Vector<SharedPtr<ColumnVector>> blocks_column;
// blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]);
// evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], blocks_column[expr_idx]);
// if(blocks_column[0].get() != output_data_block->column_vectors[expr_idx].get()) {
// // column vector in blocks column might be changed to the column vector from column reference.
// // This check and assignment is to make sure the right column vector are assign to output_data_block
// output_data_block->column_vectors[expr_idx] = blocks_column[0];
// }
// }
//
// output_data_block->Finalize();

for (SizeT block_idx = 0; block_idx < input_block_count; ++block_idx) {
DataBlock *input_data_block = pre_operator_state->data_block_array_[block_idx].get();
DataBlock *input_data_block = input_blocks[block_idx].get();

output_blocks.emplace_back(DataBlock::MakeUniquePtr());

aggregate_operator_state->data_block_array_.emplace_back(DataBlock::MakeUniquePtr());
DataBlock *output_data_block = aggregate_operator_state->data_block_array_.back().get();
DataBlock *output_data_block = output_blocks.back().get();
output_data_block->Init(*GetOutputTypes());

ExpressionEvaluator evaluator;
evaluator.Init(input_data_block);

SizeT expression_count = aggregates_count;
// Prepare the expression states

// calculate every columns value
for (SizeT expr_idx = 0; expr_idx < expression_count; ++expr_idx) {
// Vector<SharedPtr<ColumnVector>> blocks_column;
// blocks_column.emplace_back(output_data_block->column_vectors[expr_idx]);
LOG_TRACE("Physical aggregate Execute");
evaluator.Execute(aggregates_[expr_idx], expr_states[expr_idx], output_data_block->column_vectors[expr_idx]);
}
output_data_block->Finalize();
}

pre_operator_state->data_block_array_.clear();
if (pre_operator_state->Complete()) {
aggregate_operator_state->SetComplete();
}
return true;
}

Expand Down
9 changes: 4 additions & 5 deletions src/executor/operator/physical_aggregate.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import hash_table;
import base_expression;
import load_meta;
import infinity_exception;
import data_block;

export module physical_aggregate;

Expand All @@ -44,8 +45,8 @@ public:
Vector<SharedPtr<BaseExpression>> aggregates,
u64 aggregate_index,
SharedPtr<Vector<LoadMeta>> load_metas)
: PhysicalOperator(PhysicalOperatorType::kAggregate, Move(left), nullptr, id, load_metas), groups_(Move(groups)), aggregates_(Move(aggregates)),
groupby_index_(groupby_index), aggregate_index_(aggregate_index) {}
: PhysicalOperator(PhysicalOperatorType::kAggregate, Move(left), nullptr, id, load_metas), groups_(Move(groups)),
aggregates_(Move(aggregates)), groupby_index_(groupby_index), aggregate_index_(aggregate_index) {}

~PhysicalAggregate() override = default;

Expand All @@ -66,9 +67,7 @@ public:
Vector<SharedPtr<BaseExpression>> aggregates_{};
HashTable hash_table_;

bool SimpleAggregate(SharedPtr<DataTable> &output_table,
OperatorState *pre_operator_state,
AggregateOperatorState *aggregate_operator_state);
bool SimpleAggregateExecute(const Vector<UniquePtr<DataBlock>> &input_blocks, Vector<UniquePtr<DataBlock>> &output_blocks);

inline u64 GroupTableIndex() const { return groupby_index_; }

Expand Down
147 changes: 146 additions & 1 deletion src/executor/operator/physical_merge_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,164 @@

module;

#include <string>
#include <vector>
import stl;
import query_context;
import operator_state;
import infinity_exception;
import logger;
import value;
import data_block;
import parser;
import physical_aggregate;
import aggregate_expression;

module physical_merge_aggregate;

namespace infinity {

template <typename T>
using MathOperation = StdFunction<T(T, T)>;

void PhysicalMergeAggregate::Init() {}

bool PhysicalMergeAggregate::Execute(QueryContext *query_context, OperatorState *operator_state) {
Error<NotImplementException>("Not Implement");
LOG_TRACE("PhysicalMergeAggregate::Execute:: mark");
auto merge_aggregate_op_state = static_cast<MergeAggregateOperatorState *>(operator_state);

SimpleMergeAggregateExecute(merge_aggregate_op_state);

if (merge_aggregate_op_state->input_complete_) {

LOG_TRACE("PhysicalMergeAggregate::Input is complete");
for (auto &output_block : merge_aggregate_op_state->data_block_array_) {
output_block->Finalize();
}

merge_aggregate_op_state->SetComplete();
return true;
}

return false;
}

void PhysicalMergeAggregate::SimpleMergeAggregateExecute(MergeAggregateOperatorState *op_state) {
if (op_state->data_block_array_.empty()) {
op_state->data_block_array_.emplace_back(Move(op_state->input_data_block_));
} else {
auto agg_op = dynamic_cast<PhysicalAggregate *>(this->left());
auto aggs_size = agg_op->aggregates_.size();
for (SizeT col_idx = 0; col_idx < aggs_size; ++col_idx) {
auto agg_expression = static_cast<AggregateExpression *>(agg_op->aggregates_[col_idx].get());

auto function_name = agg_expression->aggregate_function_.GetFuncName();

auto func_return_type = agg_expression->aggregate_function_.return_type_;

switch (func_return_type.type()) {
case kInteger: {
HandleAggregateFunction<IntegerT>(function_name, op_state, col_idx);
break;
}
case kBigInt: {
HandleAggregateFunction<BigIntT>(function_name, op_state, col_idx);
break;
}
case kFloat: {
HandleAggregateFunction<FloatT>(function_name, op_state, col_idx);
break;
}
case kDouble: {
HandleAggregateFunction<DoubleT>(function_name, op_state, col_idx);
break;
}
default:
Error<NotImplementException>("input_value_type not Implement");
}
}
}
}

template <typename T>
void PhysicalMergeAggregate::HandleAggregateFunction(const String &function_name, MergeAggregateOperatorState *op_state, SizeT col_idx) {
if (String(function_name) == String("COUNT")) {
LOG_TRACE("COUNT");
HandleCount<T>(op_state, col_idx);
} else if (String(function_name) == String("MIN")) {
LOG_TRACE("MIN");
HandleMin<T>(op_state, col_idx);
} else if (String(function_name) == String("MAX")) {
LOG_TRACE("MAX");
HandleMax<T>(op_state, col_idx);
} else if (String(function_name) == String("SUM")) {
LOG_TRACE("SUM");
HandleSum<T>(op_state, col_idx);
}
}

template <typename T>
void PhysicalMergeAggregate::HandleMin(MergeAggregateOperatorState *op_state, SizeT col_idx) {
MathOperation<T> minOperation = [](T a, T b) -> T { return (a < b) ? a : b; };
UpdateData<T>(op_state, minOperation, col_idx);
}

template <typename T>
void PhysicalMergeAggregate::HandleMax(MergeAggregateOperatorState *op_state, SizeT col_idx) {
MathOperation<T> maxOperation = [](T a, T b) -> T { return (a > b) ? a : b; };
UpdateData<T>(op_state, maxOperation, col_idx);
}

template <typename T>
void PhysicalMergeAggregate::HandleCount(MergeAggregateOperatorState *op_state, SizeT col_idx) {
MathOperation<T> countOperation = [](T a, T b) -> T { return a + b; };
UpdateData<T>(op_state, countOperation, col_idx);
}

template <typename T>
void PhysicalMergeAggregate::HandleSum(MergeAggregateOperatorState *op_state, SizeT col_idx) {
MathOperation<T> sumOperation = [](T a, T b) -> T { return a + b; };
UpdateData<T>(op_state, sumOperation, col_idx);
}

template <typename T>
T PhysicalMergeAggregate::GetInputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx) {
Value value = op_state->input_data_block_->GetValue(col_idx, row_idx);
return value.GetValue<T>();
}

template <typename T>
T PhysicalMergeAggregate::GetOutputData(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx) {
Value value = op_state->data_block_array_[block_index]->GetValue(col_idx, row_idx);
return value.GetValue<T>();
}

template <typename T>
void PhysicalMergeAggregate::WriteValueAtPosition(MergeAggregateOperatorState *op_state, SizeT block_index, SizeT col_idx, SizeT row_idx, T value) {
op_state->data_block_array_[block_index]->SetValue(col_idx, row_idx, CreateValue(value));
}

template <typename T>
void PhysicalMergeAggregate::UpdateData(MergeAggregateOperatorState *op_state, MathOperation<T> operation, SizeT col_idx) {
T input = GetInputData<T>(op_state, 0, col_idx, 0);
T output = GetOutputData<T>(op_state, 0, col_idx, 0);
T new_value = operation(input, output);
WriteValueAtPosition<T>(op_state, 0, col_idx, 0, new_value);
}

template <typename T>
T PhysicalMergeAggregate::AddData(T a, T b) {
return a + b;
}

template <typename T>
T PhysicalMergeAggregate::MinValue(T a, T b) {
return (a < b) ? a : b;
}

template <typename T>
T PhysicalMergeAggregate::MaxValue(T a, T b) {
return (a > b) ? a : b;
}

} // namespace infinity
Loading

0 comments on commit 597cb05

Please sign in to comment.