From 20ea1058e8dca8a932d601b3cd41a89850a12f94 Mon Sep 17 00:00:00 2001 From: zh Wang Date: Fri, 4 Aug 2023 20:30:08 +0800 Subject: [PATCH] Fix bruteforce cosine Signed-off-by: zh Wang --- src/common/comp/brute_force.cc | 63 ++++--- src/common/metric.h | 2 +- thirdparty/faiss/faiss/MetricType.h | 2 + thirdparty/faiss/faiss/utils/distances.cpp | 185 ++++++++++++++++++++- thirdparty/faiss/faiss/utils/distances.h | 23 +++ 5 files changed, 247 insertions(+), 28 deletions(-) diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index a1f283a89..1fdf3b0d4 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -33,10 +33,6 @@ expected BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { std::string metric_str = config[meta::METRIC_TYPE].get(); - bool is_cosine = IsMetricType(metric_str, metric::COSINE); - if (is_cosine) { - Normalize(*base_dataset); - } auto xb = base_dataset->GetTensor(); auto nb = base_dataset->GetRows(); @@ -54,6 +50,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset auto labels = new int64_t[nq * topk]; auto distances = new float[nq * topk]; + bool is_cosine = IsMetricType(metric_str, metric::COSINE); + std::unique_ptr norms = nullptr; + if (is_cosine) { + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb); + } + auto pool = ThreadPool::GetGlobalThreadPool(); std::vector> futs; futs.reserve(nq); @@ -71,13 +74,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset } case faiss::METRIC_INNER_PRODUCT: { auto cur_query = (float*)xq + dim * index; - if (is_cosine) { - NormalizeVec(cur_query, dim); - } faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); break; } + case faiss::METRIC_COSINE: { + auto cur_query = (float*)xq + dim * index; + NormalizeVec(cur_query, dim); + faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; + faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, norms.get(), bitset); + break; + } case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; @@ -123,10 +130,6 @@ Status BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset) { std::string metric_str = config[meta::METRIC_TYPE].get(); - bool is_cosine = IsMetricType(metric_str, metric::COSINE); - if (is_cosine) { - Normalize(*base_dataset); - } auto xb = base_dataset->GetTensor(); auto nb = base_dataset->GetRows(); @@ -150,6 +153,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ auto faiss_metric_type = metric_type.value(); + bool is_cosine = IsMetricType(metric_str, metric::COSINE); + std::unique_ptr norms = nullptr; + if (is_cosine) { + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb); + } + auto pool = ThreadPool::GetGlobalThreadPool(); std::vector> futs; futs.reserve(nq); @@ -167,13 +177,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ } case faiss::METRIC_INNER_PRODUCT: { auto cur_query = (float*)xq + dim * index; - if (is_cosine) { - NormalizeVec(cur_query, dim); - } faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); break; } + case faiss::METRIC_COSINE: { + auto cur_query = (float*)xq + dim * index; + NormalizeVec(cur_query, dim); + faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; + faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, norms.get(), bitset); + break; + } case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; @@ -221,11 +235,6 @@ expected BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { std::string metric_str = config[meta::METRIC_TYPE].get(); - bool is_cosine = IsMetricType(metric_str, metric::COSINE); - if (is_cosine) { - Normalize(*base_dataset); - } - auto xb = base_dataset->GetTensor(); auto nb = base_dataset->GetRows(); auto dim = base_dataset->GetDim(); @@ -241,6 +250,12 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da float range_filter = cfg.range_filter.value(); ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value())); + bool is_cosine = IsMetricType(metric_str, metric::COSINE); + std::unique_ptr norms = nullptr; + if (is_cosine) { + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb); + } auto pool = ThreadPool::GetGlobalThreadPool(); std::vector> result_id_array(nq); @@ -262,12 +277,16 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da case faiss::METRIC_INNER_PRODUCT: { is_ip = true; auto cur_query = (float*)xq + dim * index; - if (is_cosine) { - NormalizeVec(cur_query, dim); - } faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset); break; } + case faiss::METRIC_COSINE: { + auto cur_query = (float*)xq + dim * index; + NormalizeVec(cur_query, dim); + faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, norms.get(), + bitset); + break; + } case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::binary_range_search, float>( diff --git a/src/common/metric.h b/src/common/metric.h index 2ac90090d..82aaeffbb 100644 --- a/src/common/metric.h +++ b/src/common/metric.h @@ -27,7 +27,7 @@ Str2FaissMetricType(std::string metric) { static const std::unordered_map metric_map = { {metric::L2, faiss::MetricType::METRIC_L2}, {metric::IP, faiss::MetricType::METRIC_INNER_PRODUCT}, - {metric::COSINE, faiss::MetricType::METRIC_INNER_PRODUCT}, + {metric::COSINE, faiss::MetricType::METRIC_COSINE}, {metric::HAMMING, faiss::MetricType::METRIC_Hamming}, {metric::JACCARD, faiss::MetricType::METRIC_Jaccard}, {metric::SUBSTRUCTURE, faiss::MetricType::METRIC_Substructure}, diff --git a/thirdparty/faiss/faiss/MetricType.h b/thirdparty/faiss/faiss/MetricType.h index 068335c8e..f66892223 100644 --- a/thirdparty/faiss/faiss/MetricType.h +++ b/thirdparty/faiss/faiss/MetricType.h @@ -30,6 +30,8 @@ enum MetricType { METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1 METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0 + METRIC_COSINE, + /// some additional metrics defined in scipy.spatial.distance METRIC_Canberra = 20, METRIC_BrayCurtis, diff --git a/thirdparty/faiss/faiss/utils/distances.cpp b/thirdparty/faiss/faiss/utils/distances.cpp index b8c5e2dbb..16c90c73f 100644 --- a/thirdparty/faiss/faiss/utils/distances.cpp +++ b/thirdparty/faiss/faiss/utils/distances.cpp @@ -14,6 +14,7 @@ #include #include #include +#include "simd/hook.h" #include @@ -109,7 +110,8 @@ void exhaustive_parallel_on_nx( size_t ny, ResultHandler& res, decltype(fvec_inner_product) dis_compute_func, - const BitsetView bitset) { + const BitsetView bitset, + const float* y_norm = nullptr) { using SingleResultHandler = typename ResultHandler::SingleResultHandler; #pragma omp parallel { @@ -123,6 +125,9 @@ void exhaustive_parallel_on_nx( for (size_t j = 0; j < ny; j++) { if (bitset.empty() || !bitset.test(j)) { float ip = dis_compute_func(x_i, y_j, d); + if (y_norm) { + ip /= y_norm[j]; + } resi.add_result(ip, j); } y_j += d; @@ -141,7 +146,8 @@ void exhaustive_parallel_on_ny( size_t ny, ResultHandler& res, decltype(fvec_inner_product) dis_compute_func, - const BitsetView bitset) { + const BitsetView bitset, + const float* y_norm2 = nullptr) { using SingleResultHandler = typename ResultHandler::SingleResultHandler; size_t k = res.k; size_t thread_max_num = omp_get_max_threads(); @@ -174,6 +180,9 @@ void exhaustive_parallel_on_ny( const float* x_i = x + x_from * d; for (size_t i = 0; i < size; i++) { float ip = dis_compute_func(x_i, y_j, d); + if (y_norm2) { + ip /= y_norm2[j]; + } ress[t].add_single_result(i, ip, j); x_i += d; } @@ -207,15 +216,16 @@ void exhaustive_L2sqr_IP_seq( size_t ny, ResultHandler& res, decltype(fvec_inner_product) dis_compute_func, - const BitsetView bitset) { + const BitsetView bitset, + const float* y_norm = nullptr) { size_t thread_max_num = omp_get_max_threads(); if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) { exhaustive_parallel_on_ny( - x, y, d, nx, ny, res, dis_compute_func, bitset); + x, y, d, nx, ny, res, dis_compute_func, bitset, y_norm); } else { exhaustive_parallel_on_nx( - x, y, d, nx, ny, res, dis_compute_func, bitset); + x, y, d, nx, ny, res, dis_compute_func, bitset, y_norm); } } @@ -284,6 +294,40 @@ void exhaustive_L2sqr_seq( } } +template +void exhaustive_cosine_seq( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + ResultHandler& res, + const float* y_norm, + const BitsetView bitset) { + using SingleResultHandler = typename ResultHandler::SingleResultHandler; + int nt = std::min(int(nx), omp_get_max_threads()); + +#pragma omp parallel num_threads(nt) + { + SingleResultHandler resi(res); +#pragma omp for + for (int64_t i = 0; i < nx; i++) { + const float* x_i = x + i * d; + const float* y_j = y; + resi.begin(i); + for (size_t j = 0; j < ny; j++) { + if (bitset.empty() || !bitset.test(j)) { + float ip = fvec_inner_product(x_i, y_j, d); + ip /= y_norm[j]; + resi.add_result(ip, j); + } + y_j += d; + } + resi.end(); + } + } +} + /** Find the nearest neighbors for nx queries in a set of ny vectors */ template void exhaustive_inner_product_blas( @@ -426,6 +470,89 @@ void exhaustive_L2sqr_blas( } } +// distance correction is an operator that can be applied to transform +// the distances +template +void exhaustive_cosine_blas( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + ResultHandler& res, + const float* y_norm = nullptr, + const BitsetView bitset = nullptr) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) + return; + + /* block sizes */ + const size_t bs_x = distance_compute_blas_query_bs; + const size_t bs_y = distance_compute_blas_database_bs; + // const size_t bs_x = 16, bs_y = 16; + std::unique_ptr ip_block(new float[bs_x * bs_y]); + std::unique_ptr del2; + + if (!y_norm) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + fvec_norms_L2sqr(y_norms2, y, d, ny); + y_norm = y_norms2; + } + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if (i1 > nx) + i1 = nx; + + res.begin_multiple(i0, i1); + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) + j1 = ny; + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_("Transpose", + "Not transpose", + &nyi, + &nxi, + &di, + &one, + y + j0 * d, + &di, + x + i0 * d, + &di, + &zero, + ip_block.get(), + &nyi); + } +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { + float* ip_line = ip_block.get() + (i - i0) * (j1 - j0); + + for (size_t j = j0; j < j1; j++) { + float ip = *ip_line; + float dis = ip / y_norm[i]; + + // negative values can occur for identical vectors + // due to roundoff errors + if (dis < 0) + dis = 0; + + *ip_line = dis; + ip_line++; + } + } + res.add_results(j0, j1, ip_block.get(), bitset); + } + res.end_multiple(); + InterruptCallback::check(); + } +} + template static void knn_jaccard_blas( const float* x, @@ -577,6 +704,36 @@ void knn_L2sqr( } } +void knn_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float_minheap_array_t* ha, + const float* y_norm, + const BitsetView bitset) { + if (ha->k < distance_compute_min_k_reservoir) { + HeapResultHandler> res( + ha->nh, ha->val, ha->ids, ha->k); + if (nx < distance_compute_blas_threshold) { + exhaustive_L2sqr_IP_seq( + x, y, d, nx, ny, res, fvec_inner_product, bitset, y_norm); + } else { + exhaustive_cosine_blas(x, y, d, nx, ny, res, y_norm, bitset); + } + } else { + ReservoirResultHandler> res( + ha->nh, ha->val, ha->ids, ha->k); + if (nx < distance_compute_blas_threshold) { + exhaustive_L2sqr_IP_seq( + x, y, d, nx, ny, res, fvec_inner_product, bitset, y_norm); + } else { + exhaustive_cosine_blas(x, y, d, nx, ny, res, y_norm, bitset); + } + } +} + struct NopDistanceCorrection { float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const { return dis; @@ -640,6 +797,24 @@ void range_search_inner_product( } } +void range_search_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* res, + const float* y_norm, + const BitsetView bitset) { + RangeSearchResultHandler> resh(res, radius); + if (nx < distance_compute_blas_threshold) { + exhaustive_cosine_seq(x, y, d, nx, ny, resh, y_norm, bitset); + } else { + exhaustive_cosine_blas(x, y, d, nx, ny, resh, y_norm, bitset); + } +} + /*************************************************************************** * compute a subset of distances ***************************************************************************/ diff --git a/thirdparty/faiss/faiss/utils/distances.h b/thirdparty/faiss/faiss/utils/distances.h index ebc51f7f2..048952263 100644 --- a/thirdparty/faiss/faiss/utils/distances.h +++ b/thirdparty/faiss/faiss/utils/distances.h @@ -199,6 +199,19 @@ void knn_L2sqr( const float* y_norm2 = nullptr, const BitsetView bitset = nullptr); +/** Same as knn_inner_product, for the COSINE distance + * @param norms norms for the y vectors (size ny) + */ +void knn_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float_minheap_array_t* res, + const float* y_norm2, + const BitsetView bitset = nullptr); + void knn_jaccard( const float* x, const float* y, @@ -265,6 +278,16 @@ void range_search_inner_product( RangeSearchResult* result, const BitsetView bitset = nullptr); +void range_search_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* res, + const float* y_norm, + const BitsetView bitset = nullptr); /*************************************************************************** * PQ tables computations ***************************************************************************/