Skip to content

Commit

Permalink
Merge pull request #514 from lukaszsmolinski/develop
Browse files Browse the repository at this point in the history
Fix incorrect results in bruteforce with filter
  • Loading branch information
yurymalkov authored Dec 11, 2023
2 parents 5a8fd34 + 39bc6af commit dbcef01
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 27 deletions.
24 changes: 7 additions & 17 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,27 +107,17 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t lastdist = std::numeric_limits<dist_t>::max();
for (int i = 0; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.emplace(dist, label);
}
}
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
if (dist <= lastdist || topResults.size() < k) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.emplace(dist, label);
}
if (topResults.size() > k)
topResults.pop();

if (!topResults.empty()) {
lastdist = topResults.top().first;
if (topResults.size() > k)
topResults.pop();
if (!topResults.empty())
lastdist = topResults.top().first;
}
}
}
Expand Down
43 changes: 33 additions & 10 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,16 +871,39 @@ class BFIndex {
CustomFilterFunctor idFilter(filter);
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;

ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void*)items.data(row), k, p_idFilter);
for (int i = k - 1; i >= 0; i--) {
auto& result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
data_numpy_l[row * k + i] = result_tuple.second;
result.pop();
}
});
if (!normalize) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void*)items.data(row), k, p_idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
for (int i = k - 1; i >= 0; i--) {
auto& result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
data_numpy_l[row * k + i] = result_tuple.second;
result.pop();
}
});
} else {
std::vector<float> norm_array(num_threads * features);
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
size_t start_idx = threadId * dim;
normalize_vector((float*)items.data(row), norm_array.data() + start_idx);

std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void*)(norm_array.data() + start_idx), k, p_idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
for (int i = k - 1; i >= 0; i--) {
auto& result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
data_numpy_l[row * k + i] = result_tuple.second;
result.pop();
}
});
}
}

py::capsule free_when_done_l(data_numpy_l, [](void *f) {
Expand Down

0 comments on commit dbcef01

Please sign in to comment.