-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dataset parser for vector search #424
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
|
||
import os | ||
import struct | ||
from abc import ABC, ABCMeta, abstractmethod | ||
from enum import Enum | ||
from typing import cast | ||
|
||
import h5py | ||
import numpy as np | ||
|
||
from osbenchmark.exceptions import InvalidExtensionException | ||
from osbenchmark.utils.parse import ConfigurationError | ||
|
||
|
||
class Context(Enum): | ||
"""DataSet context enum. Can be used to add additional context for how a | ||
data-set should be interpreted. | ||
""" | ||
INDEX = 1 | ||
QUERY = 2 | ||
NEIGHBORS = 3 | ||
|
||
|
||
class DataSet(ABC): | ||
"""DataSet interface. Used for reading data-sets from files. | ||
|
||
Methods: | ||
read: Read a chunk of data from the data-set | ||
seek: Get to position in the data-set | ||
size: Gets the number of items in the data-set | ||
reset: Resets internal state of data-set to beginning | ||
""" | ||
__metaclass__ = ABCMeta | ||
|
||
BEGINNING = 0 | ||
|
||
@abstractmethod | ||
def read(self, chunk_size: int): | ||
"""Read vector for given chunk size | ||
@param chunk_size: limits vector size to read | ||
""" | ||
|
||
@abstractmethod | ||
def seek(self, offset: int): | ||
""" | ||
Move reader to given offset | ||
@param offset: value to move reader pointer to | ||
""" | ||
|
||
@abstractmethod | ||
def size(self): | ||
""" | ||
Returns size of dataset | ||
""" | ||
|
||
@abstractmethod | ||
def reset(self): | ||
""" | ||
Resets the dataset reader | ||
""" | ||
|
||
|
||
def get_data_set(data_set_format: str, path: str, context: Context): | ||
""" | ||
Factory method to get instance of Dataset for given format. | ||
Args: | ||
data_set_format: File format like hdf5, bigann | ||
path: Data set file path | ||
context: Dataset Context Enum | ||
Returns: DataSet instance | ||
""" | ||
if data_set_format == HDF5DataSet.FORMAT_NAME: | ||
return HDF5DataSet(path, context) | ||
if data_set_format == BigANNVectorDataSet.FORMAT_NAME: | ||
return BigANNVectorDataSet(path) | ||
raise ConfigurationError("Invalid data set format") | ||
|
||
|
||
class HDF5DataSet(DataSet): | ||
""" Data-set format corresponding to `ANN Benchmarks | ||
<https://github.com/erikbern/ann-benchmarks#data-sets>`_ | ||
""" | ||
|
||
FORMAT_NAME = "hdf5" | ||
|
||
def __init__(self, dataset_path: str, context: Context): | ||
file = h5py.File(dataset_path) | ||
self.data = cast(h5py.Dataset, file[self.parse_context(context)]) | ||
self.current = self.BEGINNING | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_offset = self.current + chunk_size | ||
if end_offset > self.size(): | ||
end_offset = self.size() | ||
|
||
vectors = cast(np.ndarray, self.data[self.current:end_offset]) | ||
self.current = end_offset | ||
return vectors | ||
|
||
def seek(self, offset: int): | ||
|
||
if offset < self.BEGINNING: | ||
raise Exception("Offset must be greater than or equal to 0") | ||
|
||
if offset >= self.size(): | ||
raise Exception("Offset must be less than the data set size") | ||
|
||
self.current = offset | ||
|
||
def size(self): | ||
return self.data.len() | ||
|
||
def reset(self): | ||
self.current = self.BEGINNING | ||
|
||
@staticmethod | ||
def parse_context(context: Context) -> str: | ||
if context == Context.NEIGHBORS: | ||
return "neighbors" | ||
|
||
if context == Context.INDEX: | ||
return "train" | ||
|
||
if context == Context.QUERY: | ||
return "test" | ||
|
||
raise Exception("Unsupported context") | ||
|
||
|
||
class BigANNVectorDataSet(DataSet): | ||
""" Data-set format for vector data-sets for `Big ANN Benchmarks | ||
<https://big-ann-benchmarks.com/index.html#bench-datasets>`_ | ||
""" | ||
|
||
DATA_SET_HEADER_LENGTH = 8 | ||
U8BIN_EXTENSION = "u8bin" | ||
FBIN_EXTENSION = "fbin" | ||
FORMAT_NAME = "bigann" | ||
SUPPORTED_EXTENSION = [ | ||
FBIN_EXTENSION, U8BIN_EXTENSION | ||
] | ||
|
||
BYTES_PER_U8INT = 1 | ||
BYTES_PER_FLOAT = 4 | ||
|
||
def __init__(self, dataset_path: str): | ||
self.file = open(dataset_path, 'rb') | ||
self.file.seek(BigANNVectorDataSet.BEGINNING, os.SEEK_END) | ||
num_bytes = self.file.tell() | ||
self.file.seek(BigANNVectorDataSet.BEGINNING) | ||
|
||
if num_bytes < BigANNVectorDataSet.DATA_SET_HEADER_LENGTH: | ||
raise Exception("Invalid file: file size cannot be less than {} bytes".format( | ||
BigANNVectorDataSet.DATA_SET_HEADER_LENGTH)) | ||
|
||
self.num_points = int.from_bytes(self.file.read(4), "little") | ||
self.dimension = int.from_bytes(self.file.read(4), "little") | ||
self.bytes_per_num = self._get_data_size(dataset_path) | ||
|
||
if (num_bytes - BigANNVectorDataSet.DATA_SET_HEADER_LENGTH) != ( | ||
self.num_points * self.dimension * self.bytes_per_num): | ||
raise Exception("Invalid file. File size is not matching with expected estimated " | ||
"value based on number of points, dimension and bytes per point") | ||
|
||
self.reader = self._value_reader(dataset_path) | ||
self.current = BigANNVectorDataSet.BEGINNING | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_offset = self.current + chunk_size | ||
if end_offset > self.size(): | ||
end_offset = self.size() | ||
|
||
vectors = np.asarray( | ||
[self._read_vector() for _ in range(end_offset - self.current)] | ||
) | ||
self.current = end_offset | ||
return vectors | ||
|
||
def seek(self, offset: int): | ||
|
||
if offset < self.BEGINNING: | ||
raise Exception("Offset must be greater than or equal to 0") | ||
|
||
if offset >= self.size(): | ||
raise Exception("Offset must be less than the data set size") | ||
|
||
bytes_offset = BigANNVectorDataSet.DATA_SET_HEADER_LENGTH + \ | ||
self.dimension * self.bytes_per_num * offset | ||
VijayanB marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.file.seek(bytes_offset) | ||
self.current = offset | ||
|
||
def _read_vector(self): | ||
return np.asarray([self.reader(self.file) for _ in | ||
range(self.dimension)]) | ||
|
||
def size(self): | ||
return self.num_points | ||
|
||
def reset(self): | ||
self.file.seek(BigANNVectorDataSet.DATA_SET_HEADER_LENGTH) | ||
self.current = BigANNVectorDataSet.BEGINNING | ||
|
||
def __del__(self): | ||
self.file.close() | ||
|
||
@staticmethod | ||
def _get_extension(file_name): | ||
ext = file_name.split('.')[-1] | ||
if ext not in BigANNVectorDataSet.SUPPORTED_EXTENSION: | ||
raise InvalidExtensionException( | ||
"Unknown extension :{}, supported extensions are: {}".format( | ||
ext, str(BigANNVectorDataSet.SUPPORTED_EXTENSION))) | ||
return ext | ||
|
||
@staticmethod | ||
def _get_data_size(file_name): | ||
ext = BigANNVectorDataSet._get_extension(file_name) | ||
if ext == BigANNVectorDataSet.U8BIN_EXTENSION: | ||
return BigANNVectorDataSet.BYTES_PER_U8INT | ||
|
||
if ext == BigANNVectorDataSet.FBIN_EXTENSION: | ||
return BigANNVectorDataSet.BYTES_PER_FLOAT | ||
|
||
@staticmethod | ||
def _value_reader(file_name): | ||
ext = BigANNVectorDataSet._get_extension(file_name) | ||
if ext == BigANNVectorDataSet.U8BIN_EXTENSION: | ||
return lambda file: float(int.from_bytes(file.read(BigANNVectorDataSet.BYTES_PER_U8INT), "little")) | ||
|
||
if ext == BigANNVectorDataSet.FBIN_EXTENSION: | ||
return lambda file: struct.unpack('<f', file.read(BigANNVectorDataSet.BYTES_PER_FLOAT)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
from osbenchmark.exceptions import ConfigurationError | ||
|
||
|
||
def parse_string_parameter(key: str, params: dict, default: str = None) -> str: | ||
if key not in params: | ||
if default is not None: | ||
return default | ||
raise ConfigurationError( | ||
"Value cannot be None for param {}".format(key) | ||
) | ||
|
||
if isinstance(params[key], str): | ||
return params[key] | ||
|
||
raise ConfigurationError("Value must be a string for param {}".format(key)) | ||
|
||
|
||
def parse_int_parameter(key: str, params: dict, default: int = None) -> int: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's some redundancy in these three functions and I think we can simplify it. Assuming OSB will need to check the type before calling one of the three functions, maybe we can condense these three functions into a single generic function where
I've tested this against your test cases and they all pass as well. Let me know your thoughts @VijayanB ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if we pass integer value to string parameter? The above looks complicated for simple function. I don't see any advantage. However, if you have strong opinion, i can refactor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Are you referring to passing an integer value to the string parameter That being said, I don't have a strong preference and if it's more simple to use these three functions, we can keep them as is 👍🏻 |
||
if key not in params: | ||
if default: | ||
return default | ||
raise ConfigurationError( | ||
"Value cannot be None for param {}".format(key) | ||
) | ||
|
||
if isinstance(params[key], int): | ||
return params[key] | ||
|
||
raise ConfigurationError("Value must be a int for param {}".format(key)) | ||
|
||
|
||
def parse_float_parameter(key: str, params: dict, default: float = None) -> float: | ||
if key not in params: | ||
if default: | ||
return default | ||
raise ConfigurationError( | ||
"Value cannot be None for param {}".format(key) | ||
) | ||
|
||
if isinstance(params[key], float): | ||
return params[key] | ||
|
||
raise ConfigurationError("Value must be a float for param {}".format(key)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this!