Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Fix bruteforce cosine #1021

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/mergify.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pull_request_rules:
- base=1.x
- and:
- -body~=\#[0-9]{1,6}(\s+|$)
- -body~=https://github.com/milvus-io/knowhere/issues/[0-9]{1,6}(\s+|$)
- -body~=https://github.com/zilliztech/Knowhere/issues/[0-9]{1,6}(\s+|$)
- -label=kind/improvement
- -title~=\[automated\]
actions:
Expand All @@ -55,7 +55,7 @@ pull_request_rules:
- or:
- or:
- body~=\#[0-9]{1,6}(\s+|$)
- body~=https://github.com/milvus-io/knowhere/issues/[0-9]{1,6}(\s+|$)
- body~=https://github.com/zilliztech/Knowhere/issues/[0-9]{1,6}(\s+|$)
- label=kind/improvement
actions:
label:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<img src="static/knowhere-logo.png" alt="Knowhere Logo"/>
</p>

This document will help you to build the Knowhere repository from source code and to run unit tests. Please [file an issue](https://github.com/milvus-io/knowhere/issues/new) if there's a problem.
This document will help you to build the Knowhere repository from source code and to run unit tests. Please [file an issue](https://github.com/zilliztech/knowhere/issues/new) if there's a problem.

## Introduction

Expand Down
33 changes: 15 additions & 18 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ expected<DataSetPtr>
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand Down Expand Up @@ -71,11 +66,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down Expand Up @@ -123,11 +120,6 @@ Status
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand Down Expand Up @@ -167,11 +159,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down Expand Up @@ -262,10 +256,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
if (IsMetricType(metric_str, metric::COSINE)) {
NormalizeVec(cur_query, dim);
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
} else {
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
bitset);
}
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_Jaccard: {
Expand Down
154 changes: 154 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cmath>
#include <cstdio>
#include <cstring>
#include "simd/hook.h"

#include <omp.h>

Expand Down Expand Up @@ -284,6 +285,44 @@ void exhaustive_L2sqr_seq(
}
}

namespace {
float fvec_cosine(const float* x, const float* y, size_t d) {
return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d));
}
} // namespace

template <class ResultHandler>
void exhaustive_cosine_seq(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const BitsetView bitset) {
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
if (bitset.empty() || !bitset.test(j)) {
float disij = fvec_cosine(x_i, y_j, d);
resi.add_result(disij, j);
}
y_j += d;
}
resi.end();
}
}
}

/** Find the nearest neighbors for nx queries in a set of ny vectors */
template <class ResultHandler>
void exhaustive_inner_product_blas(
Expand Down Expand Up @@ -426,6 +465,76 @@ void exhaustive_L2sqr_blas(
}
}

template <class ResultHandler>
void exhaustive_cosine_blas(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const BitsetView bitset = nullptr) {
// BLAS does not like empty matrices
if (nx == 0 || ny == 0)
return;

/* block sizes */
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
// const size_t bs_x = 16, bs_y = 16;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
std::unique_ptr<float[]> y_norms(new float[nx]);
std::unique_ptr<float[]> del2;

fvec_norms_L2(y_norms.get(), x, d, nx);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if (i1 > nx)
i1 = nx;

res.begin_multiple(i0, i1);

for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
sgemm_("Transpose",
"Not transpose",
&nyi,
&nxi,
&di,
&one,
y + j0 * d,
&di,
x + i0 * d,
&di,
&zero,
ip_block.get(),
&nyi);
}
#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);

for (size_t j = j0; j < j1; j++) {
float ip = *ip_line;
float dis = ip / y_norms[j];
*ip_line = dis;
ip_line++;
}
}
res.add_results(j0, j1, ip_block.get(), bitset);
}
res.end_multiple();
InterruptCallback::check();
}
}

template <class DistanceCorrection, class ResultHandler>
static void knn_jaccard_blas(
const float* x,
Expand Down Expand Up @@ -577,6 +686,34 @@ void knn_L2sqr(
}
}

void knn_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float_minheap_array_t* ha,
const BitsetView bitset) {
if (ha->k < distance_compute_min_k_reservoir) {
HeapResultHandler<CMin<float, int64_t>> res(
ha->nh, ha->val, ha->ids, ha->k);
if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_IP_seq(x, y, d, nx, ny, res, fvec_cosine, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
}
} else {
ReservoirResultHandler<CMin<float, int64_t>> res(
ha->nh, ha->val, ha->ids, ha->k);
if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_IP_seq(
x, y, d, nx, ny, res, fvec_inner_product, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
}
}
}

struct NopDistanceCorrection {
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
return dis;
Expand Down Expand Up @@ -640,6 +777,23 @@ void range_search_inner_product(
}
}

void range_search_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float radius,
RangeSearchResult* res,
const BitsetView bitset) {
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
if (nx < distance_compute_blas_threshold) {
exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset);
} else {
exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset);
}
}

/***************************************************************************
* compute a subset of distances
***************************************************************************/
Expand Down
19 changes: 19 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ void knn_L2sqr(
const float* y_norm2 = nullptr,
const BitsetView bitset = nullptr);

void knn_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float_minheap_array_t* ha,
const BitsetView bitset);

void knn_jaccard(
const float* x,
const float* y,
Expand Down Expand Up @@ -265,6 +274,16 @@ void range_search_inner_product(
RangeSearchResult* result,
const BitsetView bitset = nullptr);

void range_search_cosine(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
float radius,
RangeSearchResult* result,
const BitsetView bitset = nullptr);

/***************************************************************************
* PQ tables computations
***************************************************************************/
Expand Down