Skip to content

Commit

Permalink
Add vector search bulk param and runner (#431)
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB authored Jan 3, 2024
1 parent 1313e1f commit 6a073e7
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 9 deletions.
2 changes: 1 addition & 1 deletion osbenchmark/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def parse_string_parameter(key: str, params: dict, default: str = None) -> str:
if key not in params:
if key not in params or not params[key]:
if default is not None:
return default
raise ConfigurationError(
Expand Down
34 changes: 33 additions & 1 deletion osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@
from typing import List, Optional

import ijson
from opensearchpy import ConnectionTimeout

from osbenchmark import exceptions, workload
from osbenchmark.utils import convert

# Mapping from operation type to specific runner
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter

__RUNNERS = {}

Expand All @@ -58,6 +60,7 @@ def register_default_runners():
register_runner(workload.OperationType.PaginatedSearch, Query(), async_runner=True)
register_runner(workload.OperationType.ScrollSearch, Query(), async_runner=True)
register_runner(workload.OperationType.VectorSearch, Query(), async_runner=True)
register_runner(workload.OperationType.BulkVectorDataSet, BulkVectorDataSet(), async_runner=True)
register_runner(workload.OperationType.RawRequest, RawRequest(), async_runner=True)
register_runner(workload.OperationType.Composite, Composite(), async_runner=True)
register_runner(workload.OperationType.SubmitAsyncSearch, SubmitAsyncSearch(), async_runner=True)
Expand Down Expand Up @@ -625,6 +628,35 @@ def __repr__(self, *args, **kwargs):
return "bulk-index"


# TODO: Add retry logic to BulkIndex, so that we can remove BulkVectorDataSet and use BulkIndex.
class BulkVectorDataSet(Runner):
"""
Bulk inserts vector search dataset of type hdf5, bigann
"""

NAME = "bulk-vector-data-set"

async def __call__(self, opensearch, params):
size = parse_int_parameter("size", params)
retries = parse_int_parameter("retries", params, 0) + 1

for attempt in range(retries):
try:
await opensearch.bulk(
body=params["body"]
)

return size, "docs"
except ConnectionTimeout:
self.logger.warning("Bulk vector ingestion timed out. Retrying attempt: %d", attempt)

raise TimeoutError("Failed to submit bulk request in specified number "
"of retries: {}".format(retries))

def __repr__(self, *args, **kwargs):
return self.NAME


class ForceMerge(Runner):
"""
Runs a force merge operation against OpenSearch.
Expand Down Expand Up @@ -1051,7 +1083,7 @@ def calculate_recall(predictions, neighbors, top_k):
if _is_empty_search_results(response_json):
self.logger.info("Vector search query returned no results.")
return result
id_field = params.get("id-field-name", "_id")
id_field = parse_string_parameter("id-field-name", params, "_id")
candidates = []
for hit in response_json['hits']['hits']:
field_value = _get_field_value(hit, id_field)
Expand Down
98 changes: 92 additions & 6 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import time
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Dict, Any

import numpy as np

from osbenchmark import exceptions
from osbenchmark.utils import io
Expand Down Expand Up @@ -803,7 +806,7 @@ def params(self):
class VectorSearchParamSource(SearchParamSource):
def __init__(self, workload, params, **kwargs):
super().__init__(workload, params, **kwargs)
self.delegate_param_source = VectorSearchPartitionParamSource(params, self.query_params)
self.delegate_param_source = VectorSearchPartitionParamSource(workload, params, self.query_params, **kwargs)

def partition(self, partition_index, total_partitions):
return self.delegate_param_source.partition(partition_index, total_partitions)
Expand All @@ -812,7 +815,7 @@ def params(self):
raise exceptions.WorkloadConfigError("Do not use a VectorSearchParamSource without partitioning")


class VectorDataSetPartitionParamSource(ABC):
class VectorDataSetPartitionParamSource(ParamSource):
""" Abstract class that can read vectors from a data set and partition the
vectors across multiple clients.
Expand All @@ -831,7 +834,8 @@ class VectorDataSetPartitionParamSource(ABC):
multiple partitions
"""

def __init__(self, params, context: Context):
def __init__(self, workload, params, context: Context, **kwargs):
super().__init__(workload, params, **kwargs)
self.field_name: str = parse_string_parameter("field", params)

self.context = context
Expand All @@ -847,10 +851,13 @@ def __init__(self, params, context: Context):
num_vectors < 0 or num_vectors > self.data_set.size()) else num_vectors
self.total = self.num_vectors
self.current = 0
self.infinite = False
self.percent_completed = 0
self.offset = 0

@property
def infinite(self):
return False

def _is_last_partition(self, partition_index, total_partitions):
return partition_index == total_partitions - 1

Expand Down Expand Up @@ -919,8 +926,8 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):
PARAMS_NAME_SOURCE = "_source"
PARAMS_NAME_ALLOW_PARTIAL_RESULTS = "allow_partial_search_results"

def __init__(self, params, query_params):
super().__init__(params, Context.QUERY)
def __init__(self, workloads, params, query_params, **kwargs):
super().__init__(workloads, params, Context.QUERY, **kwargs)
self.logger = logging.getLogger(__name__)
self.k = parse_int_parameter(self.PARAMS_NAME_K, params)
self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1)
Expand Down Expand Up @@ -1001,6 +1008,84 @@ def _build_vector_search_query_body(self, vector) -> dict:
}


class BulkVectorsFromDataSetParamSource(VectorDataSetPartitionParamSource):
""" Create bulk index requests from a data set of vectors.
Attributes:
bulk_size: number of vectors per request
retries: number of times to retry the request when it fails
"""

DEFAULT_RETRIES = 10
PARAMS_NAME_ID_FIELD_NAME = "id-field-name"
DEFAULT_ID_FIELD_NAME = "_id"

def __init__(self, workload, params, **kwargs):
super().__init__(workload, params, Context.INDEX, **kwargs)
self.bulk_size: int = parse_int_parameter("bulk_size", params)
self.retries: int = parse_int_parameter("retries", params,
self.DEFAULT_RETRIES)
self.index_name: str = parse_string_parameter("index", params)
self.id_field_name: str = parse_string_parameter(
self.PARAMS_NAME_ID_FIELD_NAME, params, self.DEFAULT_ID_FIELD_NAME)

def bulk_transform(self, partition: np.ndarray, action) -> List[Dict[str, Any]]:
"""Partitions and transforms a list of vectors into OpenSearch's bulk
injection format.
Args:
offset: to start counting from
partition: An array of vectors to transform.
action: Bulk API action.
Returns:
An array of transformed vectors in bulk format.
"""
actions = []
_ = [
actions.extend([action(self.id_field_name, i + self.current), None])
for i in range(len(partition))
]
bulk_contents = []
add_id_field_to_body = self.id_field_name != self.DEFAULT_ID_FIELD_NAME
for vec, identifier in zip(partition.tolist(), range(self.current, self.current + len(partition))):
row = {self.field_name: vec}
if add_id_field_to_body:
row.update({self.id_field_name: identifier})
bulk_contents.append(row)
actions[1::2] = bulk_contents
return actions

def params(self):
"""
Returns: A bulk index parameter with vectors from a data set.
"""
# TODO: Fix below logic to make sure we index only total number of documents as mentioned in the params.
if self.current >= self.num_vectors + self.offset:
raise StopIteration

def action(id_field_name, doc_id):
# support only index operation
bulk_action = 'index'
metadata = {
'_index': self.index_name
}
# Add id field to metadata only if it is _id
if id_field_name == self.DEFAULT_ID_FIELD_NAME:
metadata.update({id_field_name: doc_id})
return {bulk_action: metadata}

partition = self.data_set.read(self.bulk_size)
body = self.bulk_transform(partition, action)
size = len(body) // 2
self.current += size
self.percent_completed = self.current / self.total

return {
"body": body,
"retries": self.retries,
"size": size
}


def get_target(workload, params):
if len(workload.indices) == 1:
default_target = workload.indices[0].name
Expand Down Expand Up @@ -1418,6 +1503,7 @@ def read_bulk(self):


register_param_source_for_operation(workload.OperationType.Bulk, BulkIndexParamSource)
register_param_source_for_operation(workload.OperationType.BulkVectorDataSet, BulkVectorsFromDataSetParamSource)
register_param_source_for_operation(workload.OperationType.Search, SearchParamSource)
register_param_source_for_operation(workload.OperationType.VectorSearch, VectorSearchParamSource)
register_param_source_for_operation(workload.OperationType.CreateIndex, CreateIndexParamSource)
Expand Down
3 changes: 3 additions & 0 deletions osbenchmark/workload/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ class OperationType(Enum):
DeletePointInTime = 15
ListAllPointInTime = 16
VectorSearch = 17
BulkVectorDataSet = 18

# administrative actions
ForceMerge = 1001
Expand Down Expand Up @@ -644,6 +645,8 @@ def from_hyphenated_string(cls, v):
return OperationType.PaginatedSearch
elif v == "vector-search":
return OperationType.VectorSearch
elif v == "bulk-vector-data-set":
return OperationType.BulkVectorDataSet
elif v == "cluster-health":
return OperationType.ClusterHealth
elif v == "bulk":
Expand Down
Loading

0 comments on commit 6a073e7

Please sign in to comment.