Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update bool type to store values as compact bits #406

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 27 additions & 31 deletions src/executor/expression/expression_selector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,51 +76,47 @@ void ExpressionSelector::Select(const SharedPtr<BaseExpression> &expr,
SizeT count,
SharedPtr<Selection> &output_true_select) {
SharedPtr<ColumnVector> bool_column = MakeShared<ColumnVector>(MakeShared<DataType>(LogicalType::kBoolean));
bool_column->Initialize();
bool_column->Initialize(ColumnVectorType::kCompactBit);

ExpressionEvaluator expr_evaluator;
expr_evaluator.Init(input_data_);
expr_evaluator.Execute(expr, state, bool_column);

const auto *bool_column_ptr = (const u8 *)(bool_column->data());
SharedPtr<Bitmask> &null_mask = bool_column->nulls_ptr_;

Select(bool_column_ptr, null_mask, count, output_true_select, true);
Select(bool_column, count, output_true_select, true);
}

void ExpressionSelector::Select(const u8 *__restrict bool_column,
const SharedPtr<Bitmask> &null_mask,
SizeT count,
SharedPtr<Selection> &output_true_select,
bool nullable) {
void ExpressionSelector::Select(const SharedPtr<ColumnVector> &bool_column, SizeT count, SharedPtr<Selection> &output_true_select, bool nullable) {
if (bool_column->vector_type() != ColumnVectorType::kCompactBit || bool_column->data_type()->type() != LogicalType::kBoolean) {
Error<ExecutorException>("Attempting to select non-boolean expression");
}
const auto &boolean_buffer = *(bool_column->buffer_);
const auto &null_mask = bool_column->nulls_ptr_;
if (nullable && !(null_mask->IsAllTrue())) {
const u64 *result_null_data = null_mask->GetData();
SizeT unit_count = BitmaskBuffer::UnitCount(count);
for (SizeT i = 0, start_index = 0, end_index = BitmaskBuffer::UNIT_BITS; i < unit_count; ++i, end_index += BitmaskBuffer::UNIT_BITS) {
end_index = Min(end_index, count);
if (result_null_data[i] == BitmaskBuffer::UNIT_MAX) {
// all data of 64 rows are not null
while (start_index < end_index) {
if (bool_column[start_index] > 0) {
output_true_select->Append(start_index);
}
++start_index;
const u64 *result_null_data = null_mask->GetData();
SizeT unit_count = BitmaskBuffer::UnitCount(count);
for (SizeT i = 0, start_index = 0, end_index = BitmaskBuffer::UNIT_BITS; i < unit_count; ++i, end_index += BitmaskBuffer::UNIT_BITS) {
end_index = Min(end_index, count);
if (result_null_data[i] == BitmaskBuffer::UNIT_MAX) {
// all data of 64 rows are not null
for (; start_index < end_index; ++start_index) {
if (boolean_buffer.GetCompactBit(start_index)) {
output_true_select->Append(start_index);
}
} else if (result_null_data[i] == BitmaskBuffer::UNIT_MIN) {
// all data of 64 rows are null
start_index = end_index;
} else {
while (start_index < end_index) {
if ((null_mask->IsTrue(start_index)) && (bool_column[start_index] > 0)) {
output_true_select->Append(start_index);
}
++start_index;
}
} else if (result_null_data[i] == BitmaskBuffer::UNIT_MIN) {
// all data of 64 rows are null
start_index = end_index;
} else {
for (; start_index < end_index; ++start_index) {
if ((null_mask->IsTrue(start_index)) && boolean_buffer.GetCompactBit(start_index)) {
output_true_select->Append(start_index);
}
}
}
}
} else {
for (SizeT idx = 0; idx < count; ++idx) {
if (bool_column[idx] > 0) {
if (boolean_buffer.GetCompactBit(idx)) {
output_true_select->Append(idx);
}
}
Expand Down
7 changes: 2 additions & 5 deletions src/executor/expression/expression_selector.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ExpressionState;
class DataBlock;
class Selection;
class Bitmask;
class ColumnVector;

export class ExpressionSelector {
public:
Expand All @@ -43,11 +44,7 @@ public:

void Select(const SharedPtr<BaseExpression> &expr, SharedPtr<ExpressionState> &state, SizeT count, SharedPtr<Selection> &output_true_select);

static void Select(const u8 *__restrict bool_column,
const SharedPtr<Bitmask> &null_mask,
SizeT count,
SharedPtr<Selection> &output_true_select,
bool nullable);
static void Select(const SharedPtr<ColumnVector> &bool_column, SizeT count, SharedPtr<Selection> &output_true_select, bool nullable);

private:
const DataBlock *input_data_{nullptr};
Expand Down
4 changes: 3 additions & 1 deletion src/executor/expression/expression_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ SharedPtr<ExpressionState> ExpressionState::CreateState(const SharedPtr<Function
if (result_is_constant) {
result->column_vector_->Initialize(ColumnVectorType::kConstant, DEFAULT_VECTOR_SIZE);
} else {
result->column_vector_->Initialize(ColumnVectorType::kFlat, DEFAULT_VECTOR_SIZE);
auto column_vector_type =
(function_expr_data_type->type() == LogicalType::kBoolean) ? ColumnVectorType::kCompactBit : ColumnVectorType::kFlat;
result->column_vector_->Initialize(column_vector_type, DEFAULT_VECTOR_SIZE);
}

// result->output_data_block_.Init({function_expr->Type()});
Expand Down
21 changes: 11 additions & 10 deletions src/executor/operator/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import base_table_ref;
import block_entry;
import knn_scan_data;
import column_buffer;
import vector_buffer;
import block_column_entry;
import knn_distance;
import third_party;
Expand Down Expand Up @@ -92,15 +93,15 @@ void ReadDataBlock(DataBlock *output,
output->Finalize();
}

void MergeIntoBitmask(const u8 *__restrict input_bool_column,
void MergeIntoBitmask(const VectorBuffer *input_bool_column_buffer,
const SharedPtr<Bitmask> &input_null_mask,
const SizeT count,
Bitmask &bitmask,
bool nullable,
SizeT bitmask_offset = 0) {
if ((!nullable) || (input_null_mask->IsAllTrue())) {
for (SizeT idx = 0; idx < count; ++idx) {
if (input_bool_column[idx] == 0) {
if (!(input_bool_column_buffer->GetCompactBit(idx))) {
bitmask.SetFalse(idx + bitmask_offset);
}
}
Expand All @@ -115,7 +116,7 @@ void MergeIntoBitmask(const u8 *__restrict input_bool_column,
if (result_null_data[i] == BitmaskBuffer::UNIT_MAX) {
// all data of 64 rows are not null
for (; start_index < end_index; ++start_index) {
if (input_bool_column[start_index] == 0) {
if (!(input_bool_column_buffer->GetCompactBit(start_index))) {
bitmask.SetFalse(start_index + bitmask_offset);
}
}
Expand All @@ -134,7 +135,7 @@ void MergeIntoBitmask(const u8 *__restrict input_bool_column,
}
} else {
for (; start_index < end_index; ++start_index) {
if (!(input_null_mask->IsTrue(start_index)) || (input_bool_column[start_index] == 0)) {
if (!(input_null_mask->IsTrue(start_index)) || !(input_bool_column_buffer->GetCompactBit(start_index))) {
bitmask.SetFalse(start_index + bitmask_offset);
}
}
Expand Down Expand Up @@ -265,13 +266,13 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
// filter and build bitmask, if filter_expression_ != nullptr
db_for_filter->Reset(row_count);
ReadDataBlock(db_for_filter, buffer_mgr, row_count, block_entry, base_table_ref_->column_ids_);
bool_column->Initialize(ColumnVectorType::kFlat, row_count);
bool_column->Initialize(ColumnVectorType::kCompactBit, row_count);
ExpressionEvaluator expr_evaluator;
expr_evaluator.Init(db_for_filter);
expr_evaluator.Execute(filter_expression_, filter_state_, bool_column);
const auto *bool_column_ptr = (const u8 *)(bool_column->data());
const VectorBuffer *bool_column_buffer = bool_column->buffer_.get();
SharedPtr<Bitmask> &null_mask = bool_column->nulls_ptr_;
MergeIntoBitmask(bool_column_ptr, null_mask, row_count, bitmask, true);
MergeIntoBitmask(bool_column_buffer, null_mask, row_count, bitmask, true);
bool_column->Reset();
}

Expand Down Expand Up @@ -314,12 +315,12 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
auto row_count = block_entry->row_count_;
db_for_filter->Reset(row_count);
ReadDataBlock(db_for_filter, buffer_mgr, row_count, block_entry.get(), base_table_ref_->column_ids_);
bool_column->Initialize(ColumnVectorType::kFlat, row_count);
bool_column->Initialize(ColumnVectorType::kCompactBit, row_count);
expr_evaluator.Init(db_for_filter);
expr_evaluator.Execute(filter_expression_, filter_state_, bool_column);
const auto *bool_column_ptr = (const u8 *)(bool_column->data());
const VectorBuffer *bool_column_buffer = bool_column->buffer_.get();
SharedPtr<Bitmask> &null_mask = bool_column->nulls_ptr_;
MergeIntoBitmask(bool_column_ptr, null_mask, row_count, bitmask, true, segment_row_count_real);
MergeIntoBitmask(bool_column_buffer, null_mask, row_count, bitmask, true, segment_row_count_real);
segment_row_count_real += row_count;
bool_column->Reset();
}
Expand Down
11 changes: 10 additions & 1 deletion src/executor/operator/physical_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ class Comparator {

switch (type.type()) {
case kBoolean: {
COMPARE(BooleanT)
auto bool_left = left_result_vector->buffer_->GetCompactBit(left_index.offset);
auto bool_right = right_result_vector->buffer_->GetCompactBit(left_index.offset);
if (bool_left == bool_right) {
continue;
}
if (order_type == OrderType::kAsc) {
return bool_left < bool_right;
} else {
return bool_left > bool_right;
}
}
case kTinyInt: {
COMPARE(TinyIntT)
Expand Down
4 changes: 3 additions & 1 deletion src/executor/physical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ void PhysicalOperator::InputLoad(QueryContext *query_context, OperatorState *ope
// Filling ColumnVector
for (SizeT j = 0; j < load_column_count; ++j) {
SharedPtr<ColumnVector> column_vector = ColumnVector::Make(load_metas[j].type_);
column_vector->Initialize(ColumnVectorType::kFlat, capacity);
auto column_vector_type =
(load_metas[j].type_->type() == LogicalType::kBoolean) ? ColumnVectorType::kCompactBit : ColumnVectorType::kFlat;
column_vector->Initialize(column_vector_type, capacity);

input_block->InsertVector(column_vector, load_metas[j].index_);
}
Expand Down
28 changes: 26 additions & 2 deletions src/function/aggregate_function.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
// limitations under the License.

module;

#include <type_traits>
import stl;
import function;
import function_data;
import column_vector;

import vector_buffer;
import infinity_exception;
import base_expression;
import parser;
Expand All @@ -43,6 +43,21 @@ public:
// Loop execute state update according to the input column vector

switch (input_column_vector->vector_type()) {
case ColumnVectorType::kCompactBit: {
if constexpr (!std::is_same_v<InputType, BooleanT>) {
Error<TypeException>("kCompactBit column vector only support Boolean type");
} else {
// only for count, min, max
SizeT row_count = input_column_vector->Size();
BooleanT value;
const VectorBuffer *buffer = input_column_vector->buffer_.get();
for (SizeT idx = 0; idx < row_count; ++idx) {
value = buffer->GetCompactBit(idx);
((AggregateState *)state)->Update(&value, 0);
}
}
break;
}
case ColumnVectorType::kFlat: {
SizeT row_count = input_column_vector->Size();
auto *input_ptr = (InputType *)(input_column_vector->data());
Expand All @@ -52,6 +67,15 @@ public:
break;
}
case ColumnVectorType::kConstant: {
if (input_column_vector->data_type()->type() == LogicalType::kBoolean) {
if constexpr (!std::is_same_v<InputType, BooleanT>) {
Error<TypeException>("types do not match");
} else {
BooleanT value = input_column_vector->buffer_->GetCompactBit(0);
((AggregateState *)state)->Update(&value, 0);
}
break;
}
auto *input_ptr = (InputType *)(input_column_vector->data());
((AggregateState *)state)->Update(input_ptr, 0);
break;
Expand Down
11 changes: 10 additions & 1 deletion src/function/scalar/and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <type_traits>
import stl;
import new_catalog;

Expand All @@ -30,7 +31,15 @@ namespace infinity {
struct AndFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = left and right;
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = left & right;
} else if constexpr (std::is_same_v<std::remove_cv_t<TA>, BooleanT> && std::is_same_v<std::remove_cv_t<TB>, BooleanT> &&
std::is_same_v<std::remove_cv_t<TC>, BooleanT>) {
result = left and right;
} else {
Error<TypeException>("AND function accepts only u8 and BooleanT.");
}
}
};

Expand Down
9 changes: 7 additions & 2 deletions src/function/scalar/equals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

module;

#include <type_traits>
import stl;
import new_catalog;

Expand All @@ -30,7 +30,12 @@ namespace infinity {
struct EqualsFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = (left == right);
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = ~(left ^ right);
} else {
result = (left == right);
}
}
};

Expand Down
9 changes: 7 additions & 2 deletions src/function/scalar/inequals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

module;

#include <cmath>
#include <type_traits>

import stl;
import new_catalog;
Expand All @@ -32,7 +32,12 @@ namespace infinity {
struct InEqualsFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = (left != right);
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = (left ^ right);
} else {
result = (left != right);
}
}
};

Expand Down
9 changes: 8 additions & 1 deletion src/function/scalar/not.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <type_traits>
import stl;
import new_catalog;

Expand All @@ -30,7 +31,13 @@ namespace infinity {
struct NotFunction {
template <typename TA, typename TB>
static inline void Run(TA input, TB &result) {
result = !input;
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8>) {
result = ~input;
} else if constexpr (std::is_same_v<std::remove_cv_t<TA>, BooleanT> && std::is_same_v<std::remove_cv_t<TB>, BooleanT>) {
result = !input;
} else {
Error<TypeException>("NOT function accepts only u8 and BooleanT.");
}
}
};

Expand Down
13 changes: 11 additions & 2 deletions src/function/scalar/or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

module;

#include <type_traits>
import stl;
import new_catalog;

import infinity_exception;
import scalar_function;
import scalar_function_set;
import parser;
import third_party;
// import third_party;

module or_func;

Expand All @@ -30,7 +31,15 @@ namespace infinity {
struct OrFunction {
template <typename TA, typename TB, typename TC>
static inline void Run(TA left, TB right, TC &result) {
result = left or right;
if constexpr (std::is_same_v<std::remove_cv_t<TA>, u8> && std::is_same_v<std::remove_cv_t<TB>, u8> &&
std::is_same_v<std::remove_cv_t<TC>, u8>) {
result = left | right;
} else if constexpr (std::is_same_v<std::remove_cv_t<TA>, BooleanT> && std::is_same_v<std::remove_cv_t<TB>, BooleanT> &&
std::is_same_v<std::remove_cv_t<TC>, BooleanT>) {
result = left or right;
} else {
Error<TypeException>("OR function accepts only u8 and BooleanT.");
}
}
};

Expand Down
Loading
Loading