Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Introduce a loading layer in NMSLIB. (#2185) #2220

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD)
### Features
### Enhancements
* Introduce a loading layer in native engine [#2185](https://github.com/opensearch-project/k-NN/pull/2185)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in 2.x

### Bug Fixes
### Infrastructure
### Documentation
Expand Down
3 changes: 2 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ if ("${WIN32}" STREQUAL "")
tests/commons_test.cpp
tests/faiss_stream_support_test.cpp
tests/faiss_index_service_test.cpp
)
tests/nmslib_stream_support_test.cpp
)

target_link_libraries(
jni_test
Expand Down
3 changes: 1 addition & 2 deletions jni/cmake/init-nmslib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ if (NOT EXISTS ${NMS_REPO_DIR})
execute_process(COMMAND git submodule update --init -- external/nmslib WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif ()


# Apply patches
if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true)
# Define list of patch files
set(PATCH_FILE_LIST)
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch")

# Get patch id of the last commit
execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib)
Expand Down
81 changes: 4 additions & 77 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#ifndef OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H

#include "faiss/impl/io.h"
#include "jni_util.h"
#include "native_engines_stream_support.h"

#include <jni.h>
#include <stdexcept>
Expand All @@ -23,80 +24,6 @@
namespace knn_jni {
namespace stream {

/**
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/

class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface,
JNIEnv *_env,
jobject _indexInput)
: jni_interface(_jni_interface),
env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
const auto readBytes =
jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
auto primitiveArray =
(jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIUtilInterface *jni_interface;
JNIEnv *env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
}; // class NativeEngineIndexInputMediator



/**
Expand Down Expand Up @@ -133,4 +60,4 @@ class FaissOpenSearchIOReader final : public faiss::IOReader {
}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#endif //OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H
9 changes: 7 additions & 2 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ namespace knn_jni {

virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0;

virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0;
virtual jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz,
jmethodID methodID, jvalue *args) = 0;

virtual jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz,
jmethodID methodID, jvalue* args) = 0;

// --------------------------------------------------------------------------
};
Expand Down Expand Up @@ -194,7 +198,8 @@ namespace knn_jni {
jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final;
jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final;
jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, jvalue *args) final;
jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) final;
void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final;
void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final;

Expand Down
125 changes: 125 additions & 0 deletions jni/include/native_engines_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H

#include "jni_util.h"

#include <jni.h>
#include <stdexcept>
#include <iostream>
#include <cstring>

namespace knn_jni {
namespace stream {



/**
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/
class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface,
JNIEnv *_env,
jobject _indexInput)
: jni_interface(_jni_interface),
env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)),
remainingBytesMethod(getRemainingBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
auto jclazz = getIndexInputWithBufferClass(jni_interface, env);

while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
jvalue args;
args.j = nbytes;
const auto readBytes =
jni_interface->CallNonvirtualIntMethodA(env, indexInput, jclazz, copyBytesMethod, &args);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
auto primitiveArray =
(jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

int64_t remainingBytes() {
return jni_interface->CallNonvirtualLongMethodA(env,
indexInput,
getIndexInputWithBufferClass(jni_interface, env),
remainingBytesMethod,
nullptr);
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jmethodID getRemainingBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "remainingBytes", "()J");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIUtilInterface *jni_interface;
JNIEnv *env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
jmethodID remainingBytesMethod;
}; // class NativeEngineIndexInputMediator



}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
51 changes: 51 additions & 0 deletions jni/include/nmslib_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H

#include "native_engines_stream_support.h"

namespace knn_jni {
namespace stream {



/**
* NmslibIOReader implementation delegating NativeEngineIndexInputMediator to read bytes.
*/
class NmslibOpenSearchIOReader final : public similarity::NmslibIOReader {
public:
explicit NmslibOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator)
: mediator(_mediator) {
}

void read(char *bytes, size_t len) final {
if (len > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(len, (uint8_t *) bytes);
}
}

size_t remainingBytes() final {
return mediator->remainingBytes();
}

private:
NativeEngineIndexInputMediator *mediator;
}; // class NmslibOpenSearchIOReader



}
}

#endif //OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H
8 changes: 8 additions & 0 deletions jni/include/nmslib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ, jobject parametersJ);

// Load an index via an input stream into memory. Use parametersJ to set any query time parameters
//
// Return a pointer to the loaded index
jlong LoadIndexWithStream(knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
jobject readStream,
jobject parametersJ);

// Execute a query against the index located in memory at indexPointerJ.
//
// Return an array of KNNQueryResults
Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_NmslibService.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex
(JNIEnv *, jclass, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: loadIndexWithStream
* Signature: (Lorg/opensearch/knn/index/store/IndexInputWithBuffer;Ljava/util/Map;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndexWithStream
(JNIEnv *, jclass, jobject, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: queryIndex
Expand Down
Loading
Loading