Skip to content

Commit

Permalink
Move data set loading outside constructor
Browse files Browse the repository at this point in the history
Param sources will be instatiated before corpora loading.
Hence, set data set parameters to initial value during initialization
and load data set during partition stage.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 30, 2024
1 parent c74c993 commit 36ae46a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 46 deletions.
52 changes: 40 additions & 12 deletions osbenchmark/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,18 @@ class HDF5DataSet(DataSet):
FORMAT_NAME = "hdf5"

def __init__(self, dataset_path: str, context: Context):
file = h5py.File(dataset_path)
self.data = cast(h5py.Dataset, file[self.parse_context(context)])
self.dataset_path = dataset_path
self.context = self.parse_context(context)
self.current = self.BEGINNING
self.data = None

def _load(self):
if self.data is None:
file = h5py.File(self.dataset_path)
self.data = cast(h5py.Dataset, file[self.context])

def read(self, chunk_size: int):
self._load()
if self.current >= self.size():
return None

Expand All @@ -106,7 +113,7 @@ def read(self, chunk_size: int):
return vectors

def seek(self, offset: int):

# load file first before seek
if offset < self.BEGINNING:
raise Exception("Offset must be greater than or equal to 0")

Expand All @@ -116,6 +123,8 @@ def seek(self, offset: int):
self.current = offset

def size(self):
# load file first before return size
self._load()
return self.data.len()

def reset(self):
Expand Down Expand Up @@ -152,7 +161,15 @@ class BigANNVectorDataSet(DataSet):
BYTES_PER_FLOAT = 4

def __init__(self, dataset_path: str):
self.file = open(dataset_path, 'rb')
self.dataset_path = dataset_path
self.file = None
self.current = BigANNVectorDataSet.BEGINNING
self.num_points = 0
self.dimension = 0
self.bytes_per_num = 0

def _init_internal_params(self):
self.file = open(self.dataset_path, 'rb')
self.file.seek(BigANNVectorDataSet.BEGINNING, os.SEEK_END)
num_bytes = self.file.tell()
self.file.seek(BigANNVectorDataSet.BEGINNING)
Expand All @@ -163,17 +180,23 @@ def __init__(self, dataset_path: str):

self.num_points = int.from_bytes(self.file.read(4), "little")
self.dimension = int.from_bytes(self.file.read(4), "little")
self.bytes_per_num = self._get_data_size(dataset_path)
self.bytes_per_num = self._get_data_size(self.dataset_path)

if (num_bytes - BigANNVectorDataSet.DATA_SET_HEADER_LENGTH) != (
self.num_points * self.dimension * self.bytes_per_num):
raise Exception("Invalid file. File size is not matching with expected estimated "
"value based on number of points, dimension and bytes per point")

self.reader = self._value_reader(dataset_path)
self.current = BigANNVectorDataSet.BEGINNING
self.reader = self._value_reader(self.dataset_path)

def _load(self):
# load file if it is not loaded yet
if self.file is None:
self._init_internal_params()

def read(self, chunk_size: int):
# load file first before read
self._load()
if self.current >= self.size():
return None

Expand All @@ -188,15 +211,16 @@ def read(self, chunk_size: int):
return vectors

def seek(self, offset: int):

# load file first before seek
self._load()
if offset < self.BEGINNING:
raise Exception("Offset must be greater than or equal to 0")

if offset >= self.size():
raise Exception("Offset must be less than the data set size")

bytes_offset = BigANNVectorDataSet.DATA_SET_HEADER_LENGTH + \
self.dimension * self.bytes_per_num * offset
bytes_offset = BigANNVectorDataSet.DATA_SET_HEADER_LENGTH + (
self.dimension * self.bytes_per_num * offset)
self.file.seek(bytes_offset)
self.current = offset

Expand All @@ -205,14 +229,18 @@ def _read_vector(self):
range(self.dimension)])

def size(self):
# load file first before return size
self._load()
return self.num_points

def reset(self):
self.file.seek(BigANNVectorDataSet.DATA_SET_HEADER_LENGTH)
if self.file:
self.file.seek(BigANNVectorDataSet.DATA_SET_HEADER_LENGTH)
self.current = BigANNVectorDataSet.BEGINNING

def __del__(self):
self.file.close()
if self.file:
self.file.close()

@staticmethod
def _get_extension(file_name):
Expand Down
38 changes: 23 additions & 15 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,22 +837,15 @@ class VectorDataSetPartitionParamSource(ParamSource):
def __init__(self, workload, params, context: Context, **kwargs):
super().__init__(workload, params, **kwargs)
self.field_name: str = parse_string_parameter("field", params)

self.context = context
self.data_set_format = parse_string_parameter("data_set_format", params)
self.data_set_path = parse_string_parameter("data_set_path", params)
self.data_set: DataSet = get_data_set(
self.data_set_format, self.data_set_path, self.context)

num_vectors: int = parse_int_parameter(
"num_vectors", params, self.data_set.size())
# if value is -1 or greater than dataset size, use dataset size as num_vectors
self.num_vectors = self.data_set.size() if (
num_vectors < 0 or num_vectors > self.data_set.size()) else num_vectors
self.total = self.num_vectors
self.num_vectors: int = parse_int_parameter("num_vectors", params, -1)
self.total = 1
self.current = 0
self.percent_completed = 0
self.offset = 0
self.data_set: DataSet = None

@property
def infinite(self):
Expand All @@ -872,6 +865,14 @@ def partition(self, partition_index, total_partitions):
Returns:
The parameter source for this particular partition
"""
if self.data_set is None:
self.data_set: DataSet = get_data_set(
self.data_set_format, self.data_set_path, self.context)
# if value is -1 or greater than dataset size, use dataset size as num_vectors
if self.num_vectors < 0 or self.num_vectors > self.data_set.size():
self.num_vectors = self.data_set.size()
self.total = self.num_vectors

partition_x = copy.copy(self)

num_vectors = int(self.num_vectors / total_partitions)
Expand Down Expand Up @@ -934,10 +935,7 @@ def __init__(self, workloads, params, query_params, **kwargs):
self.current_rep = 1
self.neighbors_data_set_format = parse_string_parameter(
self.PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT, params, self.data_set_format)
self.neighbors_data_set_path = parse_string_parameter(
self.PARAMS_NAME_NEIGHBORS_DATA_SET_PATH, params, self.data_set_path)
self.neighbors_data_set: DataSet = get_data_set(
self.neighbors_data_set_format, self.neighbors_data_set_path, Context.NEIGHBORS)
self.neighbors_data_set_path = params.get(self.PARAMS_NAME_NEIGHBORS_DATA_SET_PATH)
operation_type = parse_string_parameter(self.PARAMS_NAME_OPERATION_TYPE, params,
self.PARAMS_VALUE_VECTOR_SEARCH)
self.query_params = query_params
Expand All @@ -961,11 +959,21 @@ def _update_body_params(self, vector):
if self.PARAMS_NAME_SIZE not in body_params:
body_params[self.PARAMS_NAME_SIZE] = self.k
if self.PARAMS_NAME_QUERY in body_params:
self.logger.warning("[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
self.logger.warning(
"[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
# override query params with vector search query
body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector)
self.query_params.update({self.PARAMS_NAME_BODY: body_params})

def partition(self, partition_index, total_partitions):
partition = super().partition(partition_index, total_partitions)
if not self.neighbors_data_set_path:
self.neighbors_data_set_path = self.data_set_path
# add neighbor instance to partition
partition.neighbors_data_set = get_data_set(
self.neighbors_data_set_format, self.neighbors_data_set_path, Context.NEIGHBORS)
return partition

def params(self):
"""
Returns: A query parameter with a vector and neighbor from a data set
Expand Down
38 changes: 19 additions & 19 deletions tests/workload/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,7 +2536,7 @@ def test_invalid_data_set_format(self):
workload.Workload(name="unit-test"),
test_param_source_params,
self.DEFAULT_CONTEXT
)
).partition(0, 1)
)

def test_invalid_data_set_path(self):
Expand All @@ -2553,11 +2553,12 @@ def test_invalid_data_set_path(self):
workload.Workload(name="unit-test"),
test_param_source_params,
self.DEFAULT_CONTEXT
)
).partition(0, 1)
)

def test_partition_hdf5(self):
num_vectors = 100
num_partitions = 10

hdf5_data_set_path = create_data_set(
num_vectors,
Expand All @@ -2579,8 +2580,7 @@ def test_partition_hdf5(self):
self.DEFAULT_CONTEXT
)

num_partitions = 10
vectors_per_partition = test_param_source.num_vectors // num_partitions
vectors_per_partition = num_vectors // num_partitions

self._test_partition(
test_param_source,
Expand All @@ -2590,6 +2590,7 @@ def test_partition_hdf5(self):

def test_partition_bigann(self):
num_vectors = 100
num_partitions = 10
float_extension = "fbin"

bigann_data_set_path = create_data_set(
Expand All @@ -2611,14 +2612,11 @@ def test_partition_bigann(self):
test_param_source_params,
self.DEFAULT_CONTEXT
)

num_partitions = 10
vecs_per_partition = test_param_source.num_vectors // num_partitions

vectors_per_partition = num_vectors // num_partitions
self._test_partition(
test_param_source,
num_partitions,
vecs_per_partition
vectors_per_partition
)

def _test_partition(
Expand Down Expand Up @@ -2691,19 +2689,20 @@ def test_params_default(self):
"request-params": {},
}
)
query_param_source_partition = query_param_source.partition(0, 1)

# Check each
for _ in range(DEFAULT_NUM_VECTORS):
self._check_params(
query_param_source.params(),
query_param_source_partition.params(),
self.DEFAULT_FIELD_NAME,
self.DEFAULT_DIMENSION,
k,
)

# Assert last call creates stop iteration
with self.assertRaises(StopIteration):
query_param_source.params()
query_param_source_partition.params()

def test_params_custom_body(self):
# Create a data set
Expand Down Expand Up @@ -2741,11 +2740,12 @@ def test_params_custom_body(self):
}
}
)
query_param_source_partition = query_param_source.partition(0, 1)

# Check each
for _ in range(DEFAULT_NUM_VECTORS):
self._check_params(
query_param_source.params(),
query_param_source_partition.params(),
self.DEFAULT_FIELD_NAME,
self.DEFAULT_DIMENSION,
k,
Expand All @@ -2754,7 +2754,7 @@ def test_params_custom_body(self):

# Assert last call creates stop iteration
with self.assertRaises(StopIteration):
query_param_source.params()
query_param_source_partition.params()

def _check_params(
self,
Expand Down Expand Up @@ -2822,12 +2822,12 @@ def test_params_default(self):
}
bulk_param_source = BulkVectorsFromDataSetParamSource(
workload.Workload(name="unit-test"), test_param_source_params)

bulk_param_source_partition = bulk_param_source.partition(0, 1)
# Check each payload returned
vectors_consumed = 0
while vectors_consumed < num_vectors:
expected_num_vectors = min(num_vectors - vectors_consumed, bulk_size)
actual_params = bulk_param_source.params()
actual_params = bulk_param_source_partition.params()
self._check_params(
actual_params,
self.DEFAULT_INDEX_NAME,
Expand All @@ -2840,7 +2840,7 @@ def test_params_default(self):

# Assert last call creates stop iteration
with self.assertRaises(StopIteration):
bulk_param_source.params()
bulk_param_source_partition.params()

def test_params_custom(self):
num_vectors = 49
Expand All @@ -2863,12 +2863,12 @@ def test_params_custom(self):
}
bulk_param_source = BulkVectorsFromDataSetParamSource(
workload.Workload(name="unit-test"), test_param_source_params)

bulk_param_source_partition = bulk_param_source.partition(0, 1)
# Check each payload returned
vectors_consumed = 0
while vectors_consumed < num_vectors:
expected_num_vectors = min(num_vectors - vectors_consumed, bulk_size)
actual_params = bulk_param_source.params()
actual_params = bulk_param_source_partition.params()
self._check_params(
actual_params,
self.DEFAULT_INDEX_NAME,
Expand All @@ -2881,7 +2881,7 @@ def test_params_custom(self):

# Assert last call creates stop iteration
with self.assertRaises(StopIteration):
bulk_param_source.params()
bulk_param_source_partition.params()

def _check_params(
self,
Expand Down

0 comments on commit 36ae46a

Please sign in to comment.