From f7dbfa26925865234a20e62cd643f1a11a717993 Mon Sep 17 00:00:00 2001 From: shen yushi Date: Wed, 27 Dec 2023 14:09:53 +0800 Subject: [PATCH] Build KNN index in parallel (#386) * Add fragment for parallel create index. * Add concurrent insert hnsw alg. Add simd func for residual dimension embedding. Merge with yangzq. Remove some heap op in hnsw alg. * Feat: Add create_index_prepare/do_finish. * Fix create index source/sink bug. Fix round robin bug. * Add tmp schedule strategy. --- .gitignore | 2 + benchmark/embedding/hnsw_benchmark2.cpp | 79 ++-- .../knn/knn_import_benchmark.cpp | 22 +- .../knn/knn_query_benchmark.cpp | 33 +- src/common/stl.cppm | 10 + src/executor/fragment_builder.cpp | 45 ++ src/executor/fragment_builder.cppm | 2 +- .../operator/physical_create_index_do.cpp | 191 ++++++++ .../operator/physical_create_index_do.cppm | 62 +++ .../operator/physical_create_index_finish.cpp | 57 +++ .../physical_create_index_finish.cppm | 60 +++ .../physical_create_index_prepare.cpp | 200 +++++++++ .../physical_create_index_prepare.cppm | 62 +++ src/executor/operator/physical_knn_scan.cpp | 86 +--- src/executor/operator/physical_sink.cpp | 51 ++- src/executor/operator_state.cpp | 53 ++- src/executor/operator_state.cppm | 27 +- src/executor/physical_operator_type.cpp | 6 + src/executor/physical_operator_type.cppm | 4 + src/executor/physical_planner.cpp | 53 ++- src/function/table/create_index_data.cppm | 36 ++ src/function/table/knn_scan_data.cpp | 18 +- src/function/table/knn_scan_data.cppm | 2 +- src/main/resource_manager.cppm | 2 +- src/planner/bound/base_table_ref.cppm | 7 +- src/planner/explain_logical_plan.cpp | 4 +- src/planner/logical_planner.cpp | 6 +- src/planner/node/logical_create_index.cpp | 3 +- src/planner/node/logical_create_index.cppm | 21 +- src/planner/query_binder.cpp | 19 +- src/planner/query_binder.cppm | 2 + src/scheduler/fragment_context.cpp | 228 ++++++---- src/scheduler/fragment_context.cppm | 12 +- src/scheduler/fragment_data.cppm | 33 +- src/scheduler/task_scheduler.cpp | 20 +- .../knnindex/knn_hnsw/dist_func_l2.cppm | 8 +- .../knnindex/knn_hnsw/graph_store.cppm | 12 +- src/storage/knnindex/knn_hnsw/hnsw_alg.cppm | 408 ++++++------------ .../knnindex/knn_hnsw/hnsw_common.cppm | 1 + src/storage/knnindex/knn_hnsw/simd_func.cppm | 60 ++- src/storage/meta/entry/segment_entry.cpp | 16 - src/storage/meta/iter/segment_iter.cppm | 16 + src/storage/txn/txn.cpp | 30 ++ src/storage/txn/txn.cppm | 4 + .../knnindex/knn_hnsw/test_dist_func.cpp | 14 +- src/unit_test/test_hnsw.cpp | 35 +- src/unit_test/test_hnsw_bitmask.cpp | 106 ++--- 47 files changed, 1505 insertions(+), 723 deletions(-) create mode 100644 src/executor/operator/physical_create_index_do.cpp create mode 100644 src/executor/operator/physical_create_index_do.cppm create mode 100644 src/executor/operator/physical_create_index_finish.cpp create mode 100644 src/executor/operator/physical_create_index_finish.cppm create mode 100644 src/executor/operator/physical_create_index_prepare.cpp create mode 100644 src/executor/operator/physical_create_index_prepare.cppm create mode 100644 src/function/table/create_index_data.cppm diff --git a/.gitignore b/.gitignore index defa5f2a45..f6ab52614b 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,8 @@ parser.output #query_parser.h #query_parser.cpp +test/data/csv/test_sort.csv +test/sql/dql/sort.slt test/data/csv/*big*.csv test/sql/**/*big*.slt test/data/fvecs/ diff --git a/benchmark/embedding/hnsw_benchmark2.cpp b/benchmark/embedding/hnsw_benchmark2.cpp index de299ba891..f6b095b476 100644 --- a/benchmark/embedding/hnsw_benchmark2.cpp +++ b/benchmark/embedding/hnsw_benchmark2.cpp @@ -7,6 +7,7 @@ import stl; import hnsw_alg; +import hnsw_common; import local_file_system; import file_system_type; import file_system; @@ -19,9 +20,13 @@ import compilation_config; using namespace infinity; int main() { - String base_file = String(test_data_path()) + "/benchmark/sift/base.fvecs"; - String query_file = String(test_data_path()) + "/benchmark/sift/query.fvecs"; - String groundtruth_file = String(test_data_path()) + "/benchmark/sift/l2_groundtruth.ivecs"; + // String base_file = String(test_data_path()) + "/benchmark/text2image_10m/base.fvecs"; + // String query_file = String(test_data_path()) + "/benchmark/text2image_10m/query.fvecs"; + // String groundtruth_file = String(test_data_path()) + "/benchmark/text2image_10m/groundtruth.ivecs"; + + String base_file = String(test_data_path()) + "/benchmark/sift_1m/sift_base.fvecs"; + String query_file = String(test_data_path()) + "/benchmark/sift_1m/sift_query.fvecs"; + String groundtruth_file = String(test_data_path()) + "/benchmark/sift_1m/sift_groundtruth.ivecs"; LocalFileSystem fs; std::string save_dir = tmp_data_path(); @@ -31,7 +36,8 @@ int main() { size_t ef_construction = 200; size_t embedding_count = 1000000; size_t test_top = 100; - const int thread_n = 1; + const int build_thread_n = 1; + const int query_thread_n = 1; using LabelT = uint64_t; @@ -41,7 +47,7 @@ int main() { using Hnsw = KnnHnsw>, LVQL2Dist>; SizeT init_args = {0}; - std::string save_place = save_dir + "/my_sift_lvq8_l2.hnsw"; + std::string save_place = save_dir + "/my_sift_lvq8_l2_1.hnsw"; // using Hnsw = KnnHnsw, PlainIPDist>; // std::tuple<> init_args = {}; @@ -69,26 +75,36 @@ int main() { std::cout << "Begin memory cost: " << get_current_rss() << "B" << std::endl; profiler.Begin(); - if (false) { - for (size_t idx = 0; idx < embedding_count; ++idx) { - // insert data into index - const float *query = input_embeddings + idx * dimension; - knn_hnsw->Insert(query, idx); - if (idx % 100000 == 0) { - std::cout << idx << ", " << get_current_rss() << " B, " << profiler.ElapsedToString() << std::endl; - } - } - } else { + { + std::cout << "Build thread number: " << build_thread_n << std::endl; auto labels = std::make_unique(embedding_count); std::iota(labels.get(), labels.get() + embedding_count, 0); - knn_hnsw->Insert(input_embeddings, labels.get(), embedding_count); + VertexType start_i = knn_hnsw->StoreDataRaw(input_embeddings, labels.get(), embedding_count); + delete[] input_embeddings; + AtomicVtxType next_i = start_i; + std::vector threads; + for (int i = 0; i < build_thread_n; ++i) { + threads.emplace_back([&]() { + while (true) { + VertexType cur_i = next_i.fetch_add(1); + if (cur_i >= VertexType(start_i + embedding_count)) { + break; + } + knn_hnsw->Build(cur_i); + if (cur_i && cur_i % 10000 == 0) { + std::cout << "Inserted " << cur_i << " / " << embedding_count << std::endl; + } + } + }); + } + for (auto &thread : threads) { + thread.join(); + } } profiler.End(); std::cout << "Insert data cost: " << profiler.ElapsedToString() << " memory cost: " << get_current_rss() << "B" << std::endl; - delete[] input_embeddings; - uint8_t file_flags = FileFlags::WRITE_FLAG | FileFlags::CREATE_FLAG; std::unique_ptr file_handler = fs.OpenFile(save_place, file_flags, FileLockType::kWriteLock); knn_hnsw->Save(*file_handler); @@ -100,7 +116,13 @@ int main() { std::unique_ptr file_handler = fs.OpenFile(save_place, file_flags, FileLockType::kReadLock); knn_hnsw = Hnsw::Load(*file_handler, init_args); + std::cout << "Loaded" << std::endl; + + // std::ofstream out("dump.txt"); + // knn_hnsw->Dump(out); + // knn_hnsw->Check(); } + return 0; size_t number_of_queries; const float *queries = nullptr; @@ -114,11 +136,10 @@ int main() { Vector> ground_truth_sets; // number_of_queries * top_k matrix of ground-truth nearest-neighbors { - // size_t *ground_truth; // load ground-truth and convert int to long size_t nq2; int *gt_int = ivecs_read(groundtruth_file.c_str(), &top_k, &nq2); - assert(nq2 == number_of_queries || !"incorrect nb of ground truth entries"); + assert(nq2 >= number_of_queries || !"incorrect nb of ground truth entries"); assert(top_k >= test_top || !"dataset does not provide enough ground truth data"); ground_truth_sets.resize(number_of_queries); @@ -130,10 +151,10 @@ int main() { } infinity::BaseProfiler profiler; - int round = 10; - Vector>> results; - results.reserve(number_of_queries); - std::cout << "thread number: " << thread_n << std::endl; + std::cout << "Start!" << std::endl; + int round = 3; + Vector>> results(number_of_queries); + std::cout << "Query thread number: " << query_thread_n << std::endl; for (int ef = 100; ef <= 300; ef += 25) { knn_hnsw->SetEf(ef); int correct = 0; @@ -142,7 +163,7 @@ int main() { std::atomic_int idx(0); std::vector threads; profiler.Begin(); - for (int j = 0; j < thread_n; ++j) { + for (int j = 0; j < query_thread_n; ++j) { threads.emplace_back([&]() { while (true) { int cur_idx = idx.fetch_add(1); @@ -150,7 +171,7 @@ int main() { break; } const float *query = queries + cur_idx * dimension; - MaxHeap> result = knn_hnsw->KnnSearch(query, test_top); + auto result = knn_hnsw->KnnSearch(query, test_top); results[cur_idx] = std::move(result); } }); @@ -161,12 +182,10 @@ int main() { profiler.End(); if (i == 0) { for (size_t query_idx = 0; query_idx < number_of_queries; ++query_idx) { - auto &result = results[query_idx]; - while (!result.empty()) { - if (ground_truth_sets[idx].contains(result.top().second)) { + for (const auto &[_, label] : results[query_idx]) { + if (ground_truth_sets[query_idx].contains(label)) { ++correct; } - result.pop(); } } printf("Recall = %.4f\n", correct / float(test_top * number_of_queries)); diff --git a/benchmark/local_infinity/knn/knn_import_benchmark.cpp b/benchmark/local_infinity/knn/knn_import_benchmark.cpp index 2659d7cf32..fc4ed8e3ff 100644 --- a/benchmark/local_infinity/knn/knn_import_benchmark.cpp +++ b/benchmark/local_infinity/knn/knn_import_benchmark.cpp @@ -37,16 +37,11 @@ import query_result; using namespace infinity; -int main(int argc, char *argv[]) { - if (argc != 2) { - std::cout << "import sift or gist" << std::endl; - return 1; - } +int main() { bool sift = true; - if (strcmp(argv[1], "sift") && strcmp(argv[1], "gist")) { - return 1; - } - sift = strcmp(argv[1], "sift") == 0; + int M = 16; + int ef_construct = 200; + std::cout << "benchmark: " << (sift ? "sift" : "gist") << std::endl; std::string data_path = "/tmp/infinity"; @@ -75,13 +70,14 @@ int main(int argc, char *argv[]) { if (sift) { col1_type = std::make_shared(LogicalType::kEmbedding, std::make_shared(EmbeddingDataType::kElemFloat, 128)); base_path += "/benchmark/sift_1m/sift_base.fvecs"; - table_name = "sift_benchmark"; + table_name = "sift_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct); } else { col1_type = std::make_shared(LogicalType::kEmbedding, std::make_shared(EmbeddingDataType::kElemFloat, 960)); base_path += "/benchmark/gist_1m/gist_base.fvecs"; - table_name = "gist_benchmark"; + table_name = "gist_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct); } std::cout << "Import from: " << base_path << std::endl; + std::cout << "table_name: " << table_name << std::endl; std::string col1_name = "col1"; auto col1_def = std::make_unique(0, col1_type, col1_name, std::unordered_set()); @@ -123,8 +119,8 @@ int main(int argc, char *argv[]) { { auto index_param_list = new std::vector(); - index_param_list->emplace_back(new InitParameter("M", std::to_string(16))); - index_param_list->emplace_back(new InitParameter("ef_construction", std::to_string(200))); + index_param_list->emplace_back(new InitParameter("M", std::to_string(M))); + index_param_list->emplace_back(new InitParameter("ef_construction", std::to_string(ef_construct))); index_param_list->emplace_back(new InitParameter("ef", std::to_string(200))); index_param_list->emplace_back(new InitParameter("metric", "l2")); index_param_list->emplace_back(new InitParameter("encode", "lvq")); diff --git a/benchmark/local_infinity/knn/knn_query_benchmark.cpp b/benchmark/local_infinity/knn/knn_query_benchmark.cpp index 353fadb6c3..1c84e59207 100644 --- a/benchmark/local_infinity/knn/knn_query_benchmark.cpp +++ b/benchmark/local_infinity/knn/knn_query_benchmark.cpp @@ -84,24 +84,18 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn } } -int main(int argc, char *argv[]) { - if (argc != 3) { - std::cout << "query gist/sift ef=?" << std::endl; - return 1; - } +int main() { bool sift = true; - if (strcmp(argv[1], "sift") && strcmp(argv[1], "gist")) { - return 1; - } - sift = strcmp(argv[1], "sift") == 0; - size_t ef = std::stoull(argv[2]); - + int ef = 100; + int M = 16; + int ef_construct = 200; size_t thread_num = 1; - size_t total_times = 1; - std::cout << "Please input thread_num, 0 means use all resources:" << std::endl; - std::cin >> thread_num; - std::cout << "Please input total_times:" << std::endl; - std::cin >> total_times; + size_t total_times = 3; + + std::cout << "benchmark: " << (sift ? "sift" : "gist") << std::endl; + std::cout << "ef: " << ef << std::endl; + std::cout << "thread_n: " << thread_num << std::endl; + std::cout << "total_times: " << total_times << std::endl; std::string path = "/tmp/infinity"; LocalFileSystem fs; @@ -123,15 +117,16 @@ int main(int argc, char *argv[]) { dimension = 128; query_path += "/benchmark/sift_1m/sift_query.fvecs"; groundtruth_path += "/benchmark/sift_1m/sift_groundtruth.ivecs"; - table_name = "sift_benchmark"; + table_name = "sift_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct); } else { dimension = 960; query_path += "/benchmark/gist_1m/gist_query.fvecs"; groundtruth_path += "/benchmark/gist_1m/gist_groundtruth.ivecs"; - table_name = "gist_benchmark"; + table_name = "gist_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct); } std::cout << "query from: " << query_path << std::endl; std::cout << "groundtruth is: " << groundtruth_path << std::endl; + std::cout << "table_name: " << table_name << std::endl; if (!fs.Exists(query_path)) { std::cerr << "File: " << query_path << " doesn't exist" << std::endl; @@ -218,7 +213,7 @@ int main(int argc, char *argv[]) { query_results[query_idx].emplace_back(data[i].ToUint64()); } } -// delete[] embedding_data_ptr; + // delete[] embedding_data_ptr; }; BaseProfiler profiler; profiler.Begin(); diff --git a/src/common/stl.cppm b/src/common/stl.cppm index 3a06a8a037..616d4c339f 100644 --- a/src/common/stl.cppm +++ b/src/common/stl.cppm @@ -30,6 +30,7 @@ module; #include #include #include +#include #include #include #include @@ -211,6 +212,7 @@ export { using atomic_u32 = std::atomic_uint32_t; using atomic_u64 = std::atomic_uint64_t; using ai64 = std::atomic_int64_t; + using ai32 = std::atomic_int32_t; using aptr = std::atomic_uintptr_t; using atomic_bool = std::atomic_bool; @@ -359,6 +361,8 @@ export { template using LockGuard = std::lock_guard; + using TryToLock = std::try_to_lock_t; + constexpr std::memory_order MemoryOrderRelax = std::memory_order::relaxed; constexpr std::memory_order MemoryOrderConsume = std::memory_order::consume; constexpr std::memory_order MemoryOrderRelease = std::memory_order::release; @@ -433,4 +437,10 @@ struct CompareByFirst { bool operator()(const P &lhs, const P &rhs) const { return lhs.first < rhs.first; } }; +export template +struct CompareByFirstReverse { + using P = std::pair; + bool operator()(const P &lhs, const P &rhs) const { return lhs.first > rhs.first; } +}; + } // namespace infinity diff --git a/src/executor/fragment_builder.cpp b/src/executor/fragment_builder.cpp index 9e93409ef6..cf69d1d76b 100644 --- a/src/executor/fragment_builder.cpp +++ b/src/executor/fragment_builder.cpp @@ -143,6 +143,51 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu current_fragment_ptr->AddChild(Move(next_plan_fragment)); return; } + case PhysicalOperatorType::kCreateIndexPrepare: { + if (phys_op->left() != nullptr || phys_op->right() != nullptr) { + Error(Format("Invalid input node of {}", phys_op->GetName())); + } + current_fragment_ptr->AddOperator(phys_op); + current_fragment_ptr->SetFragmentType(FragmentType::kSerialMaterialize); + current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kEmpty, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); + return; + } + case PhysicalOperatorType::kCreateIndexDo: { + if (phys_op->left() == nullptr || phys_op->right() != nullptr) { + Error(Format("Invalid input node of {}", phys_op->GetName())); + } + current_fragment_ptr->AddOperator(phys_op); + current_fragment_ptr->SetFragmentType(FragmentType::kParallelMaterialize); + current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kLocalQueue, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); + + auto next_plan_fragment = MakeUnique(GetFragmentId()); + next_plan_fragment->SetSinkNode(query_context_ptr_, + SinkType::kLocalQueue, + phys_op->left()->GetOutputNames(), + phys_op->left()->GetOutputTypes()); + BuildFragments(phys_op->left(), next_plan_fragment.get()); + + current_fragment_ptr->AddChild(Move(next_plan_fragment)); + return; + } + case PhysicalOperatorType::kCreateIndexFinish: { + if (phys_op->left() == nullptr || phys_op->right() != nullptr) { + Error(Format("Invalid input node of {}", phys_op->GetName())); + } + current_fragment_ptr->AddOperator(phys_op); + current_fragment_ptr->SetFragmentType(FragmentType::kSerialMaterialize); + current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kLocalQueue, phys_op->GetOutputNames(), phys_op->GetOutputTypes()); + + auto next_plan_fragment = MakeUnique(GetFragmentId()); + next_plan_fragment->SetSinkNode(query_context_ptr_, + SinkType::kLocalQueue, + phys_op->left()->GetOutputNames(), + phys_op->left()->GetOutputTypes()); + BuildFragments(phys_op->left(), next_plan_fragment.get()); + + current_fragment_ptr->AddChild(Move(next_plan_fragment)); + return; + } case PhysicalOperatorType::kParallelAggregate: case PhysicalOperatorType::kFilter: case PhysicalOperatorType::kHash: diff --git a/src/executor/fragment_builder.cppm b/src/executor/fragment_builder.cppm index 006e5f8eb6..72bf804b76 100644 --- a/src/executor/fragment_builder.cppm +++ b/src/executor/fragment_builder.cppm @@ -29,9 +29,9 @@ public: UniquePtr BuildFragment(PhysicalOperator *phys_op); +private: void BuildFragments(PhysicalOperator *phys_op, PlanFragment *current_fragment_ptr); -private: void BuildExplain(PhysicalOperator *phys_op, PlanFragment *current_fragment_ptr); idx_t GetFragmentId() { return fragment_id_++; } diff --git a/src/executor/operator/physical_create_index_do.cpp b/src/executor/operator/physical_create_index_do.cpp new file mode 100644 index 0000000000..2e1526c802 --- /dev/null +++ b/src/executor/operator/physical_create_index_do.cpp @@ -0,0 +1,191 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import parser; +import physical_operator_type; +import physical_operator; +import query_context; +import operator_state; +import load_meta; + +import index_def; +import create_index_data; +import base_table_ref; +import table_collection_entry; +import table_index_entry; +import segment_column_index_entry; +import status; +import infinity_exception; +import buffer_handle; +import index_hnsw; +import index_base; +import hnsw_common; +import dist_func_l2; +import dist_func_ip; +import hnsw_alg; +import lvq_store; +import plain_store; +import buffer_manager; +import txn_store; +import third_party; +import logger; + +module physical_create_index_do; + +namespace infinity { +PhysicalCreateIndexDo::PhysicalCreateIndexDo(u64 id, + UniquePtr left, + SharedPtr base_table_ref, + SharedPtr index_name, + SharedPtr> output_names, + SharedPtr>> output_types, + SharedPtr> load_metas) + : PhysicalOperator(PhysicalOperatorType::kCreateIndexDo, Move(left), nullptr, id, load_metas), base_table_ref_(base_table_ref), + index_name_(index_name), output_names_(output_names), output_types_(output_types) {} + +void PhysicalCreateIndexDo::Init() {} + +// FIXME: fetch and add a block one time +template +void InsertHnsw(HashMap &create_index_idxes, + const HashMap> &segment_column_index_entries, + BufferManager *buffer_mgr) { + for (auto &[segment_id, create_index_idx] : create_index_idxes) { + auto iter = segment_column_index_entries.find(segment_id); + if (iter == segment_column_index_entries.end()) { + Error("Segment id not found in column index entry."); + } + auto *segment_column_index_entry = iter->second.get(); + + auto buffer_handle = SegmentColumnIndexEntry::GetIndex(segment_column_index_entry, buffer_mgr); + auto *hnsw_index = static_cast(buffer_handle.GetDataMut()); + + SizeT vertex_n = hnsw_index->GetVertexNum(); + while (true) { + SizeT idx = create_index_idx.fetch_add(1); + if (idx % 10000 == 0) { + LOG_INFO(Format("Insert index: {}", idx)); + } + if (idx >= vertex_n) { + break; + } + hnsw_index->Build(idx); + } + } +} + +bool PhysicalCreateIndexDo::Execute(QueryContext *query_context, OperatorState *operator_state) { + auto *txn = query_context->GetTxn(); + auto *create_index_do_state = static_cast(operator_state); + auto &create_index_idxes = create_index_do_state->create_index_shared_data_->create_index_idxes_; + + TableCollectionEntry *table_entry = nullptr; + Status get_table_entry_status = txn->GetTableEntry(*base_table_ref_->schema_name(), *base_table_ref_->table_name(), table_entry); + if (!get_table_entry_status.ok()) { + operator_state->error_message_ = Move(get_table_entry_status.msg_); + return false; + } + + TxnTableStore *table_store = txn->GetTxnTableStore(table_entry); + auto iter = table_store->txn_indexes_store_.find(*index_name_); + if (iter == table_store->txn_indexes_store_.end()) { + // the table is empty + operator_state->SetComplete(); + return true; + } + TxnIndexStore &txn_index_store = iter->second; + + auto *table_index_entry = txn_index_store.table_index_entry_; + if (table_index_entry->index_def_->index_array_.size() != 1) { + Error("Not implemented"); + } + auto *index_base = table_index_entry->index_def_->index_array_[0].get(); + auto *hnsw_def = static_cast(index_base); + + if (txn_index_store.index_entry_map_.size() != 1) { + Error("Not implemented"); + } + const auto &[column_id, segment_column_index_entries] = *txn_index_store.index_entry_map_.begin(); + + auto *column_def = table_entry->columns_[column_id].get(); + if (column_def->type()->type() != LogicalType::kEmbedding) { + Error("Create index on non-embedding column is not supported yet."); + } + TypeInfo *type_info = column_def->type()->type_info().get(); + auto embedding_info = static_cast(type_info); + + switch (embedding_info->Type()) { + case kElemFloat: { + switch (hnsw_def->encode_type_) { + case HnswEncodeType::kPlain: { + switch (hnsw_def->metric_type_) { + case MetricType::kMerticInnerProduct: { + InsertHnsw, PlainIPDist>>(create_index_idxes, + segment_column_index_entries, + txn->GetBufferMgr()); + break; + } + case MetricType::kMerticL2: { + InsertHnsw, PlainL2Dist>>(create_index_idxes, + segment_column_index_entries, + txn->GetBufferMgr()); + break; + } + default: { + Error("Not implemented"); + } + } + break; + } + case HnswEncodeType::kLVQ: { + switch (hnsw_def->metric_type_) { + case MetricType::kMerticInnerProduct: { + InsertHnsw>, LVQIPDist>>( + create_index_idxes, + segment_column_index_entries, + txn->GetBufferMgr()); + break; + } + case MetricType::kMerticL2: { + InsertHnsw>, LVQL2Dist>>( + create_index_idxes, + segment_column_index_entries, + txn->GetBufferMgr()); + break; + } + default: { + Error("Not implemented"); + } + } + break; + } + default: { + Error("Not implemented"); + } + } + break; + } + default: { + Error("Create index on non-float embedding column is not supported yet."); + } + } + operator_state->SetComplete(); + + return true; +} + +}; // namespace infinity \ No newline at end of file diff --git a/src/executor/operator/physical_create_index_do.cppm b/src/executor/operator/physical_create_index_do.cppm new file mode 100644 index 0000000000..388324ae4b --- /dev/null +++ b/src/executor/operator/physical_create_index_do.cppm @@ -0,0 +1,62 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import parser; +import physical_operator_type; +import physical_operator; +import query_context; +import operator_state; +import load_meta; + +import index_def; +import base_table_ref; + +export module physical_create_index_do; + +namespace infinity { + +export class PhysicalCreateIndexDo : public PhysicalOperator { +public: + PhysicalCreateIndexDo(u64 id, + UniquePtr left, + SharedPtr base_table_ref, + SharedPtr index_name, + SharedPtr> output_names, + SharedPtr>> output_types, + SharedPtr> load_metas); + +public: + void Init() override; + + bool Execute(QueryContext *query_context, OperatorState *operator_state) override; + + SizeT TaskletCount() override { return 0; } + + SharedPtr> GetOutputNames() const override { return output_names_; } + + SharedPtr>> GetOutputTypes() const override { return output_types_; } + +public: + // for create fragemnt context + const SharedPtr base_table_ref_{}; + const SharedPtr index_name_{}; + + const SharedPtr> output_names_{}; + const SharedPtr>> output_types_{}; +}; + +} // namespace infinity \ No newline at end of file diff --git a/src/executor/operator/physical_create_index_finish.cpp b/src/executor/operator/physical_create_index_finish.cpp new file mode 100644 index 0000000000..81039e3f0f --- /dev/null +++ b/src/executor/operator/physical_create_index_finish.cpp @@ -0,0 +1,57 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import parser; +import physical_operator_type; +import physical_operator; +import query_context; +import operator_state; +import load_meta; + +import index_def; +import wal_entry; + +module physical_create_index_finish; + +namespace infinity { +PhysicalCreateIndexFinish::PhysicalCreateIndexFinish(u64 id, + UniquePtr left, + SharedPtr db_name, + SharedPtr table_name, + SharedPtr index_def, + SharedPtr> output_names, + SharedPtr>> output_types, + SharedPtr> load_metas) + : PhysicalOperator(PhysicalOperatorType::kCreateIndexFinish, Move(left), nullptr, id, load_metas), db_name_(db_name), table_name_(table_name), + index_def_(index_def), output_names_(output_names), output_types_(output_types) {} + +void PhysicalCreateIndexFinish::Init() {} + +bool PhysicalCreateIndexFinish::Execute(QueryContext *query_context, OperatorState *operator_state) { + auto *txn = query_context->GetTxn(); + auto *create_index_finish_op_state = static_cast(operator_state); + + if (create_index_finish_op_state->input_complete_) { + txn->AddWalCmd(MakeShared(*db_name_, *table_name_, index_def_)); + + operator_state->SetComplete(); + return true; + } + return false; +} + +} // namespace infinity \ No newline at end of file diff --git a/src/executor/operator/physical_create_index_finish.cppm b/src/executor/operator/physical_create_index_finish.cppm new file mode 100644 index 0000000000..0e5a295955 --- /dev/null +++ b/src/executor/operator/physical_create_index_finish.cppm @@ -0,0 +1,60 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import parser; +import physical_operator_type; +import physical_operator; +import query_context; +import operator_state; +import load_meta; +import index_def; + +export module physical_create_index_finish; + +namespace infinity { +export class PhysicalCreateIndexFinish : public PhysicalOperator { +public: + PhysicalCreateIndexFinish(u64 id, + UniquePtr left, + SharedPtr db_name, + SharedPtr table_name, + SharedPtr index_def, + SharedPtr> output_names, + SharedPtr>> output_types, + SharedPtr> load_metas); + +public: + void Init() override; + + bool Execute(QueryContext *query_context, OperatorState *operator_state) override; + + SizeT TaskletCount() override { return 1; } + + SharedPtr> GetOutputNames() const override { return output_names_; } + + SharedPtr>> GetOutputTypes() const override { return output_types_; } + +public: + const SharedPtr db_name_{}; + const SharedPtr table_name_{}; + const SharedPtr index_def_{}; + + const SharedPtr> output_names_{}; + const SharedPtr>> output_types_{}; +}; + +} // namespace infinity \ No newline at end of file diff --git a/src/executor/operator/physical_create_index_prepare.cpp b/src/executor/operator/physical_create_index_prepare.cpp new file mode 100644 index 0000000000..4e6e8efddd --- /dev/null +++ b/src/executor/operator/physical_create_index_prepare.cpp @@ -0,0 +1,200 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import parser; +import physical_operator_type; +import physical_operator; +import query_context; +import operator_state; +import load_meta; + +import index_def; +import table_collection_entry; +import status; +import infinity_exception; +import base_entry; +import segment_entry; +import table_index_entry; +import column_index_entry; +import segment_column_index_entry; +import index_base; +import index_file_worker; +import segment_iter; +import buffer_manager; +import buffer_handle; +import index_hnsw; +import default_values; +import txn_store; + +import hnsw_common; +import dist_func_l2; +import dist_func_ip; +import hnsw_alg; +import lvq_store; +import plain_store; + +module physical_create_index_prepare; + +namespace infinity { +PhysicalCreateIndexPrepare::PhysicalCreateIndexPrepare(u64 id, + SharedPtr schema_name, + SharedPtr table_name, + SharedPtr index_definition, + ConflictType conflict_type, + SharedPtr> output_names, + SharedPtr>> output_types, + SharedPtr> load_metas) + : PhysicalOperator(PhysicalOperatorType::kCreateIndexPrepare, nullptr, nullptr, id, load_metas), schema_name_(schema_name), + table_name_(table_name), index_def_ptr_(index_definition), conflict_type_(conflict_type), output_names_(output_names), + output_types_(output_types) {} + +void PhysicalCreateIndexPrepare::Init() {} + +template +void InsertHnswPrepare(BufferHandle buffer_handle, const SegmentEntry *segment_entry, u32 column_id) { + auto hnsw_index = static_cast(buffer_handle.GetDataMut()); + + u32 segment_offset = 0; + Vector row_ids; + const auto &block_entries = segment_entry->block_entries_; + for (SizeT i = 0; i < block_entries.size(); ++i) { + const auto &block_entry = block_entries[i]; + SizeT block_row_cnt = block_entry->row_count_; + + for (SizeT block_offset = 0; block_offset < block_row_cnt; ++block_offset) { + RowID row_id(segment_entry->segment_id_, segment_offset + block_offset); + row_ids.push_back(row_id.ToUint64()); + } + segment_offset += DEFAULT_BLOCK_CAPACITY; + } + OneColumnIterator one_column_iter(segment_entry, column_id); + + hnsw_index->StoreData(one_column_iter, row_ids.data(), row_ids.size()); +} + +bool PhysicalCreateIndexPrepare::Execute(QueryContext *query_context, OperatorState *operator_state) { + auto *txn = query_context->GetTxn(); + TxnTimeStamp begin_ts = txn->BeginTS(); + BufferManager *buffer_mgr = txn->GetBufferMgr(); + + TableCollectionEntry *table_entry = nullptr; + Status get_table_status = txn->GetTableEntry(*schema_name_, *table_name_, table_entry); + if (!get_table_status.ok()) { + operator_state->error_message_ = Move(get_table_status.msg_); + return false; + } + + TableIndexEntry *table_index_entry = nullptr; + Status create_index_status = txn->CreateIndex(table_entry, index_def_ptr_, conflict_type_, table_index_entry); + if (!create_index_status.ok()) { + operator_state->error_message_ = Move(create_index_status.msg_); + return false; + } + + if (table_index_entry->irs_index_entry_.get() != nullptr) { + Error("TableCollectionEntry::CreateIndexFilePrepare"); + } + if (table_index_entry->column_index_map_.size() != 1) { + Error("TableCollectionEntry::CreateIndexFilePrepare"); + } + auto [column_id, base_entry] = *table_index_entry->column_index_map_.begin(); + SharedPtr column_def = table_entry->columns_[column_id]; + auto *column_index_entry = static_cast(base_entry.get()); + + for (const auto &[segment_id, segment_entry] : table_entry->segment_map_) { + u64 column_id = column_def->id(); + IndexBase *index_base = column_index_entry->index_base_.get(); + UniquePtr create_index_param = SegmentEntry::GetCreateIndexParam(segment_entry.get(), index_base, column_def.get()); + SharedPtr segment_column_index_entry = + SegmentColumnIndexEntry::NewIndexEntry(column_index_entry, segment_entry->segment_id_, begin_ts, buffer_mgr, create_index_param.get()); + + if (index_base->index_type_ != IndexType::kHnsw) { + Error("Only HNSW index is supported."); + } + auto *index_hnsw = static_cast(index_base); + if (column_def->type()->type() != LogicalType::kEmbedding) { + Error("HNSW supports embedding type."); + } + TypeInfo *type_info = column_def->type()->type_info().get(); + auto embedding_info = static_cast(type_info); + + BufferHandle buffer_handle = SegmentColumnIndexEntry::GetIndex(segment_column_index_entry.get(), buffer_mgr); + switch (embedding_info->Type()) { + case kElemFloat: { + switch (index_hnsw->encode_type_) { + case HnswEncodeType::kPlain: { + switch (index_hnsw->metric_type_) { + case MetricType::kMerticInnerProduct: { + InsertHnswPrepare, PlainIPDist>>(buffer_handle, + segment_entry.get(), + column_id); + break; + } + case MetricType::kMerticL2: { + InsertHnswPrepare, PlainL2Dist>>(buffer_handle, + segment_entry.get(), + column_id); + break; + } + default: { + Error("Not implemented"); + } + } + break; + } + case HnswEncodeType::kLVQ: { + switch (index_hnsw->metric_type_) { + case MetricType::kMerticInnerProduct: { + InsertHnswPrepare>, LVQIPDist>>( + buffer_handle, + segment_entry.get(), + column_id); + break; + } + case MetricType::kMerticL2: { + InsertHnswPrepare>, LVQL2Dist>>( + buffer_handle, + segment_entry.get(), + column_id); + break; + } + default: { + Error("Not implemented"); + } + } + break; + } + default: { + Error("Not implemented"); + } + } + break; + } + default: { + Error("Not implemented"); + } + } + TxnTableStore *table_store = txn->GetTxnTableStore(table_entry); + table_store->CreateIndexFile(table_index_entry, column_id, segment_id, segment_column_index_entry); + + column_index_entry->index_by_segment.emplace(segment_id, segment_column_index_entry); + } + + operator_state->SetComplete(); + return true; +} +} // namespace infinity \ No newline at end of file diff --git a/src/executor/operator/physical_create_index_prepare.cppm b/src/executor/operator/physical_create_index_prepare.cppm new file mode 100644 index 0000000000..b63d9ff43f --- /dev/null +++ b/src/executor/operator/physical_create_index_prepare.cppm @@ -0,0 +1,62 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +import stl; +import parser; +import physical_operator_type; +import physical_operator; +import query_context; +import operator_state; +import load_meta; + +import index_def; + +export module physical_create_index_prepare; + +namespace infinity { +export class PhysicalCreateIndexPrepare : public PhysicalOperator { +public: + PhysicalCreateIndexPrepare(u64 id, + SharedPtr schema_name, + SharedPtr table_name, + SharedPtr index_definition, + ConflictType conflict_type, + SharedPtr> output_names, + SharedPtr>> output_types, + SharedPtr> load_metas); + +public: + void Init() override; + + bool Execute(QueryContext *query_context, OperatorState *operator_state) override; + + SizeT TaskletCount() override { return 1; } + + SharedPtr> GetOutputNames() const override { return output_names_; } + + SharedPtr>> GetOutputTypes() const override { return output_types_; } + +public: + const SharedPtr schema_name_{}; + const SharedPtr table_name_{}; + const SharedPtr index_def_ptr_{}; + const ConflictType conflict_type_{}; + + const SharedPtr> output_names_{}; + const SharedPtr>> output_types_{}; +}; + +} // namespace infinity \ No newline at end of file diff --git a/src/executor/operator/physical_knn_scan.cpp b/src/executor/operator/physical_knn_scan.cpp index bc0bd52f2c..202c802e68 100644 --- a/src/executor/operator/physical_knn_scan.cpp +++ b/src/executor/operator/physical_knn_scan.cpp @@ -147,8 +147,8 @@ void PhysicalKnnScan::Init() {} bool PhysicalKnnScan::Execute(QueryContext *query_context, OperatorState *operator_state) { auto *knn_scan_operator_state = static_cast(operator_state); - auto elem_type = knn_scan_operator_state->knn_scan_function_data_->shared_data_->elem_type_; - auto dist_type = knn_scan_operator_state->knn_scan_function_data_->shared_data_->knn_distance_type_; + auto elem_type = knn_scan_operator_state->knn_scan_function_data_->knn_scan_shared_data_->elem_type_; + auto dist_type = knn_scan_operator_state->knn_scan_function_data_->knn_scan_shared_data_->knn_distance_type_; switch (elem_type) { case kElemFloat: { switch (dist_type) { @@ -239,7 +239,7 @@ SizeT PhysicalKnnScan::BlockEntryCount() const { return base_table_ref_->block_i template typename C> void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperatorState *operator_state) { auto knn_scan_function_data = operator_state->knn_scan_function_data_.get(); - auto knn_scan_shared_data = knn_scan_function_data->shared_data_; + auto knn_scan_shared_data = knn_scan_function_data->knn_scan_shared_data_; auto dist_func = static_cast *>(knn_scan_function_data->knn_distance_.get()); auto merge_heap = static_cast *>(knn_scan_function_data->merge_knn_base_.get()); @@ -373,7 +373,7 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat case IndexType::kHnsw: { BufferHandle index_handle = SegmentColumnIndexEntry::GetIndex(segment_column_index_entry, buffer_mgr); auto index_hnsw = static_cast(segment_column_index_entry->column_index_entry_->index_base_.get()); - auto KnnScanOld = [&](auto *index) { + auto KnnScan = [&](auto *index) { Vector dists(knn_scan_shared_data->topk_ * knn_scan_shared_data->query_count_); Vector row_ids(knn_scan_shared_data->topk_ * knn_scan_shared_data->query_count_); @@ -388,15 +388,14 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat for (u64 query_idx = 0; query_idx < knn_scan_shared_data->query_count_; ++query_idx) { const DataType *query = static_cast(knn_scan_shared_data->query_embedding_) + query_idx * knn_scan_shared_data->dimension_; - MaxHeap> heap = index->KnnSearch(query, knn_scan_shared_data->topk_, bitmask); + Vector> result = index->KnnSearch(query, knn_scan_shared_data->topk_, bitmask); if (result_n < 0) { - result_n = heap.size(); - } else if (result_n != (i64)heap.size()) { + result_n = result.size(); + } else if (result_n != (i64)result.size()) { throw ExecutorException("Bug"); } u64 id = 0; - while (!heap.empty()) { - const auto &[dist, row_id] = heap.top(); + for (const auto &[dist, row_id] : result) { row_ids[query_idx * knn_scan_shared_data->topk_ + id] = RowID::FromUint64(row_id); switch (knn_scan_shared_data->knn_distance_type_) { case KnnDistanceType::kInvalid: { @@ -414,81 +413,22 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat } } ++id; - heap.pop(); } } merge_heap->Search(dists.data(), row_ids.data(), result_n); }; - auto KnnScanUseHeap = [&](auto *index) { - if constexpr (!std::is_same_v) { - Error("Bug: Hnsw LabelType must be u64"); - } - for (const auto &opt_param : knn_scan_shared_data->opt_params_) { - if (opt_param.param_name_ == "ef") { - u64 ef = std::stoull(opt_param.param_value_); - index->SetEf(ef); - } - } - i64 result_n = -1; - for (u64 query_idx = 0; query_idx < knn_scan_shared_data->query_count_; ++query_idx) { - const DataType *query = - static_cast(knn_scan_shared_data->query_embedding_) + query_idx * knn_scan_shared_data->dimension_; - auto search_result = index->KnnSearchReturnPair(query, knn_scan_shared_data->topk_, bitmask); - auto &[result_size, unique_ptr_pair] = search_result; - auto &[d_ptr, l_ptr] = unique_ptr_pair; - if (result_n < 0) { - result_n = result_size; - } else if (result_n != (i64)result_size) { - throw ExecutorException("Bug"); - } - if (result_size <= 0) { - continue; - } - UniquePtr row_ids_ptr; - RowID *row_ids = nullptr; - if constexpr (sizeof(RowID) == sizeof(LabelType)) { - row_ids = reinterpret_cast(l_ptr.get()); - } else { - row_ids_ptr = MakeUniqueForOverwrite(result_size); - row_ids = row_ids_ptr.get(); - } - for (SizeT i = 0; i < result_size; ++i) { - row_ids[i] = RowID::FromUint64(l_ptr[i]); - } - switch (knn_scan_shared_data->knn_distance_type_) { - case KnnDistanceType::kInvalid: { - throw ExecutorException("Bug"); - } - case KnnDistanceType::kL2: - case KnnDistanceType::kHamming: { - break; - } - case KnnDistanceType::kCosine: - case KnnDistanceType::kInnerProduct: { - for (SizeT i = 0; i < result_size; ++i) { - d_ptr[i] = -d_ptr[i]; - } - break; - } - } - merge_heap->Search(0, d_ptr.get(), row_ids, result_size); - } - }; - auto KnnScan = [&](auto *index) { - using LabelType = typename std::remove_pointer_t::HnswLabelType; - KnnScanUseHeap.template operator()(index); - }; + using LabelType = u64; switch (index_hnsw->encode_type_) { case HnswEncodeType::kPlain: { switch (index_hnsw->metric_type_) { case MetricType::kMerticInnerProduct: { - using Hnsw = KnnHnsw, PlainIPDist>; + using Hnsw = KnnHnsw, PlainIPDist>; // Fixme: const_cast here. may have bug. KnnScan(const_cast(static_cast(index_handle.GetData()))); break; } case MetricType::kMerticL2: { - using Hnsw = KnnHnsw, PlainL2Dist>; + using Hnsw = KnnHnsw, PlainL2Dist>; KnnScan(const_cast(static_cast(index_handle.GetData()))); break; } @@ -501,12 +441,12 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat case HnswEncodeType::kLVQ: { switch (index_hnsw->metric_type_) { case MetricType::kMerticInnerProduct: { - using Hnsw = KnnHnsw>, LVQIPDist>; + using Hnsw = KnnHnsw>, LVQIPDist>; KnnScan(const_cast(static_cast(index_handle.GetData()))); break; } case MetricType::kMerticL2: { - using Hnsw = KnnHnsw>, LVQL2Dist>; + using Hnsw = KnnHnsw>, LVQL2Dist>; KnnScan(const_cast(static_cast(index_handle.GetData()))); break; } diff --git a/src/executor/operator/physical_sink.cpp b/src/executor/operator/physical_sink.cpp index c970fb199a..1cfc3acd6d 100644 --- a/src/executor/operator/physical_sink.cpp +++ b/src/executor/operator/physical_sink.cpp @@ -131,19 +131,6 @@ void PhysicalSink::FillSinkStateFromLastOperatorState(MaterializeSinkState *mate } break; } - case PhysicalOperatorType::kKnnScan: { - throw ExecutorException("KnnScan shouldn't be here"); - KnnScanOperatorState *knn_output_state = static_cast(task_op_state); - if (knn_output_state->data_block_array_.empty()) { - Error("Empty knn scan output"); - } - - for (auto &data_block : knn_output_state->data_block_array_) { - materialize_sink_state->data_block_array_.emplace_back(Move(data_block)); - } - knn_output_state->data_block_array_.clear(); - break; - } case PhysicalOperatorType::kAggregate: { AggregateOperatorState *agg_output_state = static_cast(task_op_state); if (agg_output_state->data_block_array_.empty()) { @@ -325,6 +312,17 @@ void PhysicalSink::FillSinkStateFromLastOperatorState(ResultSinkState *result_si } break; } + case PhysicalOperatorType::kCreateIndexFinish: { + auto *output_state = static_cast(task_operator_state); + if (output_state->error_message_.get() != nullptr) { + result_sink_state->error_message_ = Move(output_state->error_message_); + break; + } + result_sink_state->result_def_ = { + MakeShared(0, MakeShared(LogicalType::kInteger), "OK", HashSet()), + }; + break; + } default: { Error(Format("{} isn't supported here.", PhysicalOperatorToString(task_operator_state->operator_type_))); } @@ -351,15 +349,13 @@ void PhysicalSink::FillSinkStateFromLastOperatorState(MessageSinkState *message_ } void PhysicalSink::FillSinkStateFromLastOperatorState(QueueSinkState *queue_sink_state, OperatorState *task_operator_state) { - if(queue_sink_state->error_message_.get() != nullptr) { + if (queue_sink_state->error_message_.get() != nullptr) { LOG_TRACE(Format("Error: {} is sent to notify next fragment", *queue_sink_state->error_message_)); - SharedPtr fragment_data = MakeShared(); - fragment_data->error_message_ = MakeUnique(*queue_sink_state->error_message_); - fragment_data->fragment_id_ = queue_sink_state->fragment_id_; + auto fragment_error = MakeShared(queue_sink_state->fragment_id_, MakeUnique(*queue_sink_state->error_message_)); for (const auto &next_fragment_queue : queue_sink_state->fragment_data_queues_) { - next_fragment_queue->Enqueue(fragment_data); + next_fragment_queue->Enqueue(fragment_error); } - return ; + return; } if (!task_operator_state->Complete()) { @@ -368,15 +364,18 @@ void PhysicalSink::FillSinkStateFromLastOperatorState(QueueSinkState *queue_sink } SizeT output_data_block_count = task_operator_state->data_block_array_.size(); if (output_data_block_count == 0) { - Error("No output from knn scan"); + for (const auto &next_fragment_queue : queue_sink_state->fragment_data_queues_) { + next_fragment_queue->Enqueue(MakeShared(queue_sink_state->fragment_id_)); + } + return; + // Error("No output from last operator."); } for (SizeT idx = 0; idx < output_data_block_count; ++idx) { - SharedPtr fragment_data = MakeShared(); - fragment_data->fragment_id_ = queue_sink_state->fragment_id_; - fragment_data->task_id_ = queue_sink_state->task_id_; - fragment_data->data_block_ = Move(task_operator_state->data_block_array_[idx]); - fragment_data->data_count_ = output_data_block_count; - fragment_data->data_idx_ = idx; + auto fragment_data = MakeShared(queue_sink_state->fragment_id_, + Move(task_operator_state->data_block_array_[idx]), + queue_sink_state->task_id_, + idx, + output_data_block_count); for (const auto &next_fragment_queue : queue_sink_state->fragment_data_queues_) { next_fragment_queue->Enqueue(fragment_data); } diff --git a/src/executor/operator_state.cpp b/src/executor/operator_state.cpp index 1ca896a67f..0c439534c7 100644 --- a/src/executor/operator_state.cpp +++ b/src/executor/operator_state.cpp @@ -39,37 +39,66 @@ void QueueSourceState::MarkCompletedTask(u64 fragment_id) { // A false return value indicate there are more data need to read from source. // True or false doesn't mean the source data is error or not. bool QueueSourceState::GetData() { - SharedPtr fragment_data = nullptr; - source_queue_.Dequeue(fragment_data); + SharedPtr fragment_data_base = nullptr; + source_queue_.Dequeue(fragment_data_base); - if(fragment_data->error_message_.get() != nullptr) { - if(this->error_message_.get() == nullptr) { - // Only record the first error of input data. - this->error_message_ = Move(fragment_data->error_message_); + switch (fragment_data_base->type_) { + case FragmentDataType::kData: { + auto *fragment_data = static_cast(fragment_data_base.get()); + if (fragment_data->data_idx_ + 1 == fragment_data->data_count_) { + // Get an all data from this + MarkCompletedTask(fragment_data->fragment_id_); + } + break; + } + case FragmentDataType::kError: { + auto *fragment_error = static_cast(fragment_data_base.get()); + if (this->error_message_.get() == nullptr) { + // Only record the first error of input data. + this->error_message_ = Move(fragment_error->error_message_); + } + // Get an error message from predecessor fragment + MarkCompletedTask(fragment_error->fragment_id_); + break; + } + case FragmentDataType::kNone: { + auto *fragment_none = static_cast(fragment_data_base.get()); + MarkCompletedTask(fragment_none->fragment_id_); + break; + } + default: { + Error("Not support fragment data type"); + break; } - - // Get an error message from predecessor fragment - MarkCompletedTask(fragment_data->fragment_id_); - } else if (fragment_data->data_idx_ + 1 == fragment_data->data_count_) { - // Get an all data from this - MarkCompletedTask(fragment_data->fragment_id_); } bool completed = num_tasks_.empty(); OperatorState *next_op_state = this->next_op_state_; switch (next_op_state->operator_type_) { case PhysicalOperatorType::kMergeKnn: { + auto *fragment_data = static_cast(fragment_data_base.get()); MergeKnnOperatorState *merge_knn_op_state = (MergeKnnOperatorState *)next_op_state; merge_knn_op_state->input_data_block_ = Move(fragment_data->data_block_); merge_knn_op_state->input_complete_ = completed; break; } case PhysicalOperatorType::kFusion: { + auto *fragment_data = static_cast(fragment_data_base.get()); FusionOperatorState *fusion_op_state = (FusionOperatorState *)next_op_state; fusion_op_state->input_data_blocks_[fragment_data->fragment_id_].push_back(Move(fragment_data->data_block_)); fusion_op_state->input_complete_ = completed; break; } + case PhysicalOperatorType::kCreateIndexDo: { + auto *create_index_do_op_state = static_cast(next_op_state); + create_index_do_op_state->input_complete_ = completed; + break; + } + case PhysicalOperatorType::kCreateIndexFinish: { + auto *create_index_finish_op_state = static_cast(next_op_state); + create_index_finish_op_state->input_complete_ = completed; + break; + } default: { Error("Not support operator type"); break; diff --git a/src/executor/operator_state.cppm b/src/executor/operator_state.cppm index 25a7f5b479..80500c5647 100644 --- a/src/executor/operator_state.cppm +++ b/src/executor/operator_state.cppm @@ -24,6 +24,7 @@ import knn_scan_data; import table_def; import parser; import merge_knn_data; +import create_index_data; import blocking_queue; export module operator_state; @@ -232,6 +233,26 @@ export struct CreateIndexOperatorState : public OperatorState { inline explicit CreateIndexOperatorState() : OperatorState(PhysicalOperatorType::kCreateIndex) {} }; +export struct CreateIndexPrepareOperatorState : public OperatorState { + inline explicit CreateIndexPrepareOperatorState() : OperatorState(PhysicalOperatorType::kCreateIndexPrepare) {} + + UniquePtr result_msg_{}; +}; + +export struct CreateIndexDoOperatorState : public OperatorState { + inline explicit CreateIndexDoOperatorState() : OperatorState(PhysicalOperatorType::kCreateIndexDo) {} + + bool input_complete_ = false; + CreateIndexSharedData *create_index_shared_data_; +}; + +export struct CreateIndexFinishOperatorState : public OperatorState { + inline explicit CreateIndexFinishOperatorState() : OperatorState(PhysicalOperatorType::kCreateIndexFinish) {} + + bool input_complete_ = false; + UniquePtr error_message_{}; +}; + // Create Collection export struct CreateCollectionOperatorState : public OperatorState { inline explicit CreateCollectionOperatorState() : OperatorState(PhysicalOperatorType::kCreateCollection) {} @@ -339,7 +360,7 @@ export struct QueueSourceState : public SourceState { bool GetData(); - BlockingQueue> source_queue_{}; + BlockingQueue> source_queue_{}; Map num_tasks_; // fragment_id -> number of pending tasks @@ -356,7 +377,7 @@ export struct AggregateSourceState : public SourceState { i64 hash_start_{}; i64 hash_end_{}; - BlockingQueue> source_queue_{}; + BlockingQueue> source_queue_{}; }; export struct TableScanSourceState : public SourceState { @@ -406,7 +427,7 @@ export struct QueueSinkState : public SinkState { inline explicit QueueSinkState(u64 fragment_id, u64 task_id) : SinkState(SinkStateType::kQueue, fragment_id, task_id) {} Vector> data_block_array_{}; - Vector> *> fragment_data_queues_; + Vector> *> fragment_data_queues_; }; export struct MaterializeSinkState : public SinkState { diff --git a/src/executor/physical_operator_type.cpp b/src/executor/physical_operator_type.cpp index 3fea83d9b7..c93d36b97a 100644 --- a/src/executor/physical_operator_type.cpp +++ b/src/executor/physical_operator_type.cpp @@ -132,6 +132,12 @@ String PhysicalOperatorToString(PhysicalOperatorType type) { return "Fusion"; case PhysicalOperatorType::kMergeAggregate: return "MergeAggregate"; + case PhysicalOperatorType::kCreateIndexPrepare: + return "CreateIndexPrepare"; + case PhysicalOperatorType::kCreateIndexDo: + return "CreateIndexDo"; + case PhysicalOperatorType::kCreateIndexFinish: + return "CreateIndexFinish"; } Error("Unknown physical operator type"); diff --git a/src/executor/physical_operator_type.cppm b/src/executor/physical_operator_type.cppm index 5d24b12b76..e9be8c31ed 100644 --- a/src/executor/physical_operator_type.cppm +++ b/src/executor/physical_operator_type.cppm @@ -70,6 +70,7 @@ export enum class PhysicalOperatorType : i8 { kInsert, kImport, kExport, + kCreateIndexDo, // DDL kAlter, @@ -84,6 +85,9 @@ export enum class PhysicalOperatorType : i8 { kDropDatabase, kDropView, + kCreateIndexPrepare, + kCreateIndexFinish, + // misc kExplain, kPreparedPlan, diff --git a/src/executor/physical_planner.cpp b/src/executor/physical_planner.cpp index ed60458b5b..011e2ae19b 100644 --- a/src/executor/physical_planner.cpp +++ b/src/executor/physical_planner.cpp @@ -75,6 +75,9 @@ import physical_drop_index; import physical_command; import physical_match; import physical_fusion; +import physical_create_index_prepare; +import physical_create_index_do; +import physical_create_index_finish; import logical_node; import logical_node_type; @@ -307,14 +310,48 @@ UniquePtr PhysicalPlanner::BuildCreateTable(const SharedPtr PhysicalPlanner::BuildCreateIndex(const SharedPtr &logical_operator) const { auto logical_create_index = static_pointer_cast(logical_operator); - return PhysicalCreateIndex::Make(logical_create_index->schema_name(), - logical_create_index->table_name(), - logical_create_index->index_definition(), - logical_create_index->conflict_type(), - logical_create_index->GetOutputNames(), - logical_create_index->GetOutputTypes(), - logical_create_index->node_id(), - logical_operator->load_metas()); + SharedPtr schema_name = logical_create_index->base_table_ref()->schema_name(); + SharedPtr table_name = logical_create_index->base_table_ref()->table_name(); + const auto &index_def_ptr = logical_create_index->index_definition(); + if (false || index_def_ptr->index_array_.size() != 1 || index_def_ptr->index_array_[0]->index_type_ != IndexType::kHnsw) { + // TODO: invalidate multiple index in one statement. + // TODO: support other index types build in parallel. + // use old `PhysicalCreateIndex` + return PhysicalCreateIndex::Make(schema_name, + table_name, + logical_create_index->index_definition(), + logical_create_index->conflict_type(), + logical_create_index->GetOutputNames(), + logical_create_index->GetOutputTypes(), + logical_create_index->node_id(), + logical_operator->load_metas()); + } + + // use new `PhysicalCreateIndexPrepare` `PhysicalCreateIndexDo` `PhysicalCreateIndexFinish` + auto create_index_prepare = MakeUnique(logical_create_index->node_id(), + schema_name, + table_name, + logical_create_index->index_definition(), + logical_create_index->conflict_type(), + logical_create_index->GetOutputNames(), + logical_create_index->GetOutputTypes(), + logical_create_index->load_metas()); + auto create_index_do = MakeUnique(logical_create_index->node_id(), + Move(create_index_prepare), + logical_create_index->base_table_ref(), + logical_create_index->index_definition()->index_name_, + logical_create_index->GetOutputNames(), + logical_create_index->GetOutputTypes(), + logical_create_index->load_metas()); + auto create_index_finish = MakeUnique(logical_create_index->node_id(), + Move(create_index_do), + schema_name, + table_name, + logical_create_index->index_definition(), + logical_create_index->GetOutputNames(), + logical_create_index->GetOutputTypes(), + logical_create_index->load_metas()); + return create_index_finish; } UniquePtr PhysicalPlanner::BuildCreateCollection(const SharedPtr &logical_operator) const { diff --git a/src/function/table/create_index_data.cppm b/src/function/table/create_index_data.cppm new file mode 100644 index 0000000000..89b49f46df --- /dev/null +++ b/src/function/table/create_index_data.cppm @@ -0,0 +1,36 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include + +import stl; +import segment_entry; + +export module create_index_data; + +namespace infinity { + +export struct CreateIndexSharedData { + CreateIndexSharedData(const Map> &segment_map) { + for (const auto &[segment_id, _] : segment_map) { + create_index_idxes_.emplace(segment_id, 0); + } + } + + HashMap create_index_idxes_{}; +}; + +}; // namespace infinity \ No newline at end of file diff --git a/src/function/table/knn_scan_data.cpp b/src/function/table/knn_scan_data.cpp index a947bb5b8d..f76575d8ff 100644 --- a/src/function/table/knn_scan_data.cpp +++ b/src/function/table/knn_scan_data.cpp @@ -63,8 +63,8 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { // -------------------------------------------- KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData* shared_data, u32 current_parallel_idx) - : shared_data_(shared_data), task_id_(current_parallel_idx) { - switch (shared_data_->elem_type_) { + : knn_scan_shared_data_(shared_data), task_id_(current_parallel_idx) { + switch (knn_scan_shared_data_->elem_type_) { case EmbeddingDataType::kElemFloat: { Init(); break; @@ -77,32 +77,32 @@ KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData* shared_data, u32 cur template void KnnScanFunctionData::Init() { - switch (shared_data_->knn_distance_type_) { + switch (knn_scan_shared_data_->knn_distance_type_) { case KnnDistanceType::kInvalid: { throw ExecutorException("Invalid Knn distance type"); } case KnnDistanceType::kL2: case KnnDistanceType::kHamming: { - auto merge_knn_max = MakeUnique>(shared_data_->query_count_, shared_data_->topk_); + auto merge_knn_max = MakeUnique>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_); merge_knn_max->Begin(); merge_knn_base_ = Move(merge_knn_max); break; } case KnnDistanceType::kCosine: case KnnDistanceType::kInnerProduct: { - auto merge_knn_min = MakeUnique>(shared_data_->query_count_, shared_data_->topk_); + auto merge_knn_min = MakeUnique>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_); merge_knn_min->Begin(); merge_knn_base_ = Move(merge_knn_min); break; } } - knn_distance_ = MakeUnique>(shared_data_->knn_distance_type_); + knn_distance_ = MakeUnique>(knn_scan_shared_data_->knn_distance_type_); - if (shared_data_->filter_expression_) { - filter_state_ = ExpressionState::CreateState(shared_data_->filter_expression_); + if (knn_scan_shared_data_->filter_expression_) { + filter_state_ = ExpressionState::CreateState(knn_scan_shared_data_->filter_expression_); db_for_filter_ = MakeUnique(); - db_for_filter_->Init(*(shared_data_->table_ref_->column_types_)); // default capacity + db_for_filter_->Init(*(knn_scan_shared_data_->table_ref_->column_types_)); // default capacity bool_column_ = ColumnVector::Make(MakeShared(LogicalType::kBoolean)); // default capacity } } diff --git a/src/function/table/knn_scan_data.cppm b/src/function/table/knn_scan_data.cppm index 509dc8fd20..d1182996c1 100644 --- a/src/function/table/knn_scan_data.cppm +++ b/src/function/table/knn_scan_data.cppm @@ -123,7 +123,7 @@ private: void Init(); public: - KnnScanSharedData* shared_data_; + KnnScanSharedData* knn_scan_shared_data_; const u32 task_id_; UniquePtr merge_knn_base_{}; diff --git a/src/main/resource_manager.cppm b/src/main/resource_manager.cppm index 41380621b2..a8d1192d47 100644 --- a/src/main/resource_manager.cppm +++ b/src/main/resource_manager.cppm @@ -30,7 +30,7 @@ public: return cpu_count; } - inline u64 GetCpuResource() { return GetCpuResource(4); } + inline u64 GetCpuResource() { return GetCpuResource(Thread::hardware_concurrency()); } inline u64 GetMemoryResource(u64 memory_size) { total_memory_ -= memory_size; diff --git a/src/planner/bound/base_table_ref.cppm b/src/planner/bound/base_table_ref.cppm index 6def9fd542..ef4be06ed2 100644 --- a/src/planner/bound/base_table_ref.cppm +++ b/src/planner/bound/base_table_ref.cppm @@ -22,6 +22,7 @@ import table_collection_entry; import parser; import table_function; import block_index; +import db_entry; import infinity_exception; @@ -42,7 +43,7 @@ public: block_index_(Move(block_index)), column_names_(Move(column_names)), column_types_(Move(column_types)), table_index_(table_index) {} void RetainColumnByIndices(const Vector &&indices) { - // OPT1212: linear judge in assert + // FIXME: linear judge in assert if (!std::is_sorted(indices.cbegin(), indices.cend())) { Error("Indices must be in order"); } @@ -52,6 +53,10 @@ public: replace_field>(*column_types_, indices); }; + SharedPtr schema_name() const { return table_entry_ptr_->table_collection_meta_->db_entry_->db_name_; } + + SharedPtr table_name() const { return table_entry_ptr_->table_collection_name_; } + TableCollectionEntry *table_entry_ptr_{}; Vector column_ids_{}; SharedPtr block_index_{}; diff --git a/src/planner/explain_logical_plan.cpp b/src/planner/explain_logical_plan.cpp index 9a9227dc88..82af2b453d 100644 --- a/src/planner/explain_logical_plan.cpp +++ b/src/planner/explain_logical_plan.cpp @@ -313,12 +313,12 @@ void ExplainLogicalPlan::Explain(const LogicalCreateIndex *create_node, SharedPt } { - String schema_name_str = String(intent_size, ' ') + " - schema name: " + *create_node->schema_name(); + String schema_name_str = String(intent_size, ' ') + " - schema name: " + *create_node->base_table_ref()->schema_name(); result->emplace_back(MakeShared(schema_name_str)); } { - String table_name_str = String(intent_size, ' ') + " - table name: " + *create_node->table_name(); + String table_name_str = String(intent_size, ' ') + " - table name: " + *create_node->base_table_ref()->table_name(); result->emplace_back(MakeShared(table_name_str)); } diff --git a/src/planner/logical_planner.cpp b/src/planner/logical_planner.cpp index 032c51079d..d78772fc1a 100644 --- a/src/planner/logical_planner.cpp +++ b/src/planner/logical_planner.cpp @@ -68,6 +68,7 @@ import index_base; import index_ivfflat; import index_hnsw; import index_full_text; +import base_table_ref; module logical_planner; @@ -490,8 +491,11 @@ Status LogicalPlanner::BuildCreateIndex(const CreateStatement *statement, Shared index_def_ptr->index_array_.emplace_back(base_index_ptr); } + UniquePtr query_binder_ptr = MakeUnique(this->query_context_ptr_, bind_context_ptr); + auto base_table_ref = std::static_pointer_cast(query_binder_ptr->GetTableRef(*schema_name, *table_name)); + auto logical_create_index_operator = - LogicalCreateIndex::Make(bind_context_ptr->GetNewLogicalNodeId(), schema_name, table_name, index_def_ptr, create_index_info->conflict_type_); + MakeShared(bind_context_ptr->GetNewLogicalNodeId(), base_table_ref, index_def_ptr, create_index_info->conflict_type_); this->logical_plan_ = logical_create_index_operator; this->names_ptr_->emplace_back("OK"); diff --git a/src/planner/node/logical_create_index.cpp b/src/planner/node/logical_create_index.cpp index b622d99390..4163106191 100644 --- a/src/planner/node/logical_create_index.cpp +++ b/src/planner/node/logical_create_index.cpp @@ -47,7 +47,8 @@ String LogicalCreateIndex::ToString(i64 &space) const { space -= 4; arrow_str = "-> "; } - ss << String(space, ' ') << arrow_str << "Create Table: " << *schema_name_ << "." << index_definition_->ToString(); + ss << String(space, ' ') << arrow_str << "Create Table: " << *base_table_ref_->table_name() << "." + << index_definition_->ToString(); space += arrow_str.size(); return ss.str(); diff --git a/src/planner/node/logical_create_index.cppm b/src/planner/node/logical_create_index.cppm index 25d3f8f86e..e9d2c92578 100644 --- a/src/planner/node/logical_create_index.cppm +++ b/src/planner/node/logical_create_index.cppm @@ -19,6 +19,7 @@ import logical_node_type; import column_binding; import logical_node; import parser; +import base_table_ref; export module logical_create_index; @@ -39,31 +40,19 @@ public: inline String name() override { return "LogicalCreateIndex"; } public: - static inline SharedPtr - Make(u64 node_id, SharedPtr schema_name, SharedPtr table_name, SharedPtr index_def, ConflictType conflict_type) { - return MakeShared(node_id, schema_name, table_name, index_def, conflict_type); - } - - inline LogicalCreateIndex(u64 node_id, - SharedPtr schema_name, - SharedPtr table_name, - SharedPtr index_def, - ConflictType conflict_type) - : LogicalNode(node_id, LogicalNodeType::kCreateIndex), schema_name_(schema_name), table_name_(table_name), index_definition_(index_def), + inline LogicalCreateIndex(u64 node_id, SharedPtr base_table_ref, SharedPtr index_def, ConflictType conflict_type) + : LogicalNode(node_id, LogicalNodeType::kCreateIndex), base_table_ref_(base_table_ref), index_definition_(index_def), conflict_type_(conflict_type) {} public: - [[nodiscard]] inline SharedPtr schema_name() const { return schema_name_; } - - [[nodiscard]] inline SharedPtr table_name() const { return table_name_; } + [[nodiscard]] inline SharedPtr base_table_ref() const { return base_table_ref_; } [[nodiscard]] inline SharedPtr index_definition() const { return index_definition_; } [[nodiscard]] inline ConflictType conflict_type() const { return conflict_type_; } private: - SharedPtr schema_name_{}; - SharedPtr table_name_{}; + SharedPtr base_table_ref_{}; SharedPtr index_definition_{}; ConflictType conflict_type_{ConflictType::kInvalid}; }; diff --git a/src/planner/query_binder.cpp b/src/planner/query_binder.cpp index 113c2ad93d..b9db45022a 100644 --- a/src/planner/query_binder.cpp +++ b/src/planner/query_binder.cpp @@ -913,13 +913,18 @@ void QueryBinder::CheckKnnAndOrderBy(KnnDistanceType distance_type, OrderType or } } +SharedPtr QueryBinder::GetTableRef(const String &db_name, const String &table_name) { + TableReference from_table; + from_table.db_name_ = db_name; + from_table.table_name_ = table_name; + return BuildBaseTable(this->query_context_ptr_, &from_table); +} + UniquePtr QueryBinder::BindDelete(const DeleteStatement &statement) { // refers to QueryBinder::BindSelect UniquePtr bound_delete_statement = BoundDeleteStatement::Make(bind_context_ptr_); - TableReference from_table; - from_table.db_name_ = statement.schema_name_; - from_table.table_name_ = statement.table_name_; - SharedPtr base_table_ref = QueryBinder::BuildBaseTable(this->query_context_ptr_, &from_table); + SharedPtr base_table_ref = GetTableRef(statement.schema_name_, statement.table_name_); + bound_delete_statement->table_ref_ptr_ = base_table_ref; if (base_table_ref.get() == nullptr) { Error(Format("Cannot bind {}.{} to a table", statement.schema_name_, statement.table_name_)); @@ -937,10 +942,8 @@ UniquePtr QueryBinder::BindDelete(const DeleteStatement &s UniquePtr QueryBinder::BindUpdate(const UpdateStatement &statement) { // refers to QueryBinder::BindSelect UniquePtr bound_update_statement = BoundUpdateStatement::Make(bind_context_ptr_); - TableReference from_table; - from_table.db_name_ = statement.schema_name_; - from_table.table_name_ = statement.table_name_; - SharedPtr base_table_ref = QueryBinder::BuildBaseTable(this->query_context_ptr_, &from_table); + SharedPtr base_table_ref = GetTableRef(statement.schema_name_, statement.table_name_); + bound_update_statement->table_ref_ptr_ = base_table_ref; if (base_table_ref.get() == nullptr) { Error(Format("Cannot bind {}.{} to a table", statement.schema_name_, statement.table_name_)); diff --git a/src/planner/query_binder.cppm b/src/planner/query_binder.cppm index 6d57072199..4707027aa5 100644 --- a/src/planner/query_binder.cppm +++ b/src/planner/query_binder.cppm @@ -42,6 +42,8 @@ public: UniquePtr BindUpdate(const UpdateStatement &statement); + SharedPtr GetTableRef(const String &db_name, const String &table_name); + QueryContext *query_context_ptr_; SharedPtr bind_context_ptr_; diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index 7d9ad8f448..c2ad399599 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -31,6 +31,7 @@ import physical_table_scan; import physical_knn_scan; import physical_aggregate; import physical_explain; +import physical_create_index_do; import global_block_id; import knn_expression; @@ -44,6 +45,7 @@ import data_table; import data_block; import physical_merge_knn; import merge_knn_data; +import create_index_data; import logger; import plan_fragment; @@ -58,6 +60,13 @@ UniquePtr MakeTaskStateTemplate(PhysicalOperator *physical_op) { return MakeUnique(); } +UniquePtr MakeCreateIndexDoState(PhysicalCreateIndexDo *physical_create_index_do, FragmentTask *task, FragmentContext *fragment_ctx) { + UniquePtr operator_state = MakeUnique(); + auto *parallel_materialize_fragment_ctx = static_cast(fragment_ctx); + operator_state->create_index_shared_data_ = parallel_materialize_fragment_ctx->create_index_shared_data_.get(); + return operator_state; +} + UniquePtr MakeTableScanState(PhysicalTableScan *physical_table_scan, FragmentTask *task) { SourceState *source_state = task->source_state_.get(); @@ -88,13 +97,13 @@ UniquePtr MakeKnnScanState(PhysicalKnnScan *physical_knn_scan, Fr case FragmentType::kSerialMaterialize: { SerialMaterializedFragmentCtx *serial_materialize_fragment_ctx = static_cast(fragment_ctx); knn_scan_op_state_ptr->knn_scan_function_data_ = - MakeUnique(serial_materialize_fragment_ctx->shared_data_.get(), task->TaskID()); + MakeUnique(serial_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID()); break; } case FragmentType::kParallelMaterialize: { ParallelMaterializedFragmentCtx *parallel_materialize_fragment_ctx = static_cast(fragment_ctx); knn_scan_op_state_ptr->knn_scan_function_data_ = - MakeUnique(parallel_materialize_fragment_ctx->shared_data_.get(), task->TaskID()); + MakeUnique(parallel_materialize_fragment_ctx->knn_scan_shared_data_.get(), task->TaskID()); break; } default: { @@ -218,6 +227,16 @@ MakeTaskState(SizeT operator_id, const Vector &physical_ops, case PhysicalOperatorType::kCreateIndex: { return MakeTaskStateTemplate(physical_ops[operator_id]); } + case PhysicalOperatorType::kCreateIndexPrepare: { + return MakeTaskStateTemplate(physical_ops[operator_id]); + } + case PhysicalOperatorType::kCreateIndexDo: { + auto *physical_create_index_do = static_cast(physical_ops[operator_id]); + return MakeCreateIndexDoState(physical_create_index_do, task, fragment_ctx); + } + case PhysicalOperatorType::kCreateIndexFinish: { + return MakeTaskStateTemplate(physical_ops[operator_id]); + } case PhysicalOperatorType::kCreateCollection: { return MakeTaskStateTemplate(physical_ops[operator_id]); } @@ -454,32 +473,32 @@ SizeT InitKnnScanFragmentContext(PhysicalKnnScan *knn_scan_operator, FragmentCon switch (fragment_context->ContextType()) { case FragmentType::kSerialMaterialize: { SerialMaterializedFragmentCtx *serial_materialize_fragment_ctx = static_cast(fragment_context); - serial_materialize_fragment_ctx->shared_data_ = MakeUnique(knn_scan_operator->base_table_ref_, - knn_scan_operator->filter_expression_, - Move(knn_scan_operator->block_column_entries_), - Move(knn_scan_operator->index_entries_), - Move(knn_expr->opt_params_), - knn_expr->topn_, - knn_expr->dimension_, - 1, - knn_expr->query_embedding_.ptr, - knn_expr->embedding_data_type_, - knn_expr->distance_type_); + serial_materialize_fragment_ctx->knn_scan_shared_data_ = MakeUnique(knn_scan_operator->base_table_ref_, + knn_scan_operator->filter_expression_, + Move(knn_scan_operator->block_column_entries_), + Move(knn_scan_operator->index_entries_), + Move(knn_expr->opt_params_), + knn_expr->topn_, + knn_expr->dimension_, + 1, + knn_expr->query_embedding_.ptr, + knn_expr->embedding_data_type_, + knn_expr->distance_type_); break; } case FragmentType::kParallelMaterialize: { ParallelMaterializedFragmentCtx *parallel_materialize_fragment_ctx = static_cast(fragment_context); - parallel_materialize_fragment_ctx->shared_data_ = MakeUnique(knn_scan_operator->base_table_ref_, - knn_scan_operator->filter_expression_, - Move(knn_scan_operator->block_column_entries_), - Move(knn_scan_operator->index_entries_), - Move(knn_expr->opt_params_), - knn_expr->topn_, - knn_expr->dimension_, - 1, - knn_expr->query_embedding_.ptr, - knn_expr->embedding_data_type_, - knn_expr->distance_type_); + parallel_materialize_fragment_ctx->knn_scan_shared_data_ = MakeUnique(knn_scan_operator->base_table_ref_, + knn_scan_operator->filter_expression_, + Move(knn_scan_operator->block_column_entries_), + Move(knn_scan_operator->index_entries_), + Move(knn_expr->opt_params_), + knn_expr->topn_, + knn_expr->dimension_, + 1, + knn_expr->query_embedding_.ptr, + knn_expr->embedding_data_type_, + knn_expr->distance_type_); break; } default: { @@ -490,65 +509,18 @@ SizeT InitKnnScanFragmentContext(PhysicalKnnScan *knn_scan_operator, FragmentCon return task_n; } -// Allocate tasks for the fragment and determine the sink and source -void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { - i64 parallel_count = cpu_count; - PhysicalOperator *first_operator = this->GetOperators().back(); - switch (first_operator->operator_type()) { - case PhysicalOperatorType::kTableScan: { - auto *table_scan_operator = static_cast(first_operator); - parallel_count = Min(parallel_count, (i64)(table_scan_operator->TaskletCount())); - if (parallel_count == 0) { - parallel_count = 1; - } - break; - } - case PhysicalOperatorType::kKnnScan: { - auto *knn_scan_operator = static_cast(first_operator); - SizeT task_n = InitKnnScanFragmentContext(knn_scan_operator, this, query_context_); - parallel_count = Min(parallel_count, (i64)task_n); - if (parallel_count == 0) { - parallel_count = 1; - } - break; - } - case PhysicalOperatorType::kMatch: - case PhysicalOperatorType::kMergeKnn: - case PhysicalOperatorType::kProjection: { - // Serial Materialize - parallel_count = 1; - break; - } - default: { - break; - } - } +SizeT InitCreateIndexDoFragmentContext(const PhysicalCreateIndexDo *create_index_do_operator, FragmentContext *fragment_ctx) { + auto *table_entry = create_index_do_operator->base_table_ref_->table_entry_ptr_; + // FIXME: to create index on unsealed_segment + SizeT segment_cnt = table_entry->segment_map_.size(); - switch (fragment_type_) { - case FragmentType::kInvalid: { - Error("Invalid fragment type"); - } - case FragmentType::kSerialMaterialize: { - UniqueLock locker(locker_); - parallel_count = 1; - tasks_.reserve(parallel_count); - tasks_.emplace_back(MakeUnique(this, 0, operator_count)); - IncreaseTask(); - break; - } - case FragmentType::kParallelMaterialize: - case FragmentType::kParallelStream: { - UniqueLock locker(locker_); - tasks_.reserve(parallel_count); - for (i64 task_id = 0; task_id < parallel_count; ++task_id) { - tasks_.emplace_back(MakeUnique(this, task_id, operator_count)); - IncreaseTask(); - } - break; - } - } + auto *parallel_materialize_fragment_ctx = static_cast(fragment_ctx); + parallel_materialize_fragment_ctx->create_index_shared_data_ = MakeUnique(table_entry->segment_map_); + return segment_cnt; +} - // Determine which type of source state. +void FragmentContext::MakeSourceState(i64 parallel_count) { + PhysicalOperator *first_operator = this->GetOperators().back(); switch (first_operator->operator_type()) { case PhysicalOperatorType::kInvalid: { Error("Unexpected operator type"); @@ -589,7 +561,8 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { case PhysicalOperatorType::kMergeTop: case PhysicalOperatorType::kMergeSort: case PhysicalOperatorType::kMergeKnn: - case PhysicalOperatorType::kFusion: { + case PhysicalOperatorType::kFusion: + case PhysicalOperatorType::kCreateIndexFinish: { if (fragment_type_ != FragmentType::kSerialMaterialize) { Error( Format("{} should be serial materialized fragment", PhysicalOperatorToString(first_operator->operator_type()))); @@ -602,6 +575,16 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { tasks_[0]->source_state_ = MakeUnique(); break; } + case PhysicalOperatorType::kCreateIndexDo: { + if (fragment_type_ != FragmentType::kParallelMaterialize) { + Error( + Format("{} should in parallel materialized fragment", PhysicalOperatorToString(first_operator->operator_type()))); + } + for (auto &task : tasks_) { + task->source_state_ = MakeUnique(); + } + break; + } case PhysicalOperatorType::kUnionAll: case PhysicalOperatorType::kIntersect: case PhysicalOperatorType::kExcept: @@ -650,6 +633,7 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { case PhysicalOperatorType::kAlter: case PhysicalOperatorType::kCreateTable: case PhysicalOperatorType::kCreateIndex: + case PhysicalOperatorType::kCreateIndexPrepare: case PhysicalOperatorType::kCreateCollection: case PhysicalOperatorType::kCreateDatabase: case PhysicalOperatorType::kCreateView: @@ -680,8 +664,10 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { Error(Format("Unexpected operator type: {}", PhysicalOperatorToString(first_operator->operator_type()))); } } +} - // Determine which type of the sink state. +void FragmentContext::MakeSinkState(i64 parallel_count) { + PhysicalOperator *first_operator = this->GetOperators().back(); PhysicalOperator *last_operator = this->GetOperators().front(); switch (last_operator->operator_type()) { @@ -753,7 +739,9 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { break; } case PhysicalOperatorType::kSort: - case PhysicalOperatorType::kKnnScan: { + case PhysicalOperatorType::kKnnScan: + case PhysicalOperatorType::kCreateIndexPrepare: + case PhysicalOperatorType::kCreateIndexDo: { if (fragment_type_ != FragmentType::kParallelMaterialize && fragment_type_ != FragmentType::kSerialMaterialize) { Error( Format("{} should in parallel/serial materialized fragment", PhysicalOperatorToString(first_operator->operator_type()))); @@ -850,6 +838,7 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { case PhysicalOperatorType::kCommand: case PhysicalOperatorType::kCreateTable: case PhysicalOperatorType::kCreateIndex: + case PhysicalOperatorType::kCreateIndexFinish: case PhysicalOperatorType::kCreateCollection: case PhysicalOperatorType::kCreateDatabase: case PhysicalOperatorType::kCreateView: @@ -878,6 +867,77 @@ void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { } } +// Allocate tasks for the fragment and determine the sink and source +void FragmentContext::CreateTasks(i64 cpu_count, i64 operator_count) { + i64 parallel_count = cpu_count; + PhysicalOperator *first_operator = this->GetOperators().back(); + switch (first_operator->operator_type()) { + case PhysicalOperatorType::kTableScan: { + auto *table_scan_operator = static_cast(first_operator); + parallel_count = Min(parallel_count, (i64)(table_scan_operator->TaskletCount())); + if (parallel_count == 0) { + parallel_count = 1; + } + break; + } + case PhysicalOperatorType::kKnnScan: { + auto *knn_scan_operator = static_cast(first_operator); + SizeT task_n = InitKnnScanFragmentContext(knn_scan_operator, this, query_context_); + parallel_count = Min(parallel_count, (i64)task_n); + if (parallel_count == 0) { + parallel_count = 1; + } + break; + } + case PhysicalOperatorType::kMatch: + case PhysicalOperatorType::kMergeKnn: + case PhysicalOperatorType::kProjection: { + // Serial Materialize + parallel_count = 1; + break; + } + case PhysicalOperatorType::kCreateIndexDo: { + const auto *create_index_do_operator = static_cast(first_operator); + SizeT segment_n = InitCreateIndexDoFragmentContext(create_index_do_operator, this); + parallel_count = Max(parallel_count, 1l); + break; + } + default: { + break; + } + } + + switch (fragment_type_) { + case FragmentType::kInvalid: { + Error("Invalid fragment type"); + } + case FragmentType::kSerialMaterialize: { + UniqueLock locker(locker_); + parallel_count = 1; + tasks_.reserve(parallel_count); + tasks_.emplace_back(MakeUnique(this, 0, operator_count)); + IncreaseTask(); + break; + } + case FragmentType::kParallelMaterialize: + case FragmentType::kParallelStream: { + UniqueLock locker(locker_); + tasks_.reserve(parallel_count); + for (i64 task_id = 0; task_id < parallel_count; ++task_id) { + tasks_.emplace_back(MakeUnique(this, task_id, operator_count)); + IncreaseTask(); + } + break; + } + } + + // Determine which type of source state. + MakeSourceState(parallel_count); + + // Determine which type of the sink state. + MakeSinkState(parallel_count); +} + SharedPtr SerialMaterializedFragmentCtx::GetResultInternal() { // Only one sink state if (tasks_.size() != 1) { diff --git a/src/scheduler/fragment_context.cppm b/src/scheduler/fragment_context.cppm index 8d2f3743a9..807c5caf83 100644 --- a/src/scheduler/fragment_context.cppm +++ b/src/scheduler/fragment_context.cppm @@ -24,6 +24,7 @@ import physical_sink; import data_table; import data_block; import knn_scan_data; +import create_index_data; export module fragment_context; @@ -96,6 +97,11 @@ public: [[nodiscard]] inline FragmentType ContextType() const { return fragment_type_; } +private: + void MakeSourceState(i64 parallel_count); + + void MakeSinkState(i64 parallel_count); + protected: virtual SharedPtr GetResultInternal() = 0; @@ -128,7 +134,7 @@ public: SharedPtr GetResultInternal() final; public: - UniquePtr shared_data_{}; + UniquePtr knn_scan_shared_data_{}; }; export class ParallelMaterializedFragmentCtx final : public FragmentContext { @@ -141,7 +147,9 @@ public: SharedPtr GetResultInternal() final; public: - UniquePtr shared_data_{}; + UniquePtr knn_scan_shared_data_{}; + + UniquePtr create_index_shared_data_{}; protected: HashMap>> task_results_{}; diff --git a/src/scheduler/fragment_data.cppm b/src/scheduler/fragment_data.cppm index d6323a060f..d8dfcebbdb 100644 --- a/src/scheduler/fragment_data.cppm +++ b/src/scheduler/fragment_data.cppm @@ -21,13 +21,40 @@ export module fragment_data; namespace infinity { -export struct FragmentData { - UniquePtr data_block_{}; - UniquePtr error_message_{}; +export enum class FragmentDataType { + kData, + kNone, + kError, + kInvalid, +}; + +export struct FragmentDataBase { + FragmentDataType type_{FragmentDataType::kInvalid}; u64 fragment_id_{u64_max}; + + FragmentDataBase(FragmentDataType type, u64 fragment_id) : type_(type), fragment_id_(fragment_id) {} +}; + +export struct FragmentError : public FragmentDataBase { + UniquePtr error_message_{}; + + FragmentError(u64 fragment_id, UniquePtr error_message) + : FragmentDataBase(FragmentDataType::kError, fragment_id), error_message_(Move(error_message)) {} +}; + +export struct FragmentData : public FragmentDataBase { + UniquePtr data_block_{}; i64 task_id_{-1}; SizeT data_idx_{u64_max}; SizeT data_count_{u64_max}; + + FragmentData(u64 fragment_id, UniquePtr data_block, i64 task_id, SizeT data_idx, SizeT data_count) + : FragmentDataBase(FragmentDataType::kData, fragment_id), data_block_(Move(data_block)), task_id_(task_id), data_idx_(data_idx), + data_count_(data_count) {} +}; + +export struct FragmentNone : public FragmentDataBase { + FragmentNone(u64 fragment_id) : FragmentDataBase(FragmentDataType::kNone, fragment_id) {} }; } // namespace infinity diff --git a/src/scheduler/task_scheduler.cpp b/src/scheduler/task_scheduler.cpp index deb58d29cc..47b6edf2b1 100644 --- a/src/scheduler/task_scheduler.cpp +++ b/src/scheduler/task_scheduler.cpp @@ -30,6 +30,7 @@ import query_context; import plan_fragment; import fragment_context; import default_values; +import physical_operator_type; module task_scheduler; @@ -93,7 +94,20 @@ void TaskScheduler::Schedule(QueryContext *query_context, const VectorGetOperators().empty()) { + Error("Empty fragment"); + } + auto *last_operator = plan_fragment->GetOperators()[0]; + switch (last_operator->operator_type()) { + case PhysicalOperatorType::kCreateIndexFinish: { + ScheduleRoundRobin(query_context, tasks, plan_fragment); + break; + } + default: { + ScheduleOneWorkerIfPossible(query_context, tasks, plan_fragment); + break; + } + } } void TaskScheduler::ScheduleOneWorkerPerQuery(QueryContext *query_context, const Vector &tasks, PlanFragment *plan_fragment) { @@ -138,9 +152,11 @@ void TaskScheduler::ScheduleOneWorkerIfPossible(QueryContext *query_context, con void TaskScheduler::ScheduleRoundRobin(QueryContext *query_context, const Vector &tasks, PlanFragment *plan_fragment) { LOG_TRACE(Format("Schedule {} tasks of query id: {} into scheduler with RR policy", tasks.size(), query_context->query_id())); + u64 worker_id = 0; for (const auto &fragment_task : tasks) { - u64 worker_id = ProposedWorkerID(worker_count_); ScheduleTask(fragment_task, worker_id); + worker_id++; + worker_id %= worker_count_; } } diff --git a/src/storage/knnindex/knn_hnsw/dist_func_l2.cppm b/src/storage/knnindex/knn_hnsw/dist_func_l2.cppm index b9e97aeed5..15ac7f6952 100644 --- a/src/storage/knnindex/knn_hnsw/dist_func_l2.cppm +++ b/src/storage/knnindex/knn_hnsw/dist_func_l2.cppm @@ -45,13 +45,13 @@ public: if (dim % 16 == 0) { SIMDFunc = F32L2AVX512; } else { - SIMDFunc = F32L2BF; + SIMDFunc = F32L2AVX512Residual; } #elif defined(USE_AVX) if (dim % 16 == 0) { SIMDFunc = F32L2AVX; } else { - SIMDFunc = F32L2BF; + SIMDFunc = F32L2AVXResidual; } #else SIMDFunc = F32L2BF; @@ -101,13 +101,13 @@ public: if (dim % 64 == 0) { SIMDFunc = I8IPAVX512; } else { - SIMDFunc = I8IPBF; + SIMDFunc = I8IPAVX512Residual; } #elif defined(USE_AVX) if (dim % 32 == 0) { SIMDFunc = I8IPAVX; } else { - SIMDFunc = I8IPBF; + SIMDFunc = I8IPAVXResidual; } #else SIMDFunc = I8IPBF; diff --git a/src/storage/knnindex/knn_hnsw/graph_store.cppm b/src/storage/knnindex/knn_hnsw/graph_store.cppm index af76768c2b..712d3b794a 100644 --- a/src/storage/knnindex/knn_hnsw/graph_store.cppm +++ b/src/storage/knnindex/knn_hnsw/graph_store.cppm @@ -103,6 +103,7 @@ private: graph_(static_cast(operator new[](max_vertex * level0_size_, std::align_val_t(8)))), // loaded_vertex_n_(loaded_vertex_n), // loaded_layers_(loaded_layers) // + // {} void Init() { @@ -261,6 +262,9 @@ public: VertexType neighbor_idx = neighbors[i]; assert(neighbor_idx < cur_vertex_n && neighbor_idx >= 0); assert(neighbor_idx != vertex_i); + + int n_layer = GetLevel0(neighbor_idx).GetLayers().second; + assert(n_layer >= layer_i); } } } @@ -283,7 +287,13 @@ public: os << "layer " << layer << std::endl; for (VertexType vertex_i : layer2vertex[layer]) { os << vertex_i << ": "; - auto [neighbors, neighbor_n] = GetLevel0(vertex_i).GetNeighbors(); + const int *neighbors = nullptr; + int neighbor_n = 0; + if (layer == 0) { + std::tie(neighbors, neighbor_n) = GetLevel0(vertex_i).GetNeighbors(); + } else { + std::tie(neighbors, neighbor_n) = GetLevelX(GetLevel0(vertex_i), layer).GetNeighbors(); + } for (int i = 0; i < neighbor_n; ++i) { os << neighbors[i] << ", "; } diff --git a/src/storage/knnindex/knn_hnsw/hnsw_alg.cppm b/src/storage/knnindex/knn_hnsw/hnsw_alg.cppm index 683e666424..c637e7ae64 100644 --- a/src/storage/knnindex/knn_hnsw/hnsw_alg.cppm +++ b/src/storage/knnindex/knn_hnsw/hnsw_alg.cppm @@ -13,6 +13,7 @@ // limitations under the License. module; +#include #include #include @@ -49,8 +50,8 @@ public: using PDV = Pair; using CMP = CompareByFirst; + using CMPReverse = CompareByFirstReverse; using DistHeap = Heap; - using HnswLabelType = LabelType; constexpr static int prefetch_offset_ = 0; constexpr static int prefetch_step_ = 2; @@ -71,6 +72,9 @@ private: Distance distance_; const UniquePtr labels_; + Mutex global_mutex_; + mutable Vector vertex_mutex_; + private: KnnHnsw(SizeT M, SizeT Mmax, @@ -87,7 +91,8 @@ private: data_store_(Move(data_store)), // graph_store_(Move(graph_store)), // distance_(Move(distance)), // - labels_(Move(labels)) { + labels_(Move(labels)), // + vertex_mutex_(data_store_.max_vec_num()) { if (ef == 0) { ef = ef_construction_; } @@ -126,83 +131,19 @@ private: return static_cast(r); } - template - requires DataIteratorConcept - VertexType StoreData(Iterator iter, const LabelType *labels, SizeT insert_n) { - auto ret = data_store_.AddVec(iter, insert_n); - if (ret == DataStore::ERR_IDX) { - Error("Data index is not enough."); - } - std::copy(labels, labels + insert_n, labels_.get() + ret); - return ret; - } - -public: // return the nearest `ef_construction_` neighbors of `query` in layer `layer_idx` - // return value is a max heap of distance - DistHeap SearchLayer(VertexType enter_point, const StoreType &query, i32 layer_idx, SizeT candidate_n) const { - DistHeap result; + template + Vector SearchLayer(VertexType enter_point, const StoreType &query, i32 layer_idx, SizeT result_n, const Bitmask &bitmask = Bitmask()) const { + Vector result; DistHeap candidate; data_store_.Prefetch(enter_point); - DataType dist = distance_(query, data_store_.GetVec(enter_point), data_store_); - + // enter_point will not be added to result_handler, the distance is not used + auto dist = bitmask.IsTrue(enter_point) ? distance_(query, data_store_.GetVec(enter_point), data_store_) : 0; candidate.emplace(-dist, enter_point); - result.emplace(dist, enter_point); - - Vector visited(data_store_.cur_vec_num(), false); - visited[enter_point] = true; - - while (!candidate.empty()) { - const auto [minus_c_dist, c_idx] = candidate.top(); - candidate.pop(); - if (-minus_c_dist > result.top().first && result.size() == candidate_n) { - break; - } - const auto [neighbors_p, neighbor_size] = graph_store_.GetNeighbors(c_idx, layer_idx); - int prefetch_start = neighbor_size - 1 - prefetch_offset_; - for (int i = neighbor_size - 1; i >= 0; --i) { - VertexType n_idx = neighbors_p[i]; - if (visited[n_idx]) { - continue; - } - visited[n_idx] = true; - if (prefetch_start >= 0) { - int lower = Max(0, prefetch_start - prefetch_step_); - for (int i = prefetch_start; i >= lower; --i) { - data_store_.Prefetch(neighbors_p[i]); - } - prefetch_start -= prefetch_step_; - } - dist = distance_(query, data_store_.GetVec(n_idx), data_store_); - if (dist < result.top().first || result.size() < candidate_n) { - candidate.emplace(-dist, n_idx); - result.emplace(dist, n_idx); - if (result.size() > candidate_n) { - result.pop(); - } - } - } - } - return result; - } - - DistHeap SearchLayer(VertexType enter_point, const StoreType &query, i32 layer_idx, SizeT candidate_n, const Bitmask &bitmask) const { - if (bitmask.IsAllTrue()) { - return SearchLayer(enter_point, query, layer_idx, candidate_n); - } - DistHeap result; - DistHeap candidate; - - DataType dist{}; if (bitmask.IsTrue(enter_point)) { - data_store_.Prefetch(enter_point); - dist = distance_(query, data_store_.GetVec(enter_point), data_store_); - - candidate.emplace(-dist, enter_point); - result.emplace(dist, enter_point); - } else { - candidate.emplace(LimitMax(), enter_point); + result.emplace_back(dist, enter_point); + std::push_heap(result.begin(), result.end(), CMP()); } Vector visited(data_store_.cur_vec_num(), false); @@ -211,118 +152,15 @@ public: while (!candidate.empty()) { const auto [minus_c_dist, c_idx] = candidate.top(); candidate.pop(); - if (result.size() == candidate_n && -minus_c_dist > result.top().first) { + if (result.size() == result_n && -minus_c_dist > result[0].first) { break; } - const auto [neighbors_p, neighbor_size] = graph_store_.GetNeighbors(c_idx, layer_idx); - int prefetch_start = neighbor_size - 1 - prefetch_offset_; - for (int i = neighbor_size - 1; i >= 0; --i) { - VertexType n_idx = neighbors_p[i]; - if (visited[n_idx]) { - continue; - } - visited[n_idx] = true; - if (prefetch_start >= 0) { - int lower = Max(0, prefetch_start - prefetch_step_); - for (int i = prefetch_start; i >= lower; --i) { - data_store_.Prefetch(neighbors_p[i]); - } - prefetch_start -= prefetch_step_; - } - dist = distance_(query, data_store_.GetVec(n_idx), data_store_); - if (result.size() < candidate_n || dist < result.top().first) { - candidate.emplace(-dist, n_idx); - if (bitmask.IsTrue(n_idx)) { - result.emplace(dist, n_idx); - if (result.size() > candidate_n) { - result.pop(); - } - } - } - } - } - return result; - } - - Pair, UniquePtr>> - SearchLayerReturnPair(VertexType enter_point, const StoreType &query, i32 layer_idx, SizeT candidate_n) const { - auto d_ptr = MakeUniqueForOverwrite(candidate_n); - auto i_ptr = MakeUniqueForOverwrite(candidate_n); - HeapResultHandler> result_handler(1, candidate_n, d_ptr.get(), i_ptr.get()); - result_handler.Begin(); - DistHeap candidate; - data_store_.Prefetch(enter_point); - auto dist = distance_(query, data_store_.GetVec(enter_point), data_store_); - - candidate.emplace(-dist, enter_point); - result_handler.AddResult(0, dist, enter_point); - - Vector visited(data_store_.cur_vec_num(), false); - visited[enter_point] = true; - - while (!candidate.empty()) { - const auto [minus_c_dist, c_idx] = candidate.top(); - candidate.pop(); - if (result_handler.GetSize(0) == candidate_n && -minus_c_dist > result_handler.GetDistance0(0)) { - break; + SharedLock lock; + if constexpr (WithLock) { + lock = SharedLock(vertex_mutex_[c_idx]); } - const auto [neighbors_p, neighbor_size] = graph_store_.GetNeighbors(c_idx, layer_idx); - int prefetch_start = neighbor_size - 1 - prefetch_offset_; - for (int i = neighbor_size - 1; i >= 0; --i) { - VertexType n_idx = neighbors_p[i]; - if (visited[n_idx]) { - continue; - } - visited[n_idx] = true; - if (prefetch_start >= 0) { - int lower = Max(0, prefetch_start - prefetch_step_); - for (int i = prefetch_start; i >= lower; --i) { - data_store_.Prefetch(neighbors_p[i]); - } - prefetch_start -= prefetch_step_; - } - auto dist = distance_(query, data_store_.GetVec(n_idx), data_store_); - if (result_handler.GetSize(0) < candidate_n || dist < result_handler.GetDistance0(0)) { - candidate.emplace(-dist, n_idx); - result_handler.AddResult(0, dist, n_idx); - } - } - } - result_handler.EndWithoutSort(); - return {result_handler.GetSize(0), MakePair(Move(d_ptr), Move(i_ptr))}; - } - - Pair, UniquePtr>> - SearchLayerReturnPair(VertexType enter_point, const StoreType &query, i32 layer_idx, SizeT candidate_n, const Bitmask &bitmask) const { - if (bitmask.IsAllTrue()) { - return SearchLayerReturnPair(enter_point, query, layer_idx, candidate_n); - } - auto d_ptr = MakeUniqueForOverwrite(candidate_n); - auto v_ptr = MakeUniqueForOverwrite(candidate_n); - HeapResultHandler> result_handler(1, candidate_n, d_ptr.get(), v_ptr.get()); - result_handler.Begin(); - DistHeap candidate; - - if (bitmask.IsTrue(enter_point)) { - data_store_.Prefetch(enter_point); - auto dist = distance_(query, data_store_.GetVec(enter_point), data_store_); - candidate.emplace(-dist, enter_point); - result_handler.AddResult(0, dist, enter_point); - } else { - candidate.emplace(LimitMax(), enter_point); - } - - Vector visited(data_store_.cur_vec_num(), false); - visited[enter_point] = true; - - while (!candidate.empty()) { - const auto [minus_c_dist, c_idx] = candidate.top(); - candidate.pop(); - if (result_handler.GetSize(0) == candidate_n && -minus_c_dist > result_handler.GetDistance0(0)) { - break; - } const auto [neighbors_p, neighbor_size] = graph_store_.GetNeighbors(c_idx, layer_idx); int prefetch_start = neighbor_size - 1 - prefetch_offset_; for (int i = neighbor_size - 1; i >= 0; --i) { @@ -339,24 +177,36 @@ public: prefetch_start -= prefetch_step_; } auto dist = distance_(query, data_store_.GetVec(n_idx), data_store_); - if (result_handler.GetSize(0) < candidate_n || dist < result_handler.GetDistance0(0)) { + if (result.size() < result_n || dist < result[0].first) { candidate.emplace(-dist, n_idx); if (bitmask.IsTrue(n_idx)) { - result_handler.AddResult(0, dist, n_idx); + if (result.size() == result_n) { + std::pop_heap(result.begin(), result.end(), CMP()); + result.pop_back(); + } + result.emplace_back(dist, n_idx); + std::push_heap(result.begin(), result.end(), CMP()); } } } } - result_handler.EndWithoutSort(); - return {result_handler.GetSize(0), MakePair(Move(d_ptr), Move(v_ptr))}; + + return result; } + template VertexType SearchLayerNearest(VertexType enter_point, const StoreType &query, i32 layer_idx) const { VertexType cur_p = enter_point; DataType cur_dist = distance_(query, data_store_.GetVec(cur_p), data_store_); bool check = true; while (check) { check = false; + + SharedLock lock; + if constexpr (WithLock) { + lock = SharedLock(vertex_mutex_[cur_p]); + } + const auto [neighbors_p, neighbor_size] = graph_store_.GetNeighbors(cur_p, layer_idx); for (int i = neighbor_size - 1; i >= 0; --i) { VertexType n_idx = neighbors_p[i]; @@ -371,24 +221,25 @@ public: return cur_p; } - // DistHeap is the min heap whose key is the minus distance to query - // result distance is increasing - void SelectNeighborsHeuristic(DistHeap &candidates, SizeT M, VertexType *result_p, VertexListSize *result_size_p) const { + // the function does not need mutex because the lock of `result_p` is already acquired + void SelectNeighborsHeuristic(Vector candidates, SizeT M, VertexType *result_p, VertexListSize *result_size_p) const { VertexListSize result_size = 0; - if (SizeT c_size = candidates.size(); c_size < M) { - while (!candidates.empty()) { - result_p[result_size++] = candidates.top().second; - candidates.pop(); + if (candidates.size() < M) { + // std::sort(candidates.begin(), candidates.end(), CMPReverse()); + for (const auto &[_, idx] : candidates) { + result_p[result_size++] = idx; } } else { + std::make_heap(candidates.begin(), candidates.end(), CMPReverse()); while (!candidates.empty() && SizeT(result_size) < M) { - const auto &[minus_c_dist, c_idx] = candidates.top(); + std::pop_heap(candidates.begin(), candidates.end(), CMPReverse()); + const auto &[c_dist, c_idx] = candidates.back(); StoreType c_data = data_store_.GetVec(c_idx); bool check = true; for (SizeT i = 0; i < SizeT(result_size); ++i) { VertexType r_idx = result_p[i]; DataType cr_dist = distance_(c_data, data_store_.GetVec(r_idx), data_store_); - if (cr_dist < -minus_c_dist) { + if (cr_dist < c_dist) { check = false; break; } @@ -396,15 +247,23 @@ public: if (check) { result_p[result_size++] = c_idx; } - candidates.pop(); + candidates.pop_back(); } + // std::reverse(result_p, result_p + result_size); } *result_size_p = result_size; } + template void ConnectNeighbors(VertexType vertex_i, const VertexType *q_neighbors_p, VertexListSize q_neighbor_size, i32 layer_idx) { for (int i = 0; i < q_neighbor_size; ++i) { VertexType n_idx = q_neighbors_p[i]; + + UniqueLock lock; + if constexpr (WithLock) { + lock = UniqueLock(vertex_mutex_[n_idx]); + } + auto [n_neighbors_p, n_neighbor_size_p] = graph_store_.GetNeighborsMut(n_idx, layer_idx); VertexListSize n_neighbor_size = *n_neighbor_size_p; SizeT Mmax = layer_idx == 0 ? Mmax0_ : Mmax_; @@ -416,124 +275,111 @@ public: StoreType n_data = data_store_.GetVec(n_idx); DataType n_dist = distance_(n_data, data_store_.GetVec(vertex_i), data_store_); - Vector tmp; - tmp.reserve(n_neighbor_size + 1); - tmp.emplace_back(-n_dist, vertex_i); + Vector candidates; + candidates.reserve(n_neighbor_size + 1); + candidates.emplace_back(n_dist, vertex_i); for (int i = 0; i < n_neighbor_size; ++i) { - tmp.emplace_back(-distance_(n_data, data_store_.GetVec(n_neighbors_p[i]), data_store_), n_neighbors_p[i]); + candidates.emplace_back(distance_(n_data, data_store_.GetVec(n_neighbors_p[i]), data_store_), n_neighbors_p[i]); } - DistHeap candidates(tmp.begin(), tmp.end()); - SelectNeighborsHeuristic(candidates, Mmax, n_neighbors_p, n_neighbor_size_p); // write in memory + SelectNeighborsHeuristic(Move(candidates), Mmax, n_neighbors_p, n_neighbor_size_p); // write in memory } } +public: + // This function will be removed template requires DataIteratorConcept - void InsertVecs(Iterator query_iter, const LabelType *labels, SizeT insert_n) { - const VertexType vertex_i1 = StoreData(Move(query_iter), labels, insert_n); - for (SizeT i = 0; i < insert_n; ++i) { - StoreType query = data_store_.GetVec(vertex_i1 + i); - i32 q_layer = GenerateRandomLayer(); - i32 max_layer = graph_store_.max_layer(); - VertexType ep = graph_store_.enterpoint(); - VertexType vertex_i = vertex_i1 + i; - graph_store_.AddVertex(vertex_i, q_layer); - - for (i32 cur_layer = max_layer; cur_layer > q_layer; --cur_layer) { - ep = SearchLayerNearest(ep, query, cur_layer); - } - for (i32 cur_layer = Min(q_layer, max_layer); cur_layer >= 0; --cur_layer) { - DistHeap search_result = SearchLayer(ep, query, cur_layer, ef_construction_); // TODO:: use pool - DistHeap candidates; - while (!search_result.empty()) { - const auto &[dist, idx] = search_result.top(); - candidates.emplace(-dist, idx); - search_result.pop(); - } - const auto [q_neighbors_p, q_neighbor_size_p] = graph_store_.GetNeighborsMut(vertex_i, cur_layer); - SelectNeighborsHeuristic(candidates, M_, q_neighbors_p, q_neighbor_size_p); - ep = q_neighbors_p[0]; - ConnectNeighbors(vertex_i, q_neighbors_p, *q_neighbor_size_p, cur_layer); - } - if (i && i % 10000 == 0) { - std::cout << "Inserted " << i << " / " << insert_n << std::endl; - } + void InsertVecs(Iterator iter, const LabelType *labels, SizeT insert_n) { + VertexType start_i = StoreData(iter, labels, insert_n); + for (VertexType vertex_i = start_i; vertex_i < VertexType(start_i + insert_n); ++vertex_i) { + Build(vertex_i); } } - // this two interface is for test and benchmark - void Insert(const DataType *queries, const LabelType *labels, SizeT insert_n) { - InsertVecs(DenseVectorIter(queries, data_store_.dim(), insert_n), labels, insert_n); + // This function for test + void InsertVecs(const DataType *query, const LabelType *labels, SizeT insert_n) { + VertexType start_i = StoreDataRaw(query, labels, insert_n); + for (VertexType vertex_i = start_i; vertex_i < VertexType(start_i + insert_n); ++vertex_i) { + Build(vertex_i); + } } - void Insert(const DataType *query, LabelType label) { Insert(query, &label, 1); } - MaxHeap> KnnSearch(const DataType *q, SizeT k) const { - auto query = data_store_.MakeQuery(q); - VertexType ep = graph_store_.enterpoint(); - for (i32 cur_layer = graph_store_.max_layer(); cur_layer > 0; --cur_layer) { - ep = SearchLayerNearest(ep, query, cur_layer); + template + requires DataIteratorConcept + VertexType StoreData(Iterator iter, const LabelType *labels, SizeT insert_n) { + auto ret = data_store_.AddVec(iter, insert_n); + if (ret == DataStore::ERR_IDX) { + Error("Data index is not enough."); } - DistHeap search_result = SearchLayer(ep, query, 0, Max(k, ef_)); - while (search_result.size() > k) { - search_result.pop(); + std::copy(labels, labels + insert_n, labels_.get() + ret); + return ret; + } + + VertexType StoreDataRaw(const DataType *query, const LabelType *labels, SizeT insert_n) { + return StoreData(DenseVectorIter(query, data_store_.dim(), insert_n), labels, insert_n); + } + + VertexType StoreDataRaw(const DataType *query, LabelType label) { return StoreDataRaw(query, &label, 1); } + + template + void Build(VertexType vertex_i) { + UniqueLock global_lock; + if constexpr (WithLock) { + global_lock = UniqueLock(global_mutex_); } - MaxHeap> result; // TODO:: reserve - while (!search_result.empty()) { - const auto &[dist, idx] = search_result.top(); - result.emplace(dist, labels_[idx]); - search_result.pop(); + + i32 q_layer = GenerateRandomLayer(); + i32 max_layer = graph_store_.max_layer(); + if constexpr (WithLock) { + if (q_layer <= max_layer) { + global_lock.unlock(); + } } - return result; - } - MaxHeap> KnnSearch(const DataType *q, SizeT k, const Bitmask &bitmask) const { - auto query = data_store_.MakeQuery(q); - VertexType ep = graph_store_.enterpoint(); - for (i32 cur_layer = graph_store_.max_layer(); cur_layer > 0; --cur_layer) { - ep = SearchLayerNearest(ep, query, cur_layer); + UniqueLock lock; + if constexpr (WithLock) { + lock = UniqueLock(vertex_mutex_[vertex_i]); } - DistHeap search_result = SearchLayer(ep, query, 0, Max(k, ef_), bitmask); - while (search_result.size() > k) { - search_result.pop(); + StoreType query = data_store_.GetVec(vertex_i); + + VertexType ep = graph_store_.enterpoint(); + graph_store_.AddVertex(vertex_i, q_layer); + + for (i32 cur_layer = max_layer; cur_layer > q_layer; --cur_layer) { + ep = SearchLayerNearest(ep, query, cur_layer); } - MaxHeap> result; // TODO:: reserve - while (!search_result.empty()) { - const auto &[dist, idx] = search_result.top(); - result.emplace(dist, labels_[idx]); - search_result.pop(); + for (i32 cur_layer = Min(q_layer, max_layer); cur_layer >= 0; --cur_layer) { + Vector search_result = SearchLayer(ep, query, cur_layer, ef_construction_); + + const auto [q_neighbors_p, q_neighbor_size_p] = graph_store_.GetNeighborsMut(vertex_i, cur_layer); + SelectNeighborsHeuristic(Move(search_result), M_, q_neighbors_p, q_neighbor_size_p); + ep = q_neighbors_p[0]; + ConnectNeighbors(vertex_i, q_neighbors_p, *q_neighbor_size_p, cur_layer); } - return result; } - Pair, UniquePtr>> KnnSearchReturnPair(const DataType *q, SizeT k, const Bitmask &bitmask) const { + template + Vector> KnnSearch(const DataType *q, SizeT k, const Bitmask &bitmask = Bitmask()) const { auto query = data_store_.MakeQuery(q); VertexType ep = graph_store_.enterpoint(); for (i32 cur_layer = graph_store_.max_layer(); cur_layer > 0; --cur_layer) { - ep = SearchLayerNearest(ep, query, cur_layer); + ep = SearchLayerNearest(ep, query, cur_layer); } - auto search_result = SearchLayerReturnPair(ep, query, 0, Max(k, ef_), bitmask); - auto &[result_size, unique_ptr_pair] = search_result; - auto &[d_ptr, v_ptr] = unique_ptr_pair; - UniquePtr l_ptr; - if constexpr (sizeof(LabelType) == sizeof(VertexType)) { - auto label_ptr = reinterpret_cast(v_ptr.get()); - for (SizeT i = 0; i < result_size; ++i) { - label_ptr[i] = labels_[v_ptr[i]]; - } - l_ptr.reset(reinterpret_cast(v_ptr.release())); - } else { - l_ptr = MakeUniqueForOverwrite(result_size); - for (SizeT i = 0; i < result_size; ++i) { - l_ptr[i] = labels_[v_ptr[i]]; - } + Vector search_result = SearchLayer(ep, query, 0, Max(k, ef_), bitmask); + + std::sort_heap(search_result.begin(), search_result.end(), CMP()); + Vector> result; + for (SizeT i = 0; i < Min(k, search_result.size()); ++i) { + result.emplace_back(search_result[i].first, labels_[search_result[i].second]); } - return {result_size, MakePair(Move(d_ptr), Move(l_ptr))}; + return result; } -public: void SetEf(SizeT ef) { ef_ = ef; } + SizeT GetVertexNum() const { return data_store_.cur_vec_num(); } + void Save(FileHandler &file_handler) { file_handler.Write(&M_, sizeof(M_)); file_handler.Write(&ef_construction_, sizeof(ef_construction_)); diff --git a/src/storage/knnindex/knn_hnsw/hnsw_common.cppm b/src/storage/knnindex/knn_hnsw/hnsw_common.cppm index 51b6876eb1..9e9ba58dc4 100644 --- a/src/storage/knnindex/knn_hnsw/hnsw_common.cppm +++ b/src/storage/knnindex/knn_hnsw/hnsw_common.cppm @@ -124,6 +124,7 @@ public: }; export using VertexType = i32; +export using AtomicVtxType = ai32; export using VertexListSize = i32; export using LayerSize = i32; diff --git a/src/storage/knnindex/knn_hnsw/simd_func.cppm b/src/storage/knnindex/knn_hnsw/simd_func.cppm index 22852205c0..0fb8e5eadf 100644 --- a/src/storage/knnindex/knn_hnsw/simd_func.cppm +++ b/src/storage/knnindex/knn_hnsw/simd_func.cppm @@ -40,6 +40,14 @@ void log_m256(const __m256i &value) { std::cout << "]" << std::endl; } +export int32_t I8IPBF(const int8_t *pv1, const int8_t *pv2, size_t dim) { + int32_t res = 0; + for (size_t i = 0; i < dim; i++) { + res += (int16_t)(pv1[i]) * pv2[i]; + } + return res; +} + #if defined(USE_AVX512) export int32_t I8IPAVX512(const int8_t *pv1, const int8_t *pv2, size_t dim) { size_t dim64 = dim >> 6; @@ -66,6 +74,10 @@ export int32_t I8IPAVX512(const int8_t *pv1, const int8_t *pv2, size_t dim) { // Reduce add return _mm512_reduce_add_epi32(sum); } + +export int32_t I8IPAVX512Residual(const int8_t *pv1, const int8_t *pv2, size_t dim) { + return I8IPAVX512(pv1, pv2, dim) + I8IPBF(pv1 + (dim & ~63), pv2 + (dim & ~63), dim & 63); +} #endif #if defined(USE_AVX) @@ -98,18 +110,24 @@ export int32_t I8IPAVX(const int8_t *pv1, const int8_t *pv2, size_t dim) { // Extract the result return _mm256_extract_epi32(sum, 0) + _mm256_extract_epi32(sum, 4); } + +export int32_t I8IPAVXResidual(const int8_t *pv1, const int8_t *pv2, size_t dim) { + return I8IPAVX(pv1, pv2, dim) + I8IPBF(pv1 + (dim & ~31), pv2 + (dim & ~31), dim & 31); +} + #endif -export int32_t I8IPBF(const int8_t *pv1, const int8_t *pv2, size_t dim) { - int32_t res = 0; +//------------------------------//------------------------------//------------------------------ + +export float F32L2BF(const float *pv1, const float *pv2, size_t dim) { + float res = 0; for (size_t i = 0; i < dim; i++) { - res += (int16_t)(pv1[i]) * pv2[i]; + float t = pv1[i] - pv2[i]; + res += t * t; } return res; } -//------------------------------//------------------------------//------------------------------ - #if defined(USE_AVX512) export float F32L2AVX512(const float *pv1, const float *pv2, size_t dim) { @@ -138,6 +156,10 @@ export float F32L2AVX512(const float *pv1, const float *pv2, size_t dim) { return (res); } +export float F32L2AVX512Residual(const float *pv1, const float *pv2, size_t dim) { + return F32L2AVX512(pv1, pv2, dim) + F32L2BF(pv1 + (dim & ~15), pv2 + (dim & ~15), dim & 15); +} + #endif #if defined(USE_AVX) @@ -171,18 +193,22 @@ export float F32L2AVX(const float *pv1, const float *pv2, size_t dim) { return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; } +export float F32L2AVXResidual(const float *pv1, const float *pv2, size_t dim) { + return F32L2AVX(pv1, pv2, dim) + F32L2BF(pv1 + (dim & ~15), pv2 + (dim & ~15), dim & 15); +} + #endif -export float F32L2BF(const float *pv1, const float *pv2, size_t dim) { +//------------------------------//------------------------------//------------------------------ + +export float F32IPBF(const float *pv1, const float *pv2, size_t dim) { float res = 0; for (size_t i = 0; i < dim; i++) { - float t = pv1[i] - pv2[i]; - res += t * t; + res += pv1[i] * pv2[i]; } return res; } -//------------------------------//------------------------------//------------------------------ #if defined(USE_AVX512) export float F32IPAVX512(const float *pVect1, const float *pVect2, SizeT qty) { @@ -211,6 +237,10 @@ export float F32IPAVX512(const float *pVect1, const float *pVect2, SizeT qty) { return sum; } +export float F32IPAVX512Residual(const float *pVect1, const float *pVect2, SizeT qty) { + return F32IPAVX512(pVect1, pVect2, qty) + F32IPBF(pVect1 + (qty & ~15), pVect2 + (qty & ~15), qty & 15); +} + #endif #if defined(USE_AVX) @@ -246,14 +276,10 @@ export float F32IPAVX(const float *pVect1, const float *pVect2, SizeT qty) { return sum; } -#endif - -export float F32IPBF(const float *pv1, const float *pv2, size_t dim) { - float res = 0; - for (size_t i = 0; i < dim; i++) { - res += pv1[i] * pv2[i]; - } - return res; +export float F32IPAVXResidual(const float *pVect1, const float *pVect2, SizeT qty) { + return F32IPAVX(pVect1, pVect2, qty) + F32IPBF(pVect1 + (qty & ~15), pVect2 + (qty & ~15), qty & 15); } +#endif + } // namespace infinity \ No newline at end of file diff --git a/src/storage/meta/entry/segment_entry.cpp b/src/storage/meta/entry/segment_entry.cpp index 63ec20b8f2..93fdd7a9ae 100644 --- a/src/storage/meta/entry/segment_entry.cpp +++ b/src/storage/meta/entry/segment_entry.cpp @@ -164,22 +164,6 @@ void SegmentEntry::DeleteData(SegmentEntry *segment_entry, Txn *txn_ptr, const H } } -template -class OneColumnIterator { -public: - OneColumnIterator(const SegmentEntry *entry, SizeT column_id) : segment_iter_(entry, MakeShared>(Vector{column_id})) {} - - Optional Next() { - if (auto ret = segment_iter_.Next(); ret) { - return reinterpret_cast((*ret)[0]); - } - return None; - } - -private: - SegmentIter segment_iter_; -}; - SharedPtr SegmentEntry::CreateIndexFile(SegmentEntry *segment_entry, ColumnIndexEntry *column_index_entry, SharedPtr column_def, diff --git a/src/storage/meta/iter/segment_iter.cppm b/src/storage/meta/iter/segment_iter.cppm index 0c2a88653d..bdf1e39051 100644 --- a/src/storage/meta/iter/segment_iter.cppm +++ b/src/storage/meta/iter/segment_iter.cppm @@ -56,4 +56,20 @@ private: SharedPtr> column_ids_; }; +export template +class OneColumnIterator { +public: + OneColumnIterator(const SegmentEntry *entry, SizeT column_id) : segment_iter_(entry, MakeShared>(Vector{column_id})) {} + + Optional Next() { + if (auto ret = segment_iter_.Next(); ret) { + return reinterpret_cast((*ret)[0]); + } + return None; + } + +private: + SegmentIter segment_iter_; +}; + } // namespace infinity \ No newline at end of file diff --git a/src/storage/txn/txn.cpp b/src/storage/txn/txn.cpp index 85e5fcd9f2..d4cd406a51 100644 --- a/src/storage/txn/txn.cpp +++ b/src/storage/txn/txn.cpp @@ -83,6 +83,21 @@ Status Txn::GetTableEntry(const String &db_name, const String &table_name, Table return Status::OK(); } +Status Txn::GetTableIndexEntry(const String &db_name, const String &table_name, const String &index_name, TableIndexEntry *&table_index_entry) { + TableCollectionEntry *table_entry = nullptr; + Status table_status = GetTableEntry(db_name, table_name, table_entry); + if (!table_status.ok()) { + return table_status; + } + + BaseEntry *base_entry = nullptr; + TableCollectionEntry::GetIndex(table_entry, index_name, txn_id_, txn_context_.GetBeginTS(), base_entry); + table_index_entry = static_cast(base_entry); + + return Status::OK(); +} + + Status Txn::Append(const String &db_name, const String &table_name, const SharedPtr &input_block) { TableCollectionEntry *table_entry{nullptr}; Status status = GetTableEntry(db_name, table_name, table_entry); @@ -421,6 +436,21 @@ Status Txn::CreateIndex(const String &db_name, const String &table_name, const S return index_status; } +Status +Txn::CreateIndex(TableCollectionEntry *table_entry, const SharedPtr &index_def, ConflictType conflict_type, TableIndexEntry *&table_index_entry) { + TxnTimeStamp begin_ts = txn_context_.GetBeginTS(); + + BaseEntry *base_entry{nullptr}; + Status index_status = TableCollectionEntry::CreateIndex(table_entry, index_def, conflict_type, txn_id_, begin_ts, txn_mgr_, base_entry); + if (!index_status.ok()) { + return index_status; + } + + table_index_entry = static_cast(base_entry); + txn_indexes_.emplace(*index_def->index_name_, table_index_entry); + return index_status; +} + Status Txn::DropIndexByName(const String &db_name, const String &table_name, const String &index_name, ConflictType conflict_type) { TxnState txn_state = txn_context_.GetTxnState(); if (txn_state != TxnState::kStarted) { diff --git a/src/storage/txn/txn.cppm b/src/storage/txn/txn.cppm index 3ff614d67d..7114d0667a 100644 --- a/src/storage/txn/txn.cppm +++ b/src/storage/txn/txn.cppm @@ -92,9 +92,13 @@ public: Status GetTableEntry(const String &db_name, const String &table_name, TableCollectionEntry *&table_entry); + Status GetTableIndexEntry(const String &db_name, const String &table_name, const String &index_name, TableIndexEntry *&table_index_entry); + // Index OPs Status CreateIndex(const String &db_name, const String &table_name, const SharedPtr &index_def, ConflictType conflict_type); + Status CreateIndex(TableCollectionEntry *table_entry, const SharedPtr &index_def, ConflictType conflict_type, TableIndexEntry *&table_index_entry); + Status DropIndexByName(const String &db_name, const String &table_name, const String &index_name, ConflictType conflict_type); // View Ops diff --git a/src/unit_test/storage/knnindex/knn_hnsw/test_dist_func.cpp b/src/unit_test/storage/knnindex/knn_hnsw/test_dist_func.cpp index ca6e88fbc4..d9b42f5d2a 100644 --- a/src/unit_test/storage/knnindex/knn_hnsw/test_dist_func.cpp +++ b/src/unit_test/storage/knnindex/knn_hnsw/test_dist_func.cpp @@ -43,14 +43,12 @@ float F32L2Test(const float *v1, const float *v2, size_t dim) { } TEST_F(DistFuncTest, test1) { - size_t dim = 32; + size_t dim = 200; size_t vec_n = 10000; auto vecs1 = std::make_unique(dim * vec_n); auto vecs2 = std::make_unique(dim * vec_n); - assert(dim % 32 == 0); - // generate a random vector of int8_t std::default_random_engine rng; std::uniform_int_distribution dist(-128, 127); @@ -64,7 +62,7 @@ TEST_F(DistFuncTest, test1) { for (size_t i = 0; i < vec_n; ++i) { auto v1 = vecs1.get() + i * dim; auto v2 = vecs2.get() + i * dim; - int32_t res = I8IPAVX(v1, v2, dim); + int32_t res = I8IPAVXResidual(v1, v2, dim); int32_t res2 = I8IPTest(v1, v2, dim); EXPECT_EQ(res, res2); } @@ -75,7 +73,7 @@ TEST_F(DistFuncTest, test2) { using Distance = LVQL2Dist; using LVQ8Data = LVQ8Store::StoreType; - size_t dim = 128; + size_t dim = 200; size_t vec_n = 10000; auto vecs1 = std::make_unique(dim * vec_n); @@ -83,11 +81,11 @@ TEST_F(DistFuncTest, test2) { // generate a random vector of float std::default_random_engine rng; - std::uniform_real_distribution dist(0, 1); + std::uniform_real_distribution rdist(0, 1); for (size_t i = 0; i < vec_n; ++i) { for (size_t j = 0; j < dim; ++j) { - vecs1[i * dim + j] = dist(rng); - vecs2[i * dim + j] = dist(rng); + vecs1[i * dim + j] = rdist(rng); + vecs2[i * dim + j] = rdist(rng); } } diff --git a/src/unit_test/test_hnsw.cpp b/src/unit_test/test_hnsw.cpp index 1879e2718a..d7e896a66a 100644 --- a/src/unit_test/test_hnsw.cpp +++ b/src/unit_test/test_hnsw.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -33,21 +34,23 @@ using namespace infinity; int main() { using LabelT = uint64_t; - using RetHeap = std::priority_queue>; std::string save_dir = tmp_data_path(); - int dim = 128; + int dim = 16; int element_size = 1000; + int M = 16; + int ef_construction = 200; - // using Hnsw = KnnHnsw, PlainL2Dist>; - using Hnsw = KnnHnsw>, LVQL2Dist>; + using Hnsw = KnnHnsw, PlainL2Dist>; + // using Hnsw = KnnHnsw>, LVQL2Dist>; // NOTE: inner product correct rate is not 1. (the vector and itself's distance is not the smallest) // using Hnsw = KnnHnsw, PlainIPDist>; // using Hnsw = KnnHnsw>, LVQIPDist>; - std::default_random_engine rng; + std::mt19937 rng; + rng.seed(0); std::uniform_real_distribution distrib_real; auto data = std::make_unique(dim * element_size); @@ -57,27 +60,23 @@ int main() { LocalFileSystem fs; - int M = 16; - int ef_construction = 200; - { auto hnsw_index = Hnsw::Make(element_size, dim, M, ef_construction, {}); auto labels = std::make_unique(element_size); std::iota(labels.get(), labels.get() + element_size, 0); - hnsw_index->Insert(data.get(), labels.get(), element_size); - // hnsw_index->Dump(std::cout); - hnsw_index->Check(); - - // hnsw_index->Dump(std::cout); + hnsw_index->InsertVecs(data.get(), labels.get(), element_size); + std::ofstream os("tmp/dump.txt"); + hnsw_index->Dump(os); hnsw_index->Check(); + return 0; hnsw_index->SetEf(10); int correct = 0; for (int i = 0; i < element_size; ++i) { const float *query = data.get() + i * dim; - RetHeap result = hnsw_index->KnnSearch(query, 1); - if (result.top().second == (LabelT)i) { + auto result = hnsw_index->KnnSearch(query, 1); + if (result[0].second == (LabelT)i) { ++correct; } } @@ -97,12 +96,12 @@ int main() { hnsw_index->SetEf(10); // hnsw_index->Dump(std::cout); - // hnsw_index->Check(); + hnsw_index->Check(); int correct = 0; for (int i = 0; i < element_size; ++i) { const float *query = data.get() + i * dim; - std::priority_queue> result = hnsw_index->KnnSearch(query, 1); - if (result.top().second == (LabelT)i) { + auto result = hnsw_index->KnnSearch(query, 1); + if (result[0].second == (LabelT)i) { ++correct; } } diff --git a/src/unit_test/test_hnsw_bitmask.cpp b/src/unit_test/test_hnsw_bitmask.cpp index d74c94b4b4..0deca90354 100644 --- a/src/unit_test/test_hnsw_bitmask.cpp +++ b/src/unit_test/test_hnsw_bitmask.cpp @@ -24,14 +24,17 @@ import hnsw_alg; using namespace infinity; -#define EXPECT_VALUE_EQ(a, b) if (auto f = f64(a) - f64(b); Max(f, -f) > 1e-4) { std::cerr << "values aren't equal at line\t" << __LINE__ << "\tvalues: " << a << " != " << b << std::endl;} +#define EXPECT_VALUE_EQ(a, b) \ + if (auto f = f64(a) - f64(b); Max(f, -f) > 1e-4) { \ + std::cerr << "values aren't equal at line\t" << __LINE__ << "\tvalues: " << a << " != " << b << std::endl; \ + } int main() { i64 dimension = 4; i64 top_k = 4; i64 base_embedding_count = 4; - UniquePtr < f32[] > base_embedding = MakeUnique(sizeof(f32) * dimension * base_embedding_count); - UniquePtr < f32[] > query_embedding = MakeUnique(sizeof(f32) * dimension); + UniquePtr base_embedding = MakeUnique(sizeof(f32) * dimension * base_embedding_count); + UniquePtr query_embedding = MakeUnique(sizeof(f32) * dimension); { base_embedding[0] = 0.1; @@ -69,8 +72,7 @@ int main() { } using LabelT = u64; - using RetHeap = MaxHeap >; - using Hnsw = KnnHnsw >, LVQL2Dist < f32, i8 >>; + using Hnsw = KnnHnsw>, LVQL2Dist>; int M = 16; int ef_construction = 200; auto hnsw_index = Hnsw::Make(base_embedding_count, dimension, M, ef_construction, {}); @@ -78,80 +80,60 @@ int main() { for (int i = 0; i < base_embedding_count; ++i) { labels[i] = i; } - hnsw_index->Insert(base_embedding.get(), labels.get(), base_embedding_count); + hnsw_index->InsertVecs(base_embedding.get(), labels.get(), base_embedding_count); - Vector distance_array(top_k); - Vector id_array(top_k); + Vector distance_array(top_k); + Vector id_array(top_k); { - RetHeap result = hnsw_index->KnnSearch(query_embedding.get(), top_k); - for (int i = 0; i < top_k; ++i) { - distance_array[top_k - 1 - i] = result.top().first; - id_array[top_k - 1 - i] = result.top().second; - result.pop(); - } - EXPECT_VALUE_EQ(distance_array[0], 0); - EXPECT_VALUE_EQ(id_array[0], 0); - - EXPECT_VALUE_EQ(distance_array[1], 0.02); - EXPECT_VALUE_EQ(id_array[1], 1); - - EXPECT_VALUE_EQ(distance_array[2], 0.08); - EXPECT_VALUE_EQ(id_array[2], 2); - - EXPECT_VALUE_EQ(distance_array[3], 0.2); - EXPECT_VALUE_EQ(id_array[3], 3); + auto result = hnsw_index->KnnSearch(query_embedding.get(), top_k); + + EXPECT_VALUE_EQ(result[0].first, 0); + EXPECT_VALUE_EQ(result[0].second, 0); + + EXPECT_VALUE_EQ(result[1].first, 0.02); + EXPECT_VALUE_EQ(result[1].second, 1); + + EXPECT_VALUE_EQ(result[2].first, 0.08); + EXPECT_VALUE_EQ(result[2].second, 2); + + EXPECT_VALUE_EQ(result[3].first, 0.2); + EXPECT_VALUE_EQ(result[3].second, 3); } auto p_bitmask = Bitmask::Make(64); p_bitmask->SetFalse(1); --top_k; { - RetHeap result = hnsw_index->KnnSearch(query_embedding.get(), top_k, *p_bitmask); - for (int i = 0; i < top_k; ++i) { - distance_array[top_k - 1 - i] = result.top().first; - id_array[top_k - 1 - i] = result.top().second; - result.pop(); - } - EXPECT_VALUE_EQ(distance_array[0], 0); - EXPECT_VALUE_EQ(id_array[0], 0); - - EXPECT_VALUE_EQ(distance_array[1], 0.08); - EXPECT_VALUE_EQ(id_array[1], 2); - - EXPECT_VALUE_EQ(distance_array[2], 0.2); - EXPECT_VALUE_EQ(id_array[2], 3); + auto result = hnsw_index->KnnSearch(query_embedding.get(), top_k, *p_bitmask); + + EXPECT_VALUE_EQ(result[0].first, 0); + EXPECT_VALUE_EQ(result[0].second, 0); + + EXPECT_VALUE_EQ(result[1].first, 0.08); + EXPECT_VALUE_EQ(result[1].second, 2); + + EXPECT_VALUE_EQ(result[2].first, 0.2); + EXPECT_VALUE_EQ(result[2].second, 3); } p_bitmask->SetFalse(0); --top_k; { - RetHeap result = hnsw_index->KnnSearch(query_embedding.get(), top_k, *p_bitmask); - for (int i = 0; i < top_k; ++i) { - distance_array[top_k - 1 - i] = result.top().first; - id_array[top_k - 1 - i] = result.top().second; - result.pop(); - } - - EXPECT_VALUE_EQ(distance_array[0], 0.08); - EXPECT_VALUE_EQ(id_array[0], 2); - - EXPECT_VALUE_EQ(distance_array[1], 0.2); - EXPECT_VALUE_EQ(id_array[1], 3); + auto result = hnsw_index->KnnSearch(query_embedding.get(), top_k, *p_bitmask); + + EXPECT_VALUE_EQ(result[0].first, 0.08); + EXPECT_VALUE_EQ(result[0].second, 2); + + EXPECT_VALUE_EQ(result[1].first, 0.2); + EXPECT_VALUE_EQ(result[1].second, 3); } p_bitmask->SetFalse(2); --top_k; { - RetHeap result = hnsw_index->KnnSearch(query_embedding.get(), top_k, *p_bitmask); - for (int i = 0; i < top_k; ++i) { - distance_array[top_k - 1 - - i] = result.top().first; - id_array[top_k - 1 - - i] = result.top().second; - result.pop(); - } - - EXPECT_VALUE_EQ(distance_array[0], 0.2); - EXPECT_VALUE_EQ(id_array[0], 3); + auto result = hnsw_index->KnnSearch(query_embedding.get(), top_k, *p_bitmask); + + EXPECT_VALUE_EQ(result[0].first, 0.2); + EXPECT_VALUE_EQ(result[0].second, 3); } }