From 39bc6af6dc6dd14987f8036ff5b641c4bb276e63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Smoli=C5=84ski?= <29839376+lukaszsmolinski@users.noreply.github.com> Date: Sat, 4 Nov 2023 12:17:25 +0100 Subject: [PATCH] Add missing normalization check to BFIndex --- python_bindings/bindings.cpp | 46 ++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 14f1cabe..56ce9beb 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -871,19 +871,39 @@ class BFIndex { CustomFilterFunctor idFilter(filter); CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; - ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - std::priority_queue> result = alg->searchKnn( - (void*)items.data(row), k, p_idFilter); - if (result.size() != k) - throw std::runtime_error( - "Cannot return the results in a contiguous 2D array. There are not enough elements."); - for (int i = k - 1; i >= 0; i--) { - auto& result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - }); + if (!normalize) { + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg->searchKnn( + (void*)items.data(row), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } else { + std::vector norm_array(num_threads * features); + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), norm_array.data() + start_idx); + + std::priority_queue> result = alg->searchKnn( + (void*)(norm_array.data() + start_idx), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } } py::capsule free_when_done_l(data_numpy_l, [](void *f) {