-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Made a patch in NMSLIB to avoid frequently calling JNI for better loa…
…ding index performance. Signed-off-by: Dooyong Kim <[email protected]>
- Loading branch information
Dooyong Kim
committed
Oct 11, 2024
1 parent
ce98151
commit 2398d49
Showing
14 changed files
with
328 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
222 changes: 222 additions & 0 deletions
222
jni/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
From ea99c7ce2cb1775d8130da9eaaffeff89bb6ffd3 Mon Sep 17 00:00:00 2001 | ||
From: Dooyong Kim <[email protected]> | ||
Date: Fri, 11 Oct 2024 14:05:40 -0700 | ||
Subject: [PATCH] Added streaming apis for vector index loading in Hnsw. | ||
|
||
Signed-off-by: Dooyong Kim <[email protected]> | ||
--- | ||
similarity_search/include/method/hnsw.h | 3 + | ||
similarity_search/include/utils.h | 12 ++ | ||
similarity_search/src/method/hnsw.cc | 139 +++++++++++++++++++++++- | ||
3 files changed, 153 insertions(+), 1 deletion(-) | ||
|
||
diff --git a/similarity_search/include/method/hnsw.h b/similarity_search/include/method/hnsw.h | ||
index e6dcea7..433f98f 100644 | ||
--- a/similarity_search/include/method/hnsw.h | ||
+++ b/similarity_search/include/method/hnsw.h | ||
@@ -457,6 +457,8 @@ namespace similarity { | ||
|
||
virtual void LoadIndex(const string &location) override; | ||
|
||
+ void LoadIndexWithStream(similarity::NmslibIOReader& in); | ||
+ | ||
Hnsw(bool PrintProgress, const Space<dist_t> &space, const ObjectVector &data); | ||
void CreateIndex(const AnyParams &IndexParams) override; | ||
|
||
@@ -500,6 +502,7 @@ namespace similarity { | ||
|
||
void SaveOptimizedIndex(std::ostream& output); | ||
void LoadOptimizedIndex(std::istream& input); | ||
+ void LoadOptimizedIndex(NmslibIOReader& input); | ||
|
||
void SaveRegularIndexBin(std::ostream& output); | ||
void LoadRegularIndexBin(std::istream& input); | ||
diff --git a/similarity_search/include/utils.h b/similarity_search/include/utils.h | ||
index b521c26..a3931b7 100644 | ||
--- a/similarity_search/include/utils.h | ||
+++ b/similarity_search/include/utils.h | ||
@@ -299,12 +299,24 @@ inline void WriteField(ostream& out, const string& fieldName, const FieldType& f | ||
} | ||
} | ||
|
||
+struct NmslibIOReader { | ||
+ virtual ~NmslibIOReader() = default; | ||
+ | ||
+ virtual void read(char* bytes, size_t len) = 0; | ||
+ | ||
+ virtual size_t remainingBytes() = 0; | ||
+}; | ||
|
||
template <typename T> | ||
void writeBinaryPOD(ostream& out, const T& podRef) { | ||
out.write((char*)&podRef, sizeof(T)); | ||
} | ||
|
||
+template <typename T> | ||
+static void readBinaryPOD(NmslibIOReader& in, T& podRef) { | ||
+ in.read((char*)&podRef, sizeof(T)); | ||
+} | ||
+ | ||
template <typename T> | ||
static void readBinaryPOD(istream& in, T& podRef) { | ||
in.read((char*)&podRef, sizeof(T)); | ||
diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc | ||
index 4080b3b..63482fd 100644 | ||
--- a/similarity_search/src/method/hnsw.cc | ||
+++ b/similarity_search/src/method/hnsw.cc | ||
@@ -950,7 +950,6 @@ namespace similarity { | ||
" read so far doesn't match the number of read lines: " + ConvertToString(lineNum)); | ||
} | ||
|
||
- | ||
template <typename dist_t> | ||
void | ||
Hnsw<dist_t>::LoadRegularIndexBin(std::istream& input) { | ||
@@ -1034,6 +1033,144 @@ namespace similarity { | ||
|
||
} | ||
|
||
+ constexpr bool _isLittleEndian() { | ||
+ uint32_t value = 1; | ||
+ return (value & 0xFFU) == 1; | ||
+ } | ||
+ | ||
+ SIZEMASS_TYPE _readIntBigEndian(uint8_t byte0, uint8_t byte1, uint8_t byte2, uint8_t byte3) noexcept { | ||
+ return (static_cast<SIZEMASS_TYPE>(byte0) << 24) | | ||
+ (static_cast<SIZEMASS_TYPE>(byte1) << 16) | | ||
+ (static_cast<SIZEMASS_TYPE>(byte2) << 8) | | ||
+ static_cast<SIZEMASS_TYPE>(byte3); | ||
+ } | ||
+ | ||
+ SIZEMASS_TYPE _readIntLittleEndian(uint8_t byte0, uint8_t byte1, uint8_t byte2, uint8_t byte3) noexcept { | ||
+ return (static_cast<SIZEMASS_TYPE>(byte3) << 24) | | ||
+ (static_cast<SIZEMASS_TYPE>(byte2) << 16) | | ||
+ (static_cast<SIZEMASS_TYPE>(byte1) << 8) | | ||
+ static_cast<SIZEMASS_TYPE>(byte0); | ||
+ } | ||
+ | ||
+ template <typename dist_t> | ||
+ void Hnsw<dist_t>::LoadIndexWithStream(NmslibIOReader& input) { | ||
+ LOG(LIB_INFO) << "Loading index from an input stream(NmslibIOReader)."; | ||
+ | ||
+ unsigned int optimIndexFlag= 0; | ||
+ readBinaryPOD(input, optimIndexFlag); | ||
+ | ||
+ if (!optimIndexFlag) { | ||
+ throw std::runtime_error("With stream, we only support optimized index type."); | ||
+ } else { | ||
+ LoadOptimizedIndex(input); | ||
+ } | ||
+ | ||
+ LOG(LIB_INFO) << "Finished loading index"; | ||
+ visitedlistpool = new VisitedListPool(1, totalElementsStored_); | ||
+ } | ||
+ | ||
+ template <typename dist_t> | ||
+ void Hnsw<dist_t>::LoadOptimizedIndex(NmslibIOReader& input) { | ||
+ static_assert(sizeof(SIZEMASS_TYPE) == 4, "Expected sizeof(SIZEMASS_TYPE) == 4."); | ||
+ | ||
+ LOG(LIB_INFO) << "Loading optimized index(NmslibIOReader)."; | ||
+ | ||
+ readBinaryPOD(input, totalElementsStored_); | ||
+ readBinaryPOD(input, memoryPerObject_); | ||
+ readBinaryPOD(input, offsetLevel0_); | ||
+ readBinaryPOD(input, offsetData_); | ||
+ readBinaryPOD(input, maxlevel_); | ||
+ readBinaryPOD(input, enterpointId_); | ||
+ readBinaryPOD(input, maxM_); | ||
+ readBinaryPOD(input, maxM0_); | ||
+ readBinaryPOD(input, dist_func_type_); | ||
+ readBinaryPOD(input, searchMethod_); | ||
+ | ||
+ LOG(LIB_INFO) << "searchMethod: " << searchMethod_; | ||
+ | ||
+ fstdistfunc_ = getDistFunc(dist_func_type_); | ||
+ iscosine_ = (dist_func_type_ == kNormCosine); | ||
+ CHECK_MSG(fstdistfunc_ != nullptr, "Unknown distance function code: " + ConvertToString(dist_func_type_)); | ||
+ | ||
+ LOG(LIB_INFO) << "Total: " << totalElementsStored_ << ", Memory per object: " << memoryPerObject_; | ||
+ size_t data_plus_links0_size = memoryPerObject_ * totalElementsStored_; | ||
+ | ||
+ // we allocate a few extra bytes to prevent prefetch from accessing out of range memory | ||
+ data_level0_memory_ = (char *)malloc(data_plus_links0_size + EXTRA_MEM_PAD_SIZE); | ||
+ CHECK(data_level0_memory_); | ||
+ input.read(data_level0_memory_, data_plus_links0_size); | ||
+ // we allocate a few extra bytes to prevent prefetch from accessing out of range memory | ||
+ linkLists_ = (char **)malloc( (sizeof(void *) * totalElementsStored_) + EXTRA_MEM_PAD_SIZE); | ||
+ CHECK(linkLists_); | ||
+ | ||
+ data_rearranged_.resize(totalElementsStored_); | ||
+ | ||
+ const size_t bufferSize = 64 * 1024; // 64KB | ||
+ std::unique_ptr<char[]> buffer (new char[bufferSize]); | ||
+ uint32_t end = 0; | ||
+ uint32_t pos = 0; | ||
+ constexpr bool isLittleEndian = _isLittleEndian(); | ||
+ | ||
+ for (size_t i = 0, remainingBytes = input.remainingBytes(); i < totalElementsStored_; i++) { | ||
+ if ((pos + sizeof(SIZEMASS_TYPE)) >= end) { | ||
+ // Underflow during reading an integer size field. | ||
+ // So the idea is to move the first partial bytes (which is < 4 bytes) to the beginning section of | ||
+ // buffer. | ||
+ // Ex: buffer -> [..., b0, b1] where we only have two bytes and still need to read two bytes more | ||
+ // buffer -> [b0, b1, ...] after move the first part. firstPartLen = 2. | ||
+ const auto firstPartLen = end - pos; | ||
+ if (firstPartLen > 0) { | ||
+ std::memcpy(buffer.get(), buffer.get() + pos, firstPartLen); | ||
+ } | ||
+ // Then, bulk load bytes from input stream. Note that the first few bytes are already occupied by | ||
+ // earlier moving logic, hence required bytes are bufferSize - firstPartLen. | ||
+ const auto copyBytes = std::min(remainingBytes, bufferSize - firstPartLen); | ||
+ input.read(buffer.get() + firstPartLen, copyBytes); | ||
+ remainingBytes -= copyBytes; | ||
+ end = copyBytes + firstPartLen; | ||
+ pos = 0; | ||
+ } | ||
+ | ||
+ // Read data size field. | ||
+ // Since NMSLIB directly write 4 bytes integer casting to char*, bytes outline may differ among systems. | ||
+ SIZEMASS_TYPE linkListSize = 0; | ||
+ if (isLittleEndian) { | ||
+ linkListSize = _readIntLittleEndian(buffer[pos], buffer[pos + 1], buffer[pos + 2], buffer[pos + 3]); | ||
+ } else { | ||
+ linkListSize = _readIntBigEndian(buffer[pos], buffer[pos + 1], buffer[pos + 2], buffer[pos + 3]); | ||
+ } | ||
+ pos += sizeof(SIZEMASS_TYPE); | ||
+ | ||
+ if (linkListSize == 0) { | ||
+ linkLists_[i] = nullptr; | ||
+ } else { | ||
+ linkLists_[i] = (char *)malloc(linkListSize); | ||
+ CHECK(linkLists_[i]); | ||
+ | ||
+ SIZEMASS_TYPE leftLinkListData = linkListSize; | ||
+ auto dataPtr = linkLists_[i]; | ||
+ while (leftLinkListData > 0) { | ||
+ if (pos >= end) { | ||
+ // Underflow during read linked list bytes. | ||
+ const auto copyBytes = std::min(remainingBytes, bufferSize); | ||
+ input.read(buffer.get(), copyBytes); | ||
+ remainingBytes -= copyBytes; | ||
+ end = copyBytes; | ||
+ pos = 0; | ||
+ } | ||
+ | ||
+ // Read linked list bytes. | ||
+ const auto copyBytes = std::min(leftLinkListData, end - pos); | ||
+ std::memcpy(dataPtr, buffer.get() + pos, copyBytes); | ||
+ dataPtr += copyBytes; | ||
+ leftLinkListData -= copyBytes; | ||
+ pos += copyBytes; | ||
+ } // End while | ||
+ } // End if | ||
+ | ||
+ data_rearranged_[i] = new Object(data_level0_memory_ + (i)*memoryPerObject_ + offsetData_); | ||
+ } // End for | ||
+ } | ||
|
||
template <typename dist_t> | ||
void | ||
-- | ||
2.39.5 (Apple Git-154) | ||
|
Oops, something went wrong.