Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Fix bruteforce cosine
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 committed Aug 4, 2023
1 parent 62c0a4b commit 20ea105
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 28 deletions.
63 changes: 41 additions & 22 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ expected<DataSetPtr>
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
Expand All @@ -54,6 +50,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto labels = new int64_t[nq * topk];
auto distances = new float[nq * topk];

bool is_cosine = IsMetricType(metric_str, metric::COSINE);
std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb);
}

auto pool = ThreadPool::GetGlobalThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand All @@ -71,13 +74,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
NormalizeVec(cur_query, dim);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_COSINE: {
auto cur_query = (float*)xq + dim * index;
NormalizeVec(cur_query, dim);
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, norms.get(), bitset);
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
Expand Down Expand Up @@ -123,10 +130,6 @@ Status
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
Expand All @@ -150,6 +153,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_

auto faiss_metric_type = metric_type.value();

bool is_cosine = IsMetricType(metric_str, metric::COSINE);
std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb);
}

auto pool = ThreadPool::GetGlobalThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
Expand All @@ -167,13 +177,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
NormalizeVec(cur_query, dim);
}
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_COSINE: {
auto cur_query = (float*)xq + dim * index;
NormalizeVec(cur_query, dim);
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, norms.get(), bitset);
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
Expand Down Expand Up @@ -221,11 +235,6 @@ expected<DataSetPtr>
BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset) {
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
if (is_cosine) {
Normalize(*base_dataset);
}

auto xb = base_dataset->GetTensor();
auto nb = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
Expand All @@ -241,6 +250,12 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
float range_filter = cfg.range_filter.value();

ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
std::unique_ptr<float[]> norms = nullptr;
if (is_cosine) {
norms = std::make_unique<float[]>(nb);
faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb);
}
auto pool = ThreadPool::GetGlobalThreadPool();

std::vector<std::vector<int64_t>> result_id_array(nq);
Expand All @@ -262,12 +277,16 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (float*)xq + dim * index;
if (is_cosine) {
NormalizeVec(cur_query, dim);
}
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_COSINE: {
auto cur_query = (float*)xq + dim * index;
NormalizeVec(cur_query, dim);
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, norms.get(),
bitset);
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
Expand Down
2 changes: 1 addition & 1 deletion src/common/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Str2FaissMetricType(std::string metric) {
static const std::unordered_map<std::string, faiss::MetricType> metric_map = {
{metric::L2, faiss::MetricType::METRIC_L2},
{metric::IP, faiss::MetricType::METRIC_INNER_PRODUCT},
{metric::COSINE, faiss::MetricType::METRIC_INNER_PRODUCT},
{metric::COSINE, faiss::MetricType::METRIC_COSINE},
{metric::HAMMING, faiss::MetricType::METRIC_Hamming},
{metric::JACCARD, faiss::MetricType::METRIC_Jaccard},
{metric::SUBSTRUCTURE, faiss::MetricType::METRIC_Substructure},
Expand Down
2 changes: 2 additions & 0 deletions thirdparty/faiss/faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ enum MetricType {
METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1
METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0

METRIC_COSINE,

/// some additional metrics defined in scipy.spatial.distance
METRIC_Canberra = 20,
METRIC_BrayCurtis,
Expand Down
Loading

0 comments on commit 20ea105

Please sign in to comment.