Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Fix cosine bruteforce
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 committed Aug 8, 2023
1 parent 62c0a4b commit a661b44
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 18 deletions.
33 changes: 15 additions & 18 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ expected<DataSetPtr>
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand Down Expand Up @@ -71,11 +66,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down Expand Up @@ -123,11 +120,6 @@ Status
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand Down Expand Up @@ -167,11 +159,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down Expand Up @@ -262,10 +256,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
} else {
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
bitset);
}
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down
154 changes: 154 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cmath>
#include <cstdio>
#include <cstring>
#include "simd/hook.h"

#include <omp.h>

Expand Down Expand Up @@ -284,6 +285,44 @@ void exhaustive_L2sqr_seq(
}
}

namespace {
float fvec_cosine(const float* x, const float* y, size_t d) {
return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d));
}
} // namespace

template <class ResultHandler>
void exhaustive_cosine_seq(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const BitsetView bitset) {
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
if (bitset.empty() || !bitset.test(j)) {
float disij = fvec_cosine(x_i, y_j, d);
resi.add_result(disij, j);
}
y_j += d;
}
resi.end();
}
}
}

/** Find the nearest neighbors for nx queries in a set of ny vectors */
template <class ResultHandler>
void exhaustive_inner_product_blas(
Expand Down Expand Up @@ -426,6 +465,76 @@ void exhaustive_L2sqr_blas(
}
}

template <class ResultHandler>
void exhaustive_cosine_blas(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const BitsetView bitset = nullptr) {
// BLAS does not like empty matrices
if (nx == 0 || ny == 0)
return;

/* block sizes */
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
// const size_t bs_x = 16, bs_y = 16;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
std::unique_ptr<float[]> y_norms(new float[nx]);
std::unique_ptr<float[]> del2;

fvec_norms_L2(y_norms.get(), x, d, nx);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if (i1 > nx)
i1 = nx;

res.begin_multiple(i0, i1);

for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
sgemm_("Transpose",
"Not transpose",
&nyi,
&nxi,
&di,
&one,
y + j0 * d,
&di,
x + i0 * d,
&di,
&zero,
ip_block.get(),
&nyi);
}
#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);

for (size_t j = j0; j < j1; j++) {
float ip = *ip_line;
float dis = ip / y_norms[j];
*ip_line = dis;
ip_line++;
}
}
res.add_results(j0, j1, ip_block.get(), bitset);
}
res.end_multiple();
InterruptCallback::check();
}
}

template <class DistanceCorrection, class ResultHandler>
static void knn_jaccard_blas(
const float* x,
Expand Down Expand Up @@ -577,6 +686,34 @@ void knn_L2sqr(
}
}

void knn_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float_minheap_array_t* ha,
const BitsetView bitset) {
if (ha->k < distance_compute_min_k_reservoir) {
HeapResultHandler<CMin<float, int64_t>> res(
ha->nh, ha->val, ha->ids, ha->k);
if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_IP_seq(x, y, d, nx, ny, res, fvec_cosine, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
}
} else {
ReservoirResultHandler<CMin<float, int64_t>> res(
ha->nh, ha->val, ha->ids, ha->k);
if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_IP_seq(
x, y, d, nx, ny, res, fvec_inner_product, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
}
}
}

struct NopDistanceCorrection {
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
return dis;
Expand Down Expand Up @@ -640,6 +777,23 @@ void range_search_inner_product(
}
}

void range_search_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float radius,
RangeSearchResult* res,
const BitsetView bitset) {
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
if (nx < distance_compute_blas_threshold) {
exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset);
}
}

/***************************************************************************
* compute a subset of distances
***************************************************************************/
Expand Down
19 changes: 19 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ void knn_L2sqr(
const float* y_norm2 = nullptr,
const BitsetView bitset = nullptr);

void knn_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float_minheap_array_t* ha,
const BitsetView bitset);

void knn_jaccard(
const float* x,
const float* y,
Expand Down Expand Up @@ -265,6 +274,16 @@ void range_search_inner_product(
RangeSearchResult* result,
const BitsetView bitset = nullptr);

void range_search_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float radius,
RangeSearchResult* result,
const BitsetView bitset = nullptr);

/***************************************************************************
* PQ tables computations
***************************************************************************/
Expand Down

0 comments on commit a661b44

Please sign in to comment.