Skip to content

Commit

Permalink
Made a patch in NMSLIB to avoid frequently calling JNI for better loa…
Browse files Browse the repository at this point in the history
…ding index performance.

Signed-off-by: Dooyong Kim <[email protected]>
  • Loading branch information
Dooyong Kim committed Oct 11, 2024
1 parent ce98151 commit 2398d49
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 139 deletions.
3 changes: 1 addition & 2 deletions jni/cmake/init-nmslib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ if (NOT EXISTS ${NMS_REPO_DIR})
execute_process(COMMAND git submodule update --init -- external/nmslib WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif ()


# Apply patches
if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true)
# Define list of patch files
set(PATCH_FILE_LIST)
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch")

# Get patch id of the last commit
execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib)
Expand Down
9 changes: 7 additions & 2 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ namespace knn_jni {

virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0;

virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0;
virtual jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz,
jmethodID methodID, jvalue *args) = 0;

virtual jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz,
jmethodID methodID, jvalue* args) = 0;

// --------------------------------------------------------------------------
};
Expand Down Expand Up @@ -194,7 +198,8 @@ namespace knn_jni {
jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final;
jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final;
jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, jvalue *args) final;
jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) final;
void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final;
void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final;

Expand Down
24 changes: 22 additions & 2 deletions jni/include/native_engines_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,19 @@ class NativeEngineIndexInputMediator {
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) {
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)),
remainingBytesMethod(getRemainingBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
auto jclazz = getIndexInputWithBufferClass(jni_interface, env);

while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
jvalue args;
args.j = nbytes;
const auto readBytes =
jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes);
jni_interface->CallNonvirtualIntMethodA(env, indexInput, jclazz, copyBytesMethod, &args);

// === Critical Section Start ===

Expand All @@ -69,6 +74,14 @@ class NativeEngineIndexInputMediator {
} // End while
}

int64_t remainingBytes() {
return jni_interface->CallNonvirtualLongMethodA(env,
indexInput,
getIndexInputWithBufferClass(jni_interface, env),
remainingBytesMethod,
nullptr);
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
Expand All @@ -82,6 +95,12 @@ class NativeEngineIndexInputMediator {
return COPY_METHOD_ID;
}

static jmethodID getRemainingBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "remainingBytes", "()J");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
Expand All @@ -95,6 +114,7 @@ class NativeEngineIndexInputMediator {
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
jmethodID remainingBytesMethod;
}; // class NativeEngineIndexInputMediator


Expand Down
34 changes: 15 additions & 19 deletions jni/include/nmslib_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,36 @@
#ifndef OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H

#include "jni_util.h"
#include "native_engines_stream_support.h"

#include <jni.h>
#include <stdexcept>
#include <iostream>
#include <cstring>

namespace knn_jni {
namespace stream {



/**
* std::streambuf implementation delegating NativeEngineIndexInputMediator to read bytes.
* This class is expected to be wrapped as std::istream, then to be passed to NMSLIB.
* NMSLIB will rely on the passed std::istream to read required bytes.
* NmslibIOReader implementation delegating NativeEngineIndexInputMediator to read bytes.
*/
class NmslibMediatorInputStreamBuffer final : public std::streambuf {
class NmslibOpenSearchIOReader final : public similarity::NmslibIOReader {
public:
explicit NmslibMediatorInputStreamBuffer(NativeEngineIndexInputMediator *_mediator)
: std::streambuf(),
mediator(_mediator) {
explicit NmslibOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator)
: mediator(_mediator) {
}

protected:
std::streamsize xsgetn(std::streambuf::char_type *destination, std::streamsize count) final {
if (count > 0) {
mediator->copyBytes(count, (uint8_t *) destination);
void read(char *bytes, size_t len) final {
if (len > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(len, (uint8_t *) bytes);
}
return count;
}

size_t remainingBytes() final {
return mediator->remainingBytes();
}

private:
NativeEngineIndexInputMediator *mediator;
}; // NmslibMediatorInputStreamBuffer
}; // class NmslibOpenSearchIOReader



Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
From ea99c7ce2cb1775d8130da9eaaffeff89bb6ffd3 Mon Sep 17 00:00:00 2001
From: Dooyong Kim <[email protected]>
Date: Fri, 11 Oct 2024 14:05:40 -0700
Subject: [PATCH] Added streaming apis for vector index loading in Hnsw.

Signed-off-by: Dooyong Kim <[email protected]>
---
similarity_search/include/method/hnsw.h | 3 +
similarity_search/include/utils.h | 12 ++
similarity_search/src/method/hnsw.cc | 139 +++++++++++++++++++++++-
3 files changed, 153 insertions(+), 1 deletion(-)

diff --git a/similarity_search/include/method/hnsw.h b/similarity_search/include/method/hnsw.h
index e6dcea7..433f98f 100644
--- a/similarity_search/include/method/hnsw.h
+++ b/similarity_search/include/method/hnsw.h
@@ -457,6 +457,8 @@ namespace similarity {

virtual void LoadIndex(const string &location) override;

+ void LoadIndexWithStream(similarity::NmslibIOReader& in);
+
Hnsw(bool PrintProgress, const Space<dist_t> &space, const ObjectVector &data);
void CreateIndex(const AnyParams &IndexParams) override;

@@ -500,6 +502,7 @@ namespace similarity {

void SaveOptimizedIndex(std::ostream& output);
void LoadOptimizedIndex(std::istream& input);
+ void LoadOptimizedIndex(NmslibIOReader& input);

void SaveRegularIndexBin(std::ostream& output);
void LoadRegularIndexBin(std::istream& input);
diff --git a/similarity_search/include/utils.h b/similarity_search/include/utils.h
index b521c26..a3931b7 100644
--- a/similarity_search/include/utils.h
+++ b/similarity_search/include/utils.h
@@ -299,12 +299,24 @@ inline void WriteField(ostream& out, const string& fieldName, const FieldType& f
}
}

+struct NmslibIOReader {
+ virtual ~NmslibIOReader() = default;
+
+ virtual void read(char* bytes, size_t len) = 0;
+
+ virtual size_t remainingBytes() = 0;
+};

template <typename T>
void writeBinaryPOD(ostream& out, const T& podRef) {
out.write((char*)&podRef, sizeof(T));
}

+template <typename T>
+static void readBinaryPOD(NmslibIOReader& in, T& podRef) {
+ in.read((char*)&podRef, sizeof(T));
+}
+
template <typename T>
static void readBinaryPOD(istream& in, T& podRef) {
in.read((char*)&podRef, sizeof(T));
diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc
index 4080b3b..63482fd 100644
--- a/similarity_search/src/method/hnsw.cc
+++ b/similarity_search/src/method/hnsw.cc
@@ -950,7 +950,6 @@ namespace similarity {
" read so far doesn't match the number of read lines: " + ConvertToString(lineNum));
}

-
template <typename dist_t>
void
Hnsw<dist_t>::LoadRegularIndexBin(std::istream& input) {
@@ -1034,6 +1033,144 @@ namespace similarity {

}

+ constexpr bool _isLittleEndian() {
+ uint32_t value = 1;
+ return (value & 0xFFU) == 1;
+ }
+
+ SIZEMASS_TYPE _readIntBigEndian(uint8_t byte0, uint8_t byte1, uint8_t byte2, uint8_t byte3) noexcept {
+ return (static_cast<SIZEMASS_TYPE>(byte0) << 24) |
+ (static_cast<SIZEMASS_TYPE>(byte1) << 16) |
+ (static_cast<SIZEMASS_TYPE>(byte2) << 8) |
+ static_cast<SIZEMASS_TYPE>(byte3);
+ }
+
+ SIZEMASS_TYPE _readIntLittleEndian(uint8_t byte0, uint8_t byte1, uint8_t byte2, uint8_t byte3) noexcept {
+ return (static_cast<SIZEMASS_TYPE>(byte3) << 24) |
+ (static_cast<SIZEMASS_TYPE>(byte2) << 16) |
+ (static_cast<SIZEMASS_TYPE>(byte1) << 8) |
+ static_cast<SIZEMASS_TYPE>(byte0);
+ }
+
+ template <typename dist_t>
+ void Hnsw<dist_t>::LoadIndexWithStream(NmslibIOReader& input) {
+ LOG(LIB_INFO) << "Loading index from an input stream(NmslibIOReader).";
+
+ unsigned int optimIndexFlag= 0;
+ readBinaryPOD(input, optimIndexFlag);
+
+ if (!optimIndexFlag) {
+ throw std::runtime_error("With stream, we only support optimized index type.");
+ } else {
+ LoadOptimizedIndex(input);
+ }
+
+ LOG(LIB_INFO) << "Finished loading index";
+ visitedlistpool = new VisitedListPool(1, totalElementsStored_);
+ }
+
+ template <typename dist_t>
+ void Hnsw<dist_t>::LoadOptimizedIndex(NmslibIOReader& input) {
+ static_assert(sizeof(SIZEMASS_TYPE) == 4, "Expected sizeof(SIZEMASS_TYPE) == 4.");
+
+ LOG(LIB_INFO) << "Loading optimized index(NmslibIOReader).";
+
+ readBinaryPOD(input, totalElementsStored_);
+ readBinaryPOD(input, memoryPerObject_);
+ readBinaryPOD(input, offsetLevel0_);
+ readBinaryPOD(input, offsetData_);
+ readBinaryPOD(input, maxlevel_);
+ readBinaryPOD(input, enterpointId_);
+ readBinaryPOD(input, maxM_);
+ readBinaryPOD(input, maxM0_);
+ readBinaryPOD(input, dist_func_type_);
+ readBinaryPOD(input, searchMethod_);
+
+ LOG(LIB_INFO) << "searchMethod: " << searchMethod_;
+
+ fstdistfunc_ = getDistFunc(dist_func_type_);
+ iscosine_ = (dist_func_type_ == kNormCosine);
+ CHECK_MSG(fstdistfunc_ != nullptr, "Unknown distance function code: " + ConvertToString(dist_func_type_));
+
+ LOG(LIB_INFO) << "Total: " << totalElementsStored_ << ", Memory per object: " << memoryPerObject_;
+ size_t data_plus_links0_size = memoryPerObject_ * totalElementsStored_;
+
+ // we allocate a few extra bytes to prevent prefetch from accessing out of range memory
+ data_level0_memory_ = (char *)malloc(data_plus_links0_size + EXTRA_MEM_PAD_SIZE);
+ CHECK(data_level0_memory_);
+ input.read(data_level0_memory_, data_plus_links0_size);
+ // we allocate a few extra bytes to prevent prefetch from accessing out of range memory
+ linkLists_ = (char **)malloc( (sizeof(void *) * totalElementsStored_) + EXTRA_MEM_PAD_SIZE);
+ CHECK(linkLists_);
+
+ data_rearranged_.resize(totalElementsStored_);
+
+ const size_t bufferSize = 64 * 1024; // 64KB
+ std::unique_ptr<char[]> buffer (new char[bufferSize]);
+ uint32_t end = 0;
+ uint32_t pos = 0;
+ constexpr bool isLittleEndian = _isLittleEndian();
+
+ for (size_t i = 0, remainingBytes = input.remainingBytes(); i < totalElementsStored_; i++) {
+ if ((pos + sizeof(SIZEMASS_TYPE)) >= end) {
+ // Underflow during reading an integer size field.
+ // So the idea is to move the first partial bytes (which is < 4 bytes) to the beginning section of
+ // buffer.
+ // Ex: buffer -> [..., b0, b1] where we only have two bytes and still need to read two bytes more
+ // buffer -> [b0, b1, ...] after move the first part. firstPartLen = 2.
+ const auto firstPartLen = end - pos;
+ if (firstPartLen > 0) {
+ std::memcpy(buffer.get(), buffer.get() + pos, firstPartLen);
+ }
+ // Then, bulk load bytes from input stream. Note that the first few bytes are already occupied by
+ // earlier moving logic, hence required bytes are bufferSize - firstPartLen.
+ const auto copyBytes = std::min(remainingBytes, bufferSize - firstPartLen);
+ input.read(buffer.get() + firstPartLen, copyBytes);
+ remainingBytes -= copyBytes;
+ end = copyBytes + firstPartLen;
+ pos = 0;
+ }
+
+ // Read data size field.
+ // Since NMSLIB directly write 4 bytes integer casting to char*, bytes outline may differ among systems.
+ SIZEMASS_TYPE linkListSize = 0;
+ if (isLittleEndian) {
+ linkListSize = _readIntLittleEndian(buffer[pos], buffer[pos + 1], buffer[pos + 2], buffer[pos + 3]);
+ } else {
+ linkListSize = _readIntBigEndian(buffer[pos], buffer[pos + 1], buffer[pos + 2], buffer[pos + 3]);
+ }
+ pos += sizeof(SIZEMASS_TYPE);
+
+ if (linkListSize == 0) {
+ linkLists_[i] = nullptr;
+ } else {
+ linkLists_[i] = (char *)malloc(linkListSize);
+ CHECK(linkLists_[i]);
+
+ SIZEMASS_TYPE leftLinkListData = linkListSize;
+ auto dataPtr = linkLists_[i];
+ while (leftLinkListData > 0) {
+ if (pos >= end) {
+ // Underflow during read linked list bytes.
+ const auto copyBytes = std::min(remainingBytes, bufferSize);
+ input.read(buffer.get(), copyBytes);
+ remainingBytes -= copyBytes;
+ end = copyBytes;
+ pos = 0;
+ }
+
+ // Read linked list bytes.
+ const auto copyBytes = std::min(leftLinkListData, end - pos);
+ std::memcpy(dataPtr, buffer.get() + pos, copyBytes);
+ dataPtr += copyBytes;
+ leftLinkListData -= copyBytes;
+ pos += copyBytes;
+ } // End while
+ } // End if
+
+ data_rearranged_[i] = new Object(data_level0_memory_ + (i)*memoryPerObject_ + offsetData_);
+ } // End for
+ }

template <typename dist_t>
void
--
2.39.5 (Apple Git-154)

Loading

0 comments on commit 2398d49

Please sign in to comment.