Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Replace diskann compute function with faiss's compute function (#568)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>

Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Nov 25, 2022
1 parent a020a61 commit a007b8a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 34 deletions.
88 changes: 56 additions & 32 deletions thirdparty/DiskANN/src/distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@

namespace diskann {

namespace {
#define ALIGNED(x) __attribute__((aligned(x)))

// reads 0 <= d < 4 floats as __m128
inline __m128 masked_read(int d, const float* x) {
ALIGNED(16) float buf[4] = {0, 0, 0, 0};
switch (d) {
case 3:
buf[2] = x[2];
case 2:
buf[1] = x[1];
case 1:
buf[0] = x[0];
}
return _mm_load_ps(buf);
// cannot use AVX2 _mm_mask_set1_epi32
}
}

// Cosine similarity.
float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b,
uint32_t length) const {
Expand Down Expand Up @@ -124,47 +143,52 @@ namespace diskann {
}

#ifndef _WINDOWS
float DistanceL2Float::compare(const float *a, const float *b,
uint32_t size) const {
a = (const float *) __builtin_assume_aligned(a, 32);
b = (const float *) __builtin_assume_aligned(b, 32);
#else
float DistanceL2Float::compare(const float *a, const float *b,
uint32_t size) const {
#endif
float DistanceL2Float::compare(const float *x, const float *y,
uint32_t d) const {
__m256 msum1 = _mm256_setzero_ps();

while (d >= 8) {
__m256 mx = _mm256_loadu_ps(x);
x += 8;
__m256 my = _mm256_loadu_ps(y);
y += 8;
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
d -= 8;
}

float result = 0;
#ifdef USE_AVX2
// assume size is divisible by 8
uint16_t niters = (uint16_t)(size / 8);
__m256 sum = _mm256_setzero_ps();
for (uint16_t j = 0; j < niters; j++) {
// scope is a[8j:8j+7], b[8j:8j+7]
// load a_vec
if (j < (niters - 1)) {
_mm_prefetch((char *) (a + 8 * (j + 1)), _MM_HINT_T0);
_mm_prefetch((char *) (b + 8 * (j + 1)), _MM_HINT_T0);
}
__m256 a_vec = _mm256_load_ps(a + 8 * j);
// load b_vec
__m256 b_vec = _mm256_load_ps(b + 8 * j);
// a_vec - b_vec
__m256 tmp_vec = _mm256_sub_ps(a_vec, b_vec);
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));

if (d >= 4) {
__m128 mx = _mm_loadu_ps(x);
x += 4;
__m128 my = _mm_loadu_ps(y);
y += 4;
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
d -= 4;
}

sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum);
if (d > 0) {
__m128 mx = masked_read(d, x);
__m128 my = masked_read(d, y);
__m128 a_m_b1 = _mm_sub_ps(mx, my);
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
}

// horizontal add sum
result = _mm256_reduce_add_ps(sum);
msum2 = _mm_hadd_ps(msum2, msum2);
msum2 = _mm_hadd_ps(msum2, msum2);
return _mm_cvtss_f32(msum2);
#else
#ifndef _WINDOWS
#pragma omp simd reduction(+ : result) aligned(a, b : 32)
#pragma omp simd reduction(+ : result) aligned(x, y : 32)
#endif
for (int32_t i = 0; i < (int32_t) size; i++) {
result += (a[i] - b[i]) * (a[i] - b[i]);
for (int32_t i = 0; i < (int32_t) d; i++) {
result += (x[i] - y[i]) * (x[i] - y[i]);
}
#endif
return result;
#endif
}

float SlowDistanceL2Float::compare(const float *a, const float *b,
Expand Down
4 changes: 2 additions & 2 deletions thirdparty/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ namespace diskann {
omp_set_num_threads(num_threads);

uint32_t num_syncs =
(unsigned) DIV_ROUND_UP(_nd + _num_frozen_pts, (64 * 64));
(unsigned) DIV_ROUND_UP(_nd + _num_frozen_pts, 8192);
if (num_syncs < 40)
num_syncs = 40;
LOG(DEBUG) << "Number of syncs: " << num_syncs;
Expand All @@ -1413,7 +1413,7 @@ namespace diskann {
unsigned L = _indexingQueueSize;

std::vector<unsigned> Lvec;
Lvec.push_back(L);
Lvec.push_back(unsigned(0.8 * L));
Lvec.push_back(L);
const unsigned NUM_RNDS = 2;

Expand Down

0 comments on commit a007b8a

Please sign in to comment.