Skip to content

Commit

Permalink
Update bool type to store values as compact bits
Browse files Browse the repository at this point in the history
  • Loading branch information
yangzq50 committed Jan 3, 2024
1 parent deaac04 commit db25f1a
Show file tree
Hide file tree
Showing 28 changed files with 1,302 additions and 219 deletions.
58 changes: 27 additions & 31 deletions src/executor/expression/expression_selector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,51 +76,47 @@ void ExpressionSelector::Select(const SharedPtr<BaseExpression> &expr,
SizeT count,
SharedPtr<Selection> &output_true_select) {
SharedPtr<ColumnVector> bool_column = MakeShared<ColumnVector>(MakeShared<DataType>(LogicalType::kBoolean));
bool_column->Initialize();
bool_column->Initialize(ColumnVectorType::kCompactBit);

ExpressionEvaluator expr_evaluator;
expr_evaluator.Init(input_data_);
expr_evaluator.Execute(expr, state, bool_column);

const auto *bool_column_ptr = (const u8 *)(bool_column->data());
SharedPtr<Bitmask> &null_mask = bool_column->nulls_ptr_;

Select(bool_column_ptr, null_mask, count, output_true_select, true);
Select(bool_column, count, output_true_select, true);
}

void ExpressionSelector::Select(const u8 *__restrict bool_column,
const SharedPtr<Bitmask> &null_mask,
SizeT count,
SharedPtr<Selection> &output_true_select,
bool nullable) {
void ExpressionSelector::Select(const SharedPtr<ColumnVector> &bool_column, SizeT count, SharedPtr<Selection> &output_true_select, bool nullable) {
if (bool_column->vector_type() != ColumnVectorType::kCompactBit || bool_column->data_type()->type() != LogicalType::kBoolean) {
Error<ExecutorException>("Attempting to select non-boolean expression");
}
const auto &boolean_buffer = *(bool_column->buffer_);
const auto &null_mask = bool_column->nulls_ptr_;
if (nullable && !(null_mask->IsAllTrue())) {
const u64 *result_null_data = null_mask->GetData();
SizeT unit_count = BitmaskBuffer::UnitCount(count);
for (SizeT i = 0, start_index = 0, end_index = BitmaskBuffer::UNIT_BITS; i < unit_count; ++i, end_index += BitmaskBuffer::UNIT_BITS) {
end_index = Min(end_index, count);
if (result_null_data[i] == BitmaskBuffer::UNIT_MAX) {
// all data of 64 rows are not null
while (start_index < end_index) {
if (bool_column[start_index] > 0) {
output_true_select->Append(start_index);
}
++start_index;
const u64 *result_null_data = null_mask->GetData();
SizeT unit_count = BitmaskBuffer::UnitCount(count);
for (SizeT i = 0, start_index = 0, end_index = BitmaskBuffer::UNIT_BITS; i < unit_count; ++i, end_index += BitmaskBuffer::UNIT_BITS) {
end_index = Min(end_index, count);
if (result_null_data[i] == BitmaskBuffer::UNIT_MAX) {
// all data of 64 rows are not null
for (; start_index < end_index; ++start_index) {
if (boolean_buffer.GetCompactBit(start_index)) {
output_true_select->Append(start_index);
}
} else if (result_null_data[i] == BitmaskBuffer::UNIT_MIN) {
// all data of 64 rows are null
start_index = end_index;
} else {
while (start_index < end_index) {
if ((null_mask->IsTrue(start_index)) && (bool_column[start_index] > 0)) {
output_true_select->Append(start_index);
}
++start_index;
}
} else if (result_null_data[i] == BitmaskBuffer::UNIT_MIN) {
// all data of 64 rows are null
start_index = end_index;
} else {
for (; start_index < end_index; ++start_index) {
if ((null_mask->IsTrue(start_index)) && boolean_buffer.GetCompactBit(start_index)) {
output_true_select->Append(start_index);
}
}
}
}
} else {
for (SizeT idx = 0; idx < count; ++idx) {
if (bool_column[idx] > 0) {
if (boolean_buffer.GetCompactBit(idx)) {
output_true_select->Append(idx);
}
}
Expand Down
7 changes: 2 additions & 5 deletions src/executor/expression/expression_selector.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ExpressionState;
class DataBlock;
class Selection;
class Bitmask;
class ColumnVector;

export class ExpressionSelector {
public:
Expand All @@ -43,11 +44,7 @@ public:

void Select(const SharedPtr<BaseExpression> &expr, SharedPtr<ExpressionState> &state, SizeT count, SharedPtr<Selection> &output_true_select);

static void Select(const u8 *__restrict bool_column,
const SharedPtr<Bitmask> &null_mask,
SizeT count,
SharedPtr<Selection> &output_true_select,
bool nullable);
static void Select(const SharedPtr<ColumnVector> &bool_column, SizeT count, SharedPtr<Selection> &output_true_select, bool nullable);

private:
const DataBlock *input_data_{nullptr};
Expand Down
4 changes: 3 additions & 1 deletion src/executor/expression/expression_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ SharedPtr<ExpressionState> ExpressionState::CreateState(const SharedPtr<Function
if (result_is_constant) {
result->column_vector_->Initialize(ColumnVectorType::kConstant, DEFAULT_VECTOR_SIZE);
} else {
result->column_vector_->Initialize(ColumnVectorType::kFlat, DEFAULT_VECTOR_SIZE);
auto column_vector_type =
(function_expr_data_type->type() == LogicalType::kBoolean) ? ColumnVectorType::kCompactBit : ColumnVectorType::kFlat;
result->column_vector_->Initialize(column_vector_type, DEFAULT_VECTOR_SIZE);
}

// result->output_data_block_.Init({function_expr->Type()});
Expand Down
21 changes: 11 additions & 10 deletions src/executor/operator/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import base_table_ref;
import block_entry;
import knn_scan_data;
import column_buffer;
import vector_buffer;
import block_column_entry;
import knn_distance;
import third_party;
Expand Down Expand Up @@ -92,15 +93,15 @@ void ReadDataBlock(DataBlock *output,
output->Finalize();
}

void MergeIntoBitmask(const u8 *__restrict input_bool_column,
void MergeIntoBitmask(const VectorBuffer *input_bool_column_buffer,
const SharedPtr<Bitmask> &input_null_mask,
const SizeT count,
Bitmask &bitmask,
bool nullable,
SizeT bitmask_offset = 0) {
if ((!nullable) || (input_null_mask->IsAllTrue())) {
for (SizeT idx = 0; idx < count; ++idx) {
if (input_bool_column[idx] == 0) {
if (!(input_bool_column_buffer->GetCompactBit(idx))) {
bitmask.SetFalse(idx + bitmask_offset);
}
}
Expand All @@ -115,7 +116,7 @@ void MergeIntoBitmask(const u8 *__restrict input_bool_column,
if (result_null_data[i] == BitmaskBuffer::UNIT_MAX) {
// all data of 64 rows are not null
for (; start_index < end_index; ++start_index) {
if (input_bool_column[start_index] == 0) {
if (!(input_bool_column_buffer->GetCompactBit(start_index))) {
bitmask.SetFalse(start_index + bitmask_offset);
}
}
Expand All @@ -134,7 +135,7 @@ void MergeIntoBitmask(const u8 *__restrict input_bool_column,
}
} else {
for (; start_index < end_index; ++start_index) {
if (!(input_null_mask->IsTrue(start_index)) || (input_bool_column[start_index] == 0)) {
if (!(input_null_mask->IsTrue(start_index)) || !(input_bool_column_buffer->GetCompactBit(start_index))) {
bitmask.SetFalse(start_index + bitmask_offset);
}
}
Expand Down Expand Up @@ -265,13 +266,13 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
// filter and build bitmask, if filter_expression_ != nullptr
db_for_filter->Reset(row_count);
ReadDataBlock(db_for_filter, buffer_mgr, row_count, block_entry, base_table_ref_->column_ids_);
bool_column->Initialize(ColumnVectorType::kFlat, row_count);
bool_column->Initialize(ColumnVectorType::kCompactBit, row_count);
ExpressionEvaluator expr_evaluator;
expr_evaluator.Init(db_for_filter);
expr_evaluator.Execute(filter_expression_, filter_state_, bool_column);
const auto *bool_column_ptr = (const u8 *)(bool_column->data());
const VectorBuffer *bool_column_buffer = bool_column->buffer_.get();
SharedPtr<Bitmask> &null_mask = bool_column->nulls_ptr_;
MergeIntoBitmask(bool_column_ptr, null_mask, row_count, bitmask, true);
MergeIntoBitmask(bool_column_buffer, null_mask, row_count, bitmask, true);
bool_column->Reset();
}

Expand Down Expand Up @@ -314,12 +315,12 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
auto row_count = block_entry->row_count_;
db_for_filter->Reset(row_count);
ReadDataBlock(db_for_filter, buffer_mgr, row_count, block_entry.get(), base_table_ref_->column_ids_);
bool_column->Initialize(ColumnVectorType::kFlat, row_count);
bool_column->Initialize(ColumnVectorType::kCompactBit, row_count);
expr_evaluator.Init(db_for_filter);
expr_evaluator.Execute(filter_expression_, filter_state_, bool_column);
const auto *bool_column_ptr = (const u8 *)(bool_column->data());
const VectorBuffer *bool_column_buffer = bool_column->buffer_.get();
SharedPtr<Bitmask> &null_mask = bool_column->nulls_ptr_;
MergeIntoBitmask(bool_column_ptr, null_mask, row_count, bitmask, true, segment_row_count_real);
MergeIntoBitmask(bool_column_buffer, null_mask, row_count, bitmask, true, segment_row_count_real);
segment_row_count_real += row_count;
bool_column->Reset();
}
Expand Down
11 changes: 10 additions & 1 deletion src/executor/operator/physical_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ class Comparator {

switch (type.type()) {
case kBoolean: {
COMPARE(BooleanT)
auto bool_left = left_result_vector->buffer_->GetCompactBit(left_index.offset);
auto bool_right = right_result_vector->buffer_->GetCompactBit(left_index.offset);
if (bool_left == bool_right) {
continue;
}
if (order_type == OrderType::kAsc) {
return bool_left < bool_right;
} else {
return bool_left > bool_right;
}
}
case kTinyInt: {
COMPARE(TinyIntT)
Expand Down
4 changes: 3 additions & 1 deletion src/executor/physical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ void PhysicalOperator::InputLoad(QueryContext *query_context, OperatorState *ope
// Filling ColumnVector
for (SizeT j = 0; j < load_column_count; ++j) {
SharedPtr<ColumnVector> column_vector = ColumnVector::Make(load_metas[j].type_);
column_vector->Initialize(ColumnVectorType::kFlat, capacity);
auto column_vector_type =
(load_metas[j].type_->type() == LogicalType::kBoolean) ? ColumnVectorType::kCompactBit : ColumnVectorType::kFlat;
column_vector->Initialize(column_vector_type, capacity);

input_block->InsertVector(column_vector, load_metas[j].index_);
}
Expand Down
28 changes: 26 additions & 2 deletions src/function/aggregate_function.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
// limitations under the License.

module;

#include <type_traits>
import stl;
import function;
import function_data;
import column_vector;

import vector_buffer;
import infinity_exception;
import base_expression;
import parser;
Expand All @@ -43,6 +43,21 @@ public:
// Loop execute state update according to the input column vector

switch (input_column_vector->vector_type()) {
case ColumnVectorType::kCompactBit: {
if constexpr (!std::is_same_v<InputType, BooleanT>) {
Error<TypeException>("kCompactBit column vector only support Boolean type");
} else {
// only for count, min, max
SizeT row_count = input_column_vector->Size();
BooleanT value;
const VectorBuffer *buffer = input_column_vector->buffer_.get();
for (SizeT idx = 0; idx < row_count; ++idx) {
value = buffer->GetCompactBit(idx);
((AggregateState *)state)->Update(&value, 0);
}
}
break;
}
case ColumnVectorType::kFlat: {
SizeT row_count = input_column_vector->Size();
auto *input_ptr = (InputType *)(input_column_vector->data());
Expand All @@ -52,6 +67,15 @@ public:
break;
}
case ColumnVectorType::kConstant: {
if (input_column_vector->data_type()->type() == LogicalType::kBoolean) {
if constexpr (!std::is_same_v<InputType, BooleanT>) {
Error<TypeException>("types do not match");
} else {
BooleanT value = input_column_vector->buffer_->GetCompactBit(0);
((AggregateState *)state)->Update(&value, 0);
}
break;
}
auto *input_ptr = (InputType *)(input_column_vector->data());
((AggregateState *)state)->Update(input_ptr, 0);
break;
Expand Down
11 changes: 10 additions & 1 deletion src/function/scalar/and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <type_traits>
import stl;
import new_catalog;

Expand All @@ -30,7 +31,15 @@ namespace infinity {
struct AndFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = left and right;
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = left & right;
} else if constexpr (std::is_same_v<std::remove_cv_t<TA>, BooleanT> && std::is_same_v<std::remove_cv_t<TB>, BooleanT> &&
std::is_same_v<std::remove_cv_t<TC>, BooleanT>) {
result = left and right;
} else {
Error<TypeException>("AND function accepts only u8 and BooleanT.");
}
}
};

Expand Down
9 changes: 7 additions & 2 deletions src/function/scalar/equals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

module;

#include <type_traits>
import stl;
import new_catalog;

Expand All @@ -30,7 +30,12 @@ namespace infinity {
struct EqualsFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = (left == right);
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = ~(left ^ right);
} else {
result = (left == right);
}
}
};

Expand Down
9 changes: 7 additions & 2 deletions src/function/scalar/inequals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

module;

#include <cmath>
#include <type_traits>

import stl;
import new_catalog;
Expand All @@ -32,7 +32,12 @@ namespace infinity {
struct InEqualsFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = (left != right);
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = (left ^ right);
} else {
result = (left != right);
}
}
};

Expand Down
9 changes: 8 additions & 1 deletion src/function/scalar/not.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <type_traits>
import stl;
import new_catalog;

Expand All @@ -30,7 +31,13 @@ namespace infinity {
struct NotFunction {
template <typename TA, typename TB>
static inline void Run(TA input, TB &result) {
result = !input;
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8>) {
result = ~input;
} else if constexpr (std::is_same_v<std::remove_cv_t<TA>, BooleanT> && std::is_same_v<std::remove_cv_t<TB>, BooleanT>) {
result = !input;
} else {
Error<TypeException>("NOT function accepts only u8 and BooleanT.");
}
}
};

Expand Down
13 changes: 11 additions & 2 deletions src/function/scalar/or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

module;

#include <type_traits>
import stl;
import new_catalog;

import infinity_exception;
import scalar_function;
import scalar_function_set;
import parser;
import third_party;
// import third_party;

module or_func;

Expand All @@ -30,7 +31,15 @@ namespace infinity {
struct OrFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = left or right;
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = left | right;
} else if constexpr (std::is_same_v<std::remove_cv_t<TA>, BooleanT> && std::is_same_v<std::remove_cv_t<TB>, BooleanT> &&
std::is_same_v<std::remove_cv_t<TC>, BooleanT>) {
result = left or right;
} else {
Error<TypeException>("OR function accepts only u8 and BooleanT.");
}
}
};

Expand Down
Loading

0 comments on commit db25f1a

Please sign in to comment.