forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 11
/
IndexBinaryFromFloat.cpp
78 lines (60 loc) · 1.85 KB
/
IndexBinaryFromFloat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include "IndexBinaryFromFloat.h"
#include <memory>
#include "utils.h"
namespace faiss {
IndexBinaryFromFloat::IndexBinaryFromFloat() {}
IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index)
: IndexBinary(index->d),
index(index),
own_fields(false) {
is_trained = index->is_trained;
ntotal = index->ntotal;
}
IndexBinaryFromFloat::~IndexBinaryFromFloat() {
if (own_fields) {
delete index;
}
}
void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) {
constexpr idx_t bs = 32768;
std::unique_ptr<float[]> xf(new float[bs * d]);
for (idx_t b = 0; b < n; b += bs) {
idx_t bn = std::min(bs, n - b);
binary_to_real(bn * d, x + b * code_size, xf.get());
index->add(bn, xf.get());
}
ntotal = index->ntotal;
}
void IndexBinaryFromFloat::reset() {
index->reset();
ntotal = index->ntotal;
}
void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k,
int32_t *distances, idx_t *labels) const {
constexpr idx_t bs = 32768;
std::unique_ptr<float[]> xf(new float[bs * d]);
std::unique_ptr<float[]> df(new float[bs * k]);
for (idx_t b = 0; b < n; b += bs) {
idx_t bn = std::min(bs, n - b);
binary_to_real(bn * d, x + b * code_size, xf.get());
index->search(bn, xf.get(), k, df.get(), labels + b * k);
for (int i = 0; i < bn * k; ++i) {
distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
}
}
}
void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) {
std::unique_ptr<float[]> xf(new float[n * d]);
binary_to_real(n * d, x, xf.get());
index->train(n, xf.get());
is_trained = true;
ntotal = index->ntotal;
}
} // namespace faiss