Skip to content

Commit

Permalink
Build KNN index in parallel (#386)
Browse files Browse the repository at this point in the history
* Add fragment for parallel create index.

* Add concurrent insert hnsw alg.
Add simd func for residual dimension embedding.
Merge with yangzq. Remove some heap op in hnsw alg.

* Feat: Add create_index_prepare/do_finish.

* Fix create index source/sink bug.
Fix round robin bug.

* Add tmp schedule strategy.
  • Loading branch information
small-turtle-1 authored Dec 27, 2023
1 parent 774aeac commit f7dbfa2
Show file tree
Hide file tree
Showing 47 changed files with 1,505 additions and 723 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ parser.output
#query_parser.h
#query_parser.cpp

test/data/csv/test_sort.csv
test/sql/dql/sort.slt
test/data/csv/*big*.csv
test/sql/**/*big*.slt
test/data/fvecs/
Expand Down
79 changes: 49 additions & 30 deletions benchmark/embedding/hnsw_benchmark2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import stl;
import hnsw_alg;
import hnsw_common;
import local_file_system;
import file_system_type;
import file_system;
Expand All @@ -19,9 +20,13 @@ import compilation_config;
using namespace infinity;

int main() {
String base_file = String(test_data_path()) + "/benchmark/sift/base.fvecs";
String query_file = String(test_data_path()) + "/benchmark/sift/query.fvecs";
String groundtruth_file = String(test_data_path()) + "/benchmark/sift/l2_groundtruth.ivecs";
// String base_file = String(test_data_path()) + "/benchmark/text2image_10m/base.fvecs";
// String query_file = String(test_data_path()) + "/benchmark/text2image_10m/query.fvecs";
// String groundtruth_file = String(test_data_path()) + "/benchmark/text2image_10m/groundtruth.ivecs";

String base_file = String(test_data_path()) + "/benchmark/sift_1m/sift_base.fvecs";
String query_file = String(test_data_path()) + "/benchmark/sift_1m/sift_query.fvecs";
String groundtruth_file = String(test_data_path()) + "/benchmark/sift_1m/sift_groundtruth.ivecs";

LocalFileSystem fs;
std::string save_dir = tmp_data_path();
Expand All @@ -31,7 +36,8 @@ int main() {
size_t ef_construction = 200;
size_t embedding_count = 1000000;
size_t test_top = 100;
const int thread_n = 1;
const int build_thread_n = 1;
const int query_thread_n = 1;

using LabelT = uint64_t;

Expand All @@ -41,7 +47,7 @@ int main() {

using Hnsw = KnnHnsw<float, LabelT, LVQStore<float, int8_t, LVQL2Cache<float, int8_t>>, LVQL2Dist<float, int8_t>>;
SizeT init_args = {0};
std::string save_place = save_dir + "/my_sift_lvq8_l2.hnsw";
std::string save_place = save_dir + "/my_sift_lvq8_l2_1.hnsw";

// using Hnsw = KnnHnsw<float, LabelT, PlainStore<float>, PlainIPDist<float>>;
// std::tuple<> init_args = {};
Expand Down Expand Up @@ -69,26 +75,36 @@ int main() {
std::cout << "Begin memory cost: " << get_current_rss() << "B" << std::endl;
profiler.Begin();

if (false) {
for (size_t idx = 0; idx < embedding_count; ++idx) {
// insert data into index
const float *query = input_embeddings + idx * dimension;
knn_hnsw->Insert(query, idx);
if (idx % 100000 == 0) {
std::cout << idx << ", " << get_current_rss() << " B, " << profiler.ElapsedToString() << std::endl;
}
}
} else {
{
std::cout << "Build thread number: " << build_thread_n << std::endl;
auto labels = std::make_unique<LabelT[]>(embedding_count);
std::iota(labels.get(), labels.get() + embedding_count, 0);
knn_hnsw->Insert(input_embeddings, labels.get(), embedding_count);
VertexType start_i = knn_hnsw->StoreDataRaw(input_embeddings, labels.get(), embedding_count);
delete[] input_embeddings;
AtomicVtxType next_i = start_i;
std::vector<std::thread> threads;
for (int i = 0; i < build_thread_n; ++i) {
threads.emplace_back([&]() {
while (true) {
VertexType cur_i = next_i.fetch_add(1);
if (cur_i >= VertexType(start_i + embedding_count)) {
break;
}
knn_hnsw->Build<true>(cur_i);
if (cur_i && cur_i % 10000 == 0) {
std::cout << "Inserted " << cur_i << " / " << embedding_count << std::endl;
}
}
});
}
for (auto &thread : threads) {
thread.join();
}
}

profiler.End();
std::cout << "Insert data cost: " << profiler.ElapsedToString() << " memory cost: " << get_current_rss() << "B" << std::endl;

delete[] input_embeddings;

uint8_t file_flags = FileFlags::WRITE_FLAG | FileFlags::CREATE_FLAG;
std::unique_ptr<FileHandler> file_handler = fs.OpenFile(save_place, file_flags, FileLockType::kWriteLock);
knn_hnsw->Save(*file_handler);
Expand All @@ -100,7 +116,13 @@ int main() {
std::unique_ptr<FileHandler> file_handler = fs.OpenFile(save_place, file_flags, FileLockType::kReadLock);

knn_hnsw = Hnsw::Load(*file_handler, init_args);
std::cout << "Loaded" << std::endl;

// std::ofstream out("dump.txt");
// knn_hnsw->Dump(out);
// knn_hnsw->Check();
}
return 0;

size_t number_of_queries;
const float *queries = nullptr;
Expand All @@ -114,11 +136,10 @@ int main() {
Vector<HashSet<int>> ground_truth_sets; // number_of_queries * top_k matrix of ground-truth nearest-neighbors

{
// size_t *ground_truth;
// load ground-truth and convert int to long
size_t nq2;
int *gt_int = ivecs_read(groundtruth_file.c_str(), &top_k, &nq2);
assert(nq2 == number_of_queries || !"incorrect nb of ground truth entries");
assert(nq2 >= number_of_queries || !"incorrect nb of ground truth entries");
assert(top_k >= test_top || !"dataset does not provide enough ground truth data");

ground_truth_sets.resize(number_of_queries);
Expand All @@ -130,10 +151,10 @@ int main() {
}

infinity::BaseProfiler profiler;
int round = 10;
Vector<MaxHeap<Pair<float, LabelT>>> results;
results.reserve(number_of_queries);
std::cout << "thread number: " << thread_n << std::endl;
std::cout << "Start!" << std::endl;
int round = 3;
Vector<Vector<Pair<float, LabelT>>> results(number_of_queries);
std::cout << "Query thread number: " << query_thread_n << std::endl;
for (int ef = 100; ef <= 300; ef += 25) {
knn_hnsw->SetEf(ef);
int correct = 0;
Expand All @@ -142,15 +163,15 @@ int main() {
std::atomic_int idx(0);
std::vector<std::thread> threads;
profiler.Begin();
for (int j = 0; j < thread_n; ++j) {
for (int j = 0; j < query_thread_n; ++j) {
threads.emplace_back([&]() {
while (true) {
int cur_idx = idx.fetch_add(1);
if (cur_idx >= (int)number_of_queries) {
break;
}
const float *query = queries + cur_idx * dimension;
MaxHeap<Pair<float, LabelT>> result = knn_hnsw->KnnSearch(query, test_top);
auto result = knn_hnsw->KnnSearch<false>(query, test_top);
results[cur_idx] = std::move(result);
}
});
Expand All @@ -161,12 +182,10 @@ int main() {
profiler.End();
if (i == 0) {
for (size_t query_idx = 0; query_idx < number_of_queries; ++query_idx) {
auto &result = results[query_idx];
while (!result.empty()) {
if (ground_truth_sets[idx].contains(result.top().second)) {
for (const auto &[_, label] : results[query_idx]) {
if (ground_truth_sets[query_idx].contains(label)) {
++correct;
}
result.pop();
}
}
printf("Recall = %.4f\n", correct / float(test_top * number_of_queries));
Expand Down
22 changes: 9 additions & 13 deletions benchmark/local_infinity/knn/knn_import_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,11 @@ import query_result;

using namespace infinity;

int main(int argc, char *argv[]) {
if (argc != 2) {
std::cout << "import sift or gist" << std::endl;
return 1;
}
int main() {
bool sift = true;
if (strcmp(argv[1], "sift") && strcmp(argv[1], "gist")) {
return 1;
}
sift = strcmp(argv[1], "sift") == 0;
int M = 16;
int ef_construct = 200;
std::cout << "benchmark: " << (sift ? "sift" : "gist") << std::endl;

std::string data_path = "/tmp/infinity";

Expand Down Expand Up @@ -75,13 +70,14 @@ int main(int argc, char *argv[]) {
if (sift) {
col1_type = std::make_shared<DataType>(LogicalType::kEmbedding, std::make_shared<EmbeddingInfo>(EmbeddingDataType::kElemFloat, 128));
base_path += "/benchmark/sift_1m/sift_base.fvecs";
table_name = "sift_benchmark";
table_name = "sift_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct);
} else {
col1_type = std::make_shared<DataType>(LogicalType::kEmbedding, std::make_shared<EmbeddingInfo>(EmbeddingDataType::kElemFloat, 960));
base_path += "/benchmark/gist_1m/gist_base.fvecs";
table_name = "gist_benchmark";
table_name = "gist_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct);
}
std::cout << "Import from: " << base_path << std::endl;
std::cout << "table_name: " << table_name << std::endl;

std::string col1_name = "col1";
auto col1_def = std::make_unique<ColumnDef>(0, col1_type, col1_name, std::unordered_set<ConstraintType>());
Expand Down Expand Up @@ -123,8 +119,8 @@ int main(int argc, char *argv[]) {

{
auto index_param_list = new std::vector<InitParameter *>();
index_param_list->emplace_back(new InitParameter("M", std::to_string(16)));
index_param_list->emplace_back(new InitParameter("ef_construction", std::to_string(200)));
index_param_list->emplace_back(new InitParameter("M", std::to_string(M)));
index_param_list->emplace_back(new InitParameter("ef_construction", std::to_string(ef_construct)));
index_param_list->emplace_back(new InitParameter("ef", std::to_string(200)));
index_param_list->emplace_back(new InitParameter("metric", "l2"));
index_param_list->emplace_back(new InitParameter("encode", "lvq"));
Expand Down
33 changes: 14 additions & 19 deletions benchmark/local_infinity/knn/knn_query_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,18 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
}
}

int main(int argc, char *argv[]) {
if (argc != 3) {
std::cout << "query gist/sift ef=?" << std::endl;
return 1;
}
int main() {
bool sift = true;
if (strcmp(argv[1], "sift") && strcmp(argv[1], "gist")) {
return 1;
}
sift = strcmp(argv[1], "sift") == 0;
size_t ef = std::stoull(argv[2]);

int ef = 100;
int M = 16;
int ef_construct = 200;
size_t thread_num = 1;
size_t total_times = 1;
std::cout << "Please input thread_num, 0 means use all resources:" << std::endl;
std::cin >> thread_num;
std::cout << "Please input total_times:" << std::endl;
std::cin >> total_times;
size_t total_times = 3;

std::cout << "benchmark: " << (sift ? "sift" : "gist") << std::endl;
std::cout << "ef: " << ef << std::endl;
std::cout << "thread_n: " << thread_num << std::endl;
std::cout << "total_times: " << total_times << std::endl;

std::string path = "/tmp/infinity";
LocalFileSystem fs;
Expand All @@ -123,15 +117,16 @@ int main(int argc, char *argv[]) {
dimension = 128;
query_path += "/benchmark/sift_1m/sift_query.fvecs";
groundtruth_path += "/benchmark/sift_1m/sift_groundtruth.ivecs";
table_name = "sift_benchmark";
table_name = "sift_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct);
} else {
dimension = 960;
query_path += "/benchmark/gist_1m/gist_query.fvecs";
groundtruth_path += "/benchmark/gist_1m/gist_groundtruth.ivecs";
table_name = "gist_benchmark";
table_name = "gist_benchmark_M" + std::to_string(M) + "_ef" + std::to_string(ef_construct);
}
std::cout << "query from: " << query_path << std::endl;
std::cout << "groundtruth is: " << groundtruth_path << std::endl;
std::cout << "table_name: " << table_name << std::endl;

if (!fs.Exists(query_path)) {
std::cerr << "File: " << query_path << " doesn't exist" << std::endl;
Expand Down Expand Up @@ -218,7 +213,7 @@ int main(int argc, char *argv[]) {
query_results[query_idx].emplace_back(data[i].ToUint64());
}
}
// delete[] embedding_data_ptr;
// delete[] embedding_data_ptr;
};
BaseProfiler profiler;
profiler.Begin();
Expand Down
10 changes: 10 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module;
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <set>
#include <shared_mutex>
Expand Down Expand Up @@ -211,6 +212,7 @@ export {
using atomic_u32 = std::atomic_uint32_t;
using atomic_u64 = std::atomic_uint64_t;
using ai64 = std::atomic_int64_t;
using ai32 = std::atomic_int32_t;
using aptr = std::atomic_uintptr_t;
using atomic_bool = std::atomic_bool;

Expand Down Expand Up @@ -359,6 +361,8 @@ export {
template <typename T>
using LockGuard = std::lock_guard<T>;

using TryToLock = std::try_to_lock_t;

constexpr std::memory_order MemoryOrderRelax = std::memory_order::relaxed;
constexpr std::memory_order MemoryOrderConsume = std::memory_order::consume;
constexpr std::memory_order MemoryOrderRelease = std::memory_order::release;
Expand Down Expand Up @@ -433,4 +437,10 @@ struct CompareByFirst {
bool operator()(const P &lhs, const P &rhs) const { return lhs.first < rhs.first; }
};

export template <typename T1, typename T2>
struct CompareByFirstReverse {
using P = std::pair<T1, T2>;
bool operator()(const P &lhs, const P &rhs) const { return lhs.first > rhs.first; }
};

} // namespace infinity
45 changes: 45 additions & 0 deletions src/executor/fragment_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,51 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu
current_fragment_ptr->AddChild(Move(next_plan_fragment));
return;
}
case PhysicalOperatorType::kCreateIndexPrepare: {
if (phys_op->left() != nullptr || phys_op->right() != nullptr) {
Error<SchedulerException>(Format("Invalid input node of {}", phys_op->GetName()));
}
current_fragment_ptr->AddOperator(phys_op);
current_fragment_ptr->SetFragmentType(FragmentType::kSerialMaterialize);
current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kEmpty, phys_op->GetOutputNames(), phys_op->GetOutputTypes());
return;
}
case PhysicalOperatorType::kCreateIndexDo: {
if (phys_op->left() == nullptr || phys_op->right() != nullptr) {
Error<SchedulerException>(Format("Invalid input node of {}", phys_op->GetName()));
}
current_fragment_ptr->AddOperator(phys_op);
current_fragment_ptr->SetFragmentType(FragmentType::kParallelMaterialize);
current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kLocalQueue, phys_op->GetOutputNames(), phys_op->GetOutputTypes());

auto next_plan_fragment = MakeUnique<PlanFragment>(GetFragmentId());
next_plan_fragment->SetSinkNode(query_context_ptr_,
SinkType::kLocalQueue,
phys_op->left()->GetOutputNames(),
phys_op->left()->GetOutputTypes());
BuildFragments(phys_op->left(), next_plan_fragment.get());

current_fragment_ptr->AddChild(Move(next_plan_fragment));
return;
}
case PhysicalOperatorType::kCreateIndexFinish: {
if (phys_op->left() == nullptr || phys_op->right() != nullptr) {
Error<SchedulerException>(Format("Invalid input node of {}", phys_op->GetName()));
}
current_fragment_ptr->AddOperator(phys_op);
current_fragment_ptr->SetFragmentType(FragmentType::kSerialMaterialize);
current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kLocalQueue, phys_op->GetOutputNames(), phys_op->GetOutputTypes());

auto next_plan_fragment = MakeUnique<PlanFragment>(GetFragmentId());
next_plan_fragment->SetSinkNode(query_context_ptr_,
SinkType::kLocalQueue,
phys_op->left()->GetOutputNames(),
phys_op->left()->GetOutputTypes());
BuildFragments(phys_op->left(), next_plan_fragment.get());

current_fragment_ptr->AddChild(Move(next_plan_fragment));
return;
}
case PhysicalOperatorType::kParallelAggregate:
case PhysicalOperatorType::kFilter:
case PhysicalOperatorType::kHash:
Expand Down
2 changes: 1 addition & 1 deletion src/executor/fragment_builder.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ public:

UniquePtr<PlanFragment> BuildFragment(PhysicalOperator *phys_op);

private:
void BuildFragments(PhysicalOperator *phys_op, PlanFragment *current_fragment_ptr);

private:
void BuildExplain(PhysicalOperator *phys_op, PlanFragment *current_fragment_ptr);

idx_t GetFragmentId() { return fragment_id_++; }
Expand Down
Loading

0 comments on commit f7dbfa2

Please sign in to comment.