Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Use PQ and Refine Instead of Brute Force for DiskANN when Bitset Filt…
Browse files Browse the repository at this point in the history
…er Ratio High (#831)

Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm authored May 5, 2023
1 parent ecb4db3 commit 9727030
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 104 deletions.
4 changes: 3 additions & 1 deletion src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ DiskANNIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const Bit
auto k = static_cast<uint64_t>(search_conf.k);
auto lsearch = static_cast<uint64_t>(search_conf.search_list_size);
auto beamwidth = static_cast<uint64_t>(search_conf.beamwidth);
auto filter_ratio = static_cast<float>(search_conf.filter_threshold);

auto nq = dataset.GetRows();
auto dim = dataset.GetDim();
Expand All @@ -519,7 +520,8 @@ DiskANNIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const Bit
for (int64_t row = 0; row < nq; ++row) {
futures.push_back(pool_->push([&, index = row]() {
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id + (index * k),
p_dist + (index * k), beamwidth, false, nullptr, feder_result, bitset);
p_dist + (index * k), beamwidth, false, nullptr, feder_result, bitset,
filter_ratio);
}));
}
for (auto& future : futures) {
Expand Down
9 changes: 9 additions & 0 deletions src/index/diskann/diskann_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class DiskANNConfig : public BaseConfig {
// DiskANN uses TopK search to simulate range search, this is the ratio of search list size and k. With larger
// ratio, the accuracy will get higher but throughput will get affected.
CFG_FLOAT search_list_and_k_ratio;
// The threshold which determines when to switch to PQ + Refine strategy based on the number of bits set. The
// value should be in range of [0.0, 1.0] which means when greater or equal to x% of the bits are set,
// use PQ + Refine. Default to -1.0f, negative vlaues will use dynamic threshold calculator given topk.
CFG_FLOAT filter_threshold;
KNOHWERE_DECLARE_CONFIG(DiskANNConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(index_prefix)
.description("path to load or save Diskann.")
Expand Down Expand Up @@ -146,6 +150,11 @@ class DiskANNConfig : public BaseConfig {
.set_default(2.0)
.set_range(1.0, 5.0)
.for_range_search();
KNOWHERE_CONFIG_DECLARE_FIELD(filter_threshold)
.description("the threshold of filter ratio to use PQ + Refine.")
.set_default(-1.0f)
.set_range(-1.0f, 1.0f)
.for_search();
}
};
} // namespace knowhere
Expand Down
30 changes: 16 additions & 14 deletions tests/ut/test_diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "catch2/catch_approx.hpp"
#include "catch2/catch_test_macros.hpp"
#include "catch2/generators/catch_generators.hpp"
#include "diskann/defines_export.h"
#include "index/diskann/diskann.cc"
#include "index/diskann/diskann_config.h"
#include "knowhere/comp/brute_force.h"
Expand Down Expand Up @@ -241,19 +240,22 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") {
// knn search with bitset
std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages =
GetBitsetTestPercentagesFromThreshold(diskann::kDiskAnnBruteForceFilterRate);
for (const float percentage : bitset_percentages) {
for (const auto& gen_func : gen_bitset_funcs) {
auto bitset_data = gen_func(kNumRows, percentage * kNumRows);
knowhere::BitsetView bitset(bitset_data.data(), kNumRows);
auto results = diskann.Search(*query_ds, knn_json, bitset);
auto gt = knowhere::BruteForce::Search(base_ds, query_ds, knn_json, bitset);
float recall = GetKNNRecall(*gt.value(), *results.value());
if (percentage > diskann::kDiskAnnBruteForceFilterRate) {
REQUIRE(recall >= 0.99f);
} else {
REQUIRE(recall >= kL2KnnRecall);
const auto bitset_percentages = {0.4f, 0.98f};
const auto bitset_thresholds = {-1.0f, 0.9f};
for (const float threshold : bitset_thresholds) {
knn_json["filter_threshold"] = threshold;
for (const float percentage : bitset_percentages) {
for (const auto& gen_func : gen_bitset_funcs) {
auto bitset_data = gen_func(kNumRows, percentage * kNumRows);
knowhere::BitsetView bitset(bitset_data.data(), kNumRows);
auto results = diskann.Search(*query_ds, knn_json, bitset);
auto gt = knowhere::BruteForce::Search(base_ds, query_ds, knn_json, bitset);
float recall = GetKNNRecall(*gt.value(), *results.value());
if (percentage == 0.98f) {
REQUIRE(recall >= 0.9f);
} else {
REQUIRE(recall >= kL2KnnRecall);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ TEST_CASE("Test All Mem Index Search", "[search]") {

std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages = GetBitsetTestPercentagesFromThreshold(threshold);
const auto bitset_percentages = {0.4f, 0.98f};
for (const float percentage : bitset_percentages) {
for (const auto& gen_func : gen_bitset_funcs) {
auto bitset_data = gen_func(nb, percentage * nb);
Expand Down
8 changes: 0 additions & 8 deletions tests/ut/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,6 @@ GetRangeSearchRecall(const knowhere::DataSet& gt, const knowhere::DataSet& resul
return (1 + precision) * recall / 2;
}

// Generate two bitset percentages from given threshold, so that bruteforce
// strategy can be fully tested
inline std::vector<float>
GetBitsetTestPercentagesFromThreshold(const float threshold) {
assert(threshold >= 0 && threshold <= 1.0f);
return {threshold / 2, threshold + (1 - threshold) / 2};
}

// Return a n-bits bitset data with first t bits set to true
inline std::vector<uint8_t>
GenerateBitsetWithFirstTbitsSet(size_t n, size_t t) {
Expand Down
5 changes: 0 additions & 5 deletions thirdparty/DiskANN/include/diskann/defines_export.h

This file was deleted.

7 changes: 5 additions & 2 deletions thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ namespace diskann {
float *res_dists, const _u64 beam_width,
const bool use_reorder_data = false, QueryStats *stats = nullptr,
const knowhere::feder::diskann::FederResultUniq &feder = nullptr,
knowhere::BitsetView bitset_view = nullptr);
knowhere::BitsetView bitset_view = nullptr,
const float filter_ratio = -1.0f);

DISKANN_DLLEXPORT _u32 range_search(
const T *query1, const double range, const _u64 min_l_search,
Expand Down Expand Up @@ -163,7 +164,9 @@ namespace diskann {
// If there is no value, there is nothing to do with the given query
std::optional<float> init_thread_data(ThreadData<T> &data, const T *query1);

// Brute force search for the given query.
// Brute force search for the given query. Use beam search rather than
// sending whole bunch of requests at once to avoid all threads sending I/O
// requests and the time overlaps.
// The beam width is adjusted in the function.
void brute_force_beam_search(
ThreadData<T> &data, const float query_norm, const _u64 k_search,
Expand Down
183 changes: 110 additions & 73 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <atomic>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <iterator>
#include <optional>
#include <random>
Expand All @@ -21,10 +22,8 @@
#include "diskann/parameters.h"
#include "diskann/timer.h"
#include "diskann/utils.h"
#include "diskann/defines_export.h"
#include "knowhere/heap.h"

#include "knowhere/log.h"
#include "knowhere/utils.h"
#include "tsl/robin_set.h"

Expand Down Expand Up @@ -91,7 +90,13 @@ namespace {
}
}
}
constexpr _u64 kBruteForceBeamWidthFactor = 1; // TODO: after experiment, change this value
constexpr _u64 kRefineBeamWidthFactor = 2;
constexpr _u64 kBruteForceTopkRefineExpansionFactor = 2;
auto calcFilterThreshold = [](const auto topk) -> const float {
return std::max(-0.04570166137874405f * log2(topk + 58.96422392240403) +
1.1982775974217197,
0.5);
};
} // namespace

namespace diskann {
Expand Down Expand Up @@ -856,95 +861,111 @@ namespace diskann {
template<typename T>
void PQFlashIndex<T>::brute_force_beam_search(
ThreadData<T> &data, const float query_norm, const _u64 k_search,
_s64 *indices, float *distances, const _u64 beam_width_parm, IOContext &ctx,
_s64 *indices, float *distances, const _u64 beam_width_param, IOContext &ctx,
QueryStats *stats, const knowhere::feder::diskann::FederResultUniq &feder,
knowhere::BitsetView bitset_view) {
auto query_scratch = &(data.scratch);
const T *query = data.scratch.aligned_query_T;
auto beam_width = beam_width_parm * kBruteForceBeamWidthFactor;

T *data_buf = query_scratch->coord_scratch;
auto beam_width = beam_width_param * kRefineBeamWidthFactor;
const float *query_float = data.scratch.aligned_query_float;
float *pq_dists = query_scratch->aligned_pqtable_dist_scratch;
pq_table.populate_chunk_distances(query_float, pq_dists);
float *dist_scratch = query_scratch->aligned_dist_scratch;
_u8 *pq_coord_scratch = query_scratch->aligned_pq_coord_scratch;
constexpr _u32 pq_batch_size = MAX_GRAPH_DEGREE;
std::vector<unsigned> pq_batch_ids;
pq_batch_ids.reserve(pq_batch_size);
const _u64 pq_topk = k_search * kBruteForceTopkRefineExpansionFactor;
knowhere::ResultMaxHeap<float, int64_t> pq_max_heap(pq_topk);
T *data_buf = query_scratch->coord_scratch;
std::unordered_map<_u64, std::vector<_u64>> nodes_in_sectors_to_visit;
std::vector<AlignedRead> frontier_read_reqs;
frontier_read_reqs.reserve(2 * beam_width);
frontier_read_reqs.reserve(beam_width);
char *sector_scratch = query_scratch->sector_scratch;
_u64 &sector_scratch_idx = query_scratch->sector_idx;

Timer io_timer, query_timer;
knowhere::ResultMaxHeap<float, _u64> max_heap(k_search);
// TODO: maybe we can pipeline here
assert(bitset_view.size() == num_points);
for (_u64 id = 0; id < num_points;) {
// gathering information about reading which nodes and which sector
Timer io_timer, query_timer;

// scan un-marked points and calculate pq dists
for (_u64 id = 0; id < num_points; ++id) {
if (!bitset_view.test(id)) {
const _u64 sector_offset = get_node_sector_offset(id);
if (nodes_in_sectors_to_visit.find(sector_offset) ==
nodes_in_sectors_to_visit.end()) {
while (id < num_points &&
get_node_sector_offset(id) == sector_offset) {
if (!bitset_view.test(id)) {
if (coord_cache.find(id) != coord_cache.end()) {
float dist =
dist_cmp(query, coord_cache.at(id), (size_t) aligned_dim);
max_heap.Push(dist, id);
} else {
nodes_in_sectors_to_visit[sector_offset].push_back(id);
}
}
++id;
}
} else {
LOG_KNOWHERE_ERROR_ << "Should not be visited twice";
pq_batch_ids.push_back(id);
}

if (pq_batch_ids.size() == pq_batch_size || id == num_points - 1) {
const size_t sz = pq_batch_ids.size();
aggregate_coords(pq_batch_ids.data(), sz, this->data, this->n_chunks,
pq_coord_scratch);
pq_dist_lookup(pq_coord_scratch, sz, this->n_chunks, pq_dists,
dist_scratch);
for (size_t i = 0; i < sz; ++i) {
pq_max_heap.Push(dist_scratch[i], pq_batch_ids[i]);
}
} else {
++id;
pq_batch_ids.clear();
}
if (id < num_points && nodes_in_sectors_to_visit.size() < beam_width)
}

// deduplicate sectors by ids
while (const auto opt = pq_max_heap.Pop()) {
const auto [dist, id] = opt.value();

// check if in cache
if (coord_cache.find(id) != coord_cache.end()) {
float dist = dist_cmp(query, coord_cache.at(id), (size_t) aligned_dim);
max_heap.Push(dist, id);
continue;
}

// perform I/O
for (const auto &[sector_offset, ids_in_sectors] :
nodes_in_sectors_to_visit) {
frontier_read_reqs.emplace_back(
sector_offset, read_len_for_node,
sector_scratch + sector_scratch_idx * read_len_for_node);
++sector_scratch_idx;
if (stats != nullptr) {
stats->n_4k++;
stats->n_ios++;
}
// deduplicate and prepare for I/O
const _u64 sector_offset = get_node_sector_offset(id);
nodes_in_sectors_to_visit[sector_offset].push_back(id);
}

for (auto it = nodes_in_sectors_to_visit.cbegin();
it != nodes_in_sectors_to_visit.cend();) {
const auto sector_offset = it->first;
frontier_read_reqs.emplace_back(
sector_offset, read_len_for_node,
sector_scratch + sector_scratch_idx * read_len_for_node);
++sector_scratch_idx, ++it;
if (stats != nullptr) {
stats->n_4k++;
stats->n_ios++;
}
io_timer.reset();

// perform I/Os and calculate exact distances
if (frontier_read_reqs.size() == beam_width ||
it == nodes_in_sectors_to_visit.cend()) {
io_timer.reset();
#ifdef USE_BING_INFRA
reader->read(frontier_read_reqs, ctx, true); // async reader windows.
reader->read(frontier_read_reqs, ctx, true); // async reader windows.
#else
reader->read(frontier_read_reqs, ctx); // synchronous IO linux
reader->read(frontier_read_reqs, ctx); // synchronous IO linux
#endif
if (stats != nullptr) {
stats->io_us += (double) io_timer.elapsed();
}
if (stats != nullptr) {
stats->io_us += (double) io_timer.elapsed();
}

T *node_fp_coords_copy = data_buf;
for (const auto &req : frontier_read_reqs) {
const _u64 sector_offset = req.offset;
char *sector_buf = reinterpret_cast<char *>(req.buf);
for (const auto cur_id : nodes_in_sectors_to_visit[sector_offset]) {
char *node_buf = get_offset_to_node(sector_buf, cur_id);
memcpy(node_fp_coords_copy, node_buf,
disk_bytes_per_point); // Do we really need memcpy here?
float dist =
dist_cmp(query, node_fp_coords_copy, (size_t) aligned_dim);
max_heap.Push(dist, cur_id);
if (feder != nullptr) {
feder->visit_info_.AddTopCandidateInfo(cur_id, dist);
feder->id_set_.insert(cur_id);
T *node_fp_coords_copy = data_buf;
for (const auto &req : frontier_read_reqs) {
const auto offset = req.offset;
char *sector_buf = reinterpret_cast<char *>(req.buf);
for (const auto cur_id : nodes_in_sectors_to_visit[offset]) {
char *node_buf = get_offset_to_node(sector_buf, cur_id);
memcpy(node_fp_coords_copy, node_buf,
disk_bytes_per_point); // Do we really need memcpy here?
float dist =
dist_cmp(query, node_fp_coords_copy, (size_t) aligned_dim);
max_heap.Push(dist, cur_id);
if (feder != nullptr) {
feder->visit_info_.AddTopCandidateInfo(cur_id, dist);
feder->id_set_.insert(cur_id);
}
}
}
frontier_read_reqs.clear();
sector_scratch_idx = 0;
}

nodes_in_sectors_to_visit.clear();
frontier_read_reqs.clear();
sector_scratch_idx = 0;
}

for (_s64 i = k_search - 1; i >= 0; --i) {
Expand All @@ -968,7 +989,7 @@ namespace diskann {
}
}
} else {
LOG_KNOWHERE_ERROR_ << "Size is incorrect";
LOG(ERROR) << "Size is incorrect";
}
}
if (stats != nullptr) {
Expand All @@ -982,7 +1003,7 @@ namespace diskann {
const T *query1, const _u64 k_search, const _u64 l_search, _s64 *indices,
float *distances, const _u64 beam_width, const bool use_reorder_data,
QueryStats *stats, const knowhere::feder::diskann::FederResultUniq &feder,
knowhere::BitsetView bitset_view) {
knowhere::BitsetView bitset_view, const float filter_ratio_in) {
if (beam_width > MAX_N_SECTOR_READS)
throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS",
-1, __FUNCSIG__, __FILE__, __LINE__);
Expand All @@ -1002,13 +1023,29 @@ namespace diskann {
float query_norm = query_norm_opt.value();
auto ctx = this->reader->get_ctx();

if (!bitset_view.empty() && bitset_view.count() >= bitset_view.size() * kDiskAnnBruteForceFilterRate) {
if (!bitset_view.empty()) {
const auto filter_threshold = filter_ratio_in < 0
? calcFilterThreshold(k_search)
: filter_ratio_in;
const auto bv_cnt = bitset_view.count();
if (bitset_view.size() == bv_cnt) {
for (_u64 i = 0; i < k_search; i++) {
indices[i] = -1;
if (distances != nullptr) {
distances[i] = -1;
}
}
return;
}

if (bv_cnt >= bitset_view.size() * filter_threshold) {
brute_force_beam_search(data, query_norm, k_search, indices, distances,
beam_width, ctx, stats, feder, bitset_view);
this->thread_data.push(data);
this->thread_data.push_notify_all();
this->reader->put_ctx(ctx);
return;
}
}

auto query_scratch = &(data.scratch);
Expand Down

0 comments on commit 9727030

Please sign in to comment.