Skip to content

Commit

Permalink
Add a safer mmap reading strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
birchkwok committed Sep 23, 2024
1 parent ff2b60e commit b427209
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
12 changes: 6 additions & 6 deletions lynse/execution_layer/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..index.binary import IndexBinaryJaccard, IndexBinaryHamming
from ..index.flat import IndexFlatIP, IndexFlatL2sq, IndexFlatCos
from .ivf import IVFCreator
from ..utils.utils import drop_duplicated_substr, find_first_file_with_substr
from ..utils.utils import drop_duplicated_substr, find_first_file_with_substr, safe_mmap_reader

_INDEX_ALIAS = {
'IVF-IP-SQ8': 'IVF-IP-SQ8',
Expand Down Expand Up @@ -146,8 +146,8 @@ def build_binary():
if not (self.index_data_path / f'{self.storage_worker.fingerprint}.bd').exists():
_index = _rebuild()
else:
binary_data = np.load(self.index_data_path / f'{self.storage_worker.fingerprint}.bd', mmap_mode='r')
binary_ids = np.load(self.index_ids_path / f'{self.storage_worker.fingerprint}.bi', mmap_mode='r')
binary_data = safe_mmap_reader(self.index_data_path / f'{self.storage_worker.fingerprint}.bd')
binary_ids = safe_mmap_reader(self.index_ids_path / f'{self.storage_worker.fingerprint}.bi')

# load sq8 data as a view, used for rescore
if (not (self.index_data_path / f'{self.storage_worker.fingerprint}.sqd').exists()) or (
Expand All @@ -164,7 +164,7 @@ def build_binary():
f'{self.storage_worker.fingerprint}.*SQ8.index')
)

sq8_data = np.load(self.index_data_path / f'{self.storage_worker.fingerprint}.sqd', mmap_mode='r')
sq8_data = safe_mmap_reader(self.index_data_path / f'{self.storage_worker.fingerprint}.sqd')

_index.data = binary_data
_index.ids = binary_ids
Expand All @@ -191,8 +191,8 @@ def build_binary():
):
_index = _rebuild()
else:
sq8_data = np.load(self.index_data_path / f'{self.storage_worker.fingerprint}.sqd', mmap_mode='r')
sq8_ids = np.load(self.index_ids_path / f'{self.storage_worker.fingerprint}.sqi', mmap_mode='r')
sq8_data = safe_mmap_reader(self.index_data_path / f'{self.storage_worker.fingerprint}.sqd')
sq8_ids = safe_mmap_reader(self.index_ids_path / f'{self.storage_worker.fingerprint}.sqi')

_index.data = sq8_data
_index.ids = sq8_ids
Expand Down
5 changes: 3 additions & 2 deletions lynse/storage_layer/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..core_components.limited_dict import LimitedDict
from ..core_components.locks import ThreadLock
from ..utils.utils import safe_mmap_reader


class PersistentFileStorage:
Expand Down Expand Up @@ -232,8 +233,8 @@ def file_exists(self):
return (self.collection_chunk_path / 'chunk_0').exists()

def mmap_read(self, filename):
return (np.load(self.collection_chunk_path / filename, mmap_mode='r'),
np.load(self.collection_chunk_indices_path / filename, mmap_mode='r'))
return safe_mmap_reader(self.collection_chunk_path / filename), \
safe_mmap_reader(self.collection_chunk_indices_path / filename)

def warm_up(self):
"""Load the data from the file to the memory."""
Expand Down
17 changes: 17 additions & 0 deletions lynse/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,20 @@ def sort_and_get_top_k(arr, k):
top_k_values = np.take_along_axis(arr, top_k_indices, axis=1)

return top_k_indices, top_k_values


def safe_mmap_reader(path, ids=None):
"""
Open a file in memory-mapped mode.
Parameters:
path (str or Pathlike): The path to the file.
ids (list): The slices to read from the file.
Returns:
np.ndarray: The numpy ndarray.
"""
if ids is None:
return np.asarray(memoryview(np.load(path, "r")))

return np.asarray(memoryview(np.load(path, "r")[ids]))

0 comments on commit b427209

Please sign in to comment.