Skip to content

Commit

Permalink
Add vector search param source (#425)
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB authored Dec 22, 2023
1 parent a9f7c60 commit 1dc9de5
Show file tree
Hide file tree
Showing 3 changed files with 444 additions and 5 deletions.
2 changes: 1 addition & 1 deletion osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ def calculate_recall(predictions, neighbors, top_k):
if _is_empty_search_results(response_json):
self.logger.info("Vector search query returned no results.")
return result
id_field = params.get("id_field_name", "_id")
id_field = params.get("id-field-name", "_id")
candidates = []
for hit in response_json['hits']['hits']:
if id_field in hit: # Will add to candidates if field value is present
Expand Down
203 changes: 200 additions & 3 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@
# under the License.

import collections
import copy
import inspect
import logging
import math
import numbers
import operator
import random
from abc import ABC

import time
from abc import ABC, abstractmethod
from enum import Enum

from osbenchmark import exceptions
from osbenchmark.workload import workload
from osbenchmark.utils import io
from osbenchmark.utils.dataset import DataSet, get_data_set, Context
from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter
from osbenchmark.workload import workload

__PARAM_SOURCES_BY_OP = {}
__PARAM_SOURCES_BY_NAME = {}
Expand Down Expand Up @@ -797,6 +799,200 @@ def params(self):
parsed_params.update(self._client_params())
return parsed_params


class VectorSearchParamSource(SearchParamSource):
def __init__(self, workload, params, **kwargs):
super().__init__(workload, params, **kwargs)
self.delegate_param_source = VectorSearchPartitionParamSource(params, self.query_params)

def partition(self, partition_index, total_partitions):
return self.delegate_param_source.partition(partition_index, total_partitions)

def params(self):
raise exceptions.WorkloadConfigError("Do not use a VectorSearchParamSource without partitioning")


class VectorDataSetPartitionParamSource(ABC):
""" Abstract class that can read vectors from a data set and partition the
vectors across multiple clients.
Attributes:
field_name: Name of the field to generate the query for
data_set_format: Format data set is serialized with. bigann or hdf5
data_set_path: Path to data set
context: Context the data set will be used in.
data_set: Structure containing meta data about data and ability to read
num_vectors: Number of vectors to use from the data set
total: Number of vectors for the partition
current: Current vector offset in data set
infinite: Property of param source signalling that it can be exhausted
percent_completed: Progress indicator for how exhausted data set is
offset: Offset into the data set to start at. Relevant when there are
multiple partitions
"""

def __init__(self, params, context: Context):
self.field_name: str = parse_string_parameter("field", params)

self.context = context
self.data_set_format = parse_string_parameter("data_set_format", params)
self.data_set_path = parse_string_parameter("data_set_path", params)
self.data_set: DataSet = get_data_set(
self.data_set_format, self.data_set_path, self.context)

num_vectors: int = parse_int_parameter(
"num_vectors", params, self.data_set.size())
# if value is -1 or greater than dataset size, use dataset size as num_vectors
self.num_vectors = self.data_set.size() if (
num_vectors < 0 or num_vectors > self.data_set.size()) else num_vectors
self.total = self.num_vectors
self.current = 0
self.infinite = False
self.percent_completed = 0
self.offset = 0

def _is_last_partition(self, partition_index, total_partitions):
return partition_index == total_partitions - 1

def partition(self, partition_index, total_partitions):
"""
Splits up the parameters source so that multiple clients can read data
from it.
Args:
partition_index: index of one particular partition
total_partitions: total number of partitions data set is split into
Returns:
The parameter source for this particular partition
"""
partition_x = copy.copy(self)

num_vectors = int(self.num_vectors / total_partitions)

# if partition is not divided equally, add extra docs to the last partition
if self.num_vectors % total_partitions != 0 and self._is_last_partition(partition_index, total_partitions):
num_vectors += self.num_vectors - (num_vectors * total_partitions)

partition_x.num_vectors = num_vectors
partition_x.offset = int(partition_index * partition_x.num_vectors)
# We need to create a new instance of the data set for each client
partition_x.data_set = get_data_set(
self.data_set_format,
self.data_set_path,
self.context
)
partition_x.data_set.seek(partition_x.offset)
partition_x.current = partition_x.offset
return partition_x

@abstractmethod
def params(self):
"""
Returns: A single parameter from this source
"""


class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):
""" Parameter source for k-NN. Queries are created from data set
provided.
Attributes:
k: The number of results to return for the search
repetitions: Number of times to re-run query dataset from beginning
neighbors_data_set_format: neighbor's dataset format type like hdf5, bigann
neighbors_data_set_path: neighbor's dataset file path
operation-type: search method type
id-field-name: field name that will have unique identifier id in document
request-params: query parameters that can be passed to search request
"""
PARAMS_NAME_K = "k"
PARAMS_NAME_BODY = "body"
PARAMS_NAME_REPETITIONS = "repetitions"
PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT = "neighbors_data_set_format"
PARAMS_NAME_NEIGHBORS_DATA_SET_PATH = "neighbors_data_set_path"
PARAMS_NAME_OPERATION_TYPE = "operation-type"
PARAMS_VALUE_VECTOR_SEARCH = "vector-search"
PARAMS_NAME_ID_FIELD_NAME = "id-field-name"
PARAMS_NAME_REQUEST_PARAMS = "request-params"
PARAMS_NAME_SOURCE = "_source"
PARAMS_NAME_ALLOW_PARTIAL_RESULTS = "allow_partial_search_results"

def __init__(self, params, query_params):
super().__init__(params, Context.QUERY)
self.k = parse_int_parameter(self.PARAMS_NAME_K, params)
self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1)
self.current_rep = 1
self.neighbors_data_set_format = parse_string_parameter(
self.PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT, params, self.data_set_format)
self.neighbors_data_set_path = parse_string_parameter(
self.PARAMS_NAME_NEIGHBORS_DATA_SET_PATH, params, self.data_set_path)
self.neighbors_data_set: DataSet = get_data_set(
self.neighbors_data_set_format, self.neighbors_data_set_path, Context.NEIGHBORS)
operation_type = parse_string_parameter(self.PARAMS_NAME_OPERATION_TYPE, params,
self.PARAMS_VALUE_VECTOR_SEARCH)
self.query_params = query_params
self.query_params.update({
self.PARAMS_NAME_K: self.k,
self.PARAMS_NAME_OPERATION_TYPE: operation_type,
self.PARAMS_NAME_ID_FIELD_NAME: params.get(self.PARAMS_NAME_ID_FIELD_NAME),
})

def _update_request_params(self):
request_params = self.query_params.get(self.PARAMS_NAME_REQUEST_PARAMS, {})
request_params[self.PARAMS_NAME_SOURCE] = request_params.get(
self.PARAMS_NAME_SOURCE, "false")
request_params[self.PARAMS_NAME_ALLOW_PARTIAL_RESULTS] = request_params.get(
self.PARAMS_NAME_ALLOW_PARTIAL_RESULTS, "false")
self.query_params.update({self.PARAMS_NAME_REQUEST_PARAMS: request_params})

def params(self):
"""
Returns: A query parameter with a vector and neighbor from a data set
"""
is_dataset_exhausted = self.current >= self.num_vectors + self.offset

if is_dataset_exhausted and self.current_rep < self.repetitions:
self.data_set.seek(self.offset)
self.current = self.offset
self.current_rep += 1
elif is_dataset_exhausted:
raise StopIteration
vector = self.data_set.read(1)[0]
neighbor = self.neighbors_data_set.read(1)[0]
true_neighbors = list(map(str, neighbor[:self.k]))
self.query_params.update({
"neighbors": true_neighbors,
})
self._update_request_params()

self.query_params.update({
self.PARAMS_NAME_BODY: self._build_vector_search_query_body(self.field_name, vector)})
self.current += 1
self.percent_completed = self.current / self.total
return self.query_params

def _build_vector_search_query_body(self, field_name: str, vector) -> dict:
"""Builds a k-NN request that can be used to execute an approximate nearest
neighbor search against a k-NN plugin index
Args:
field_name: name of field to search
vector: vector used for query
Returns:
A dictionary containing the body used for search query
"""
return {
"size": self.k,
"query": {
"knn": {
field_name: {
"vector": vector,
"k": self.k
}
}
}
}


def get_target(workload, params):
if len(workload.indices) == 1:
default_target = workload.indices[0].name
Expand Down Expand Up @@ -1215,6 +1411,7 @@ def read_bulk(self):

register_param_source_for_operation(workload.OperationType.Bulk, BulkIndexParamSource)
register_param_source_for_operation(workload.OperationType.Search, SearchParamSource)
register_param_source_for_operation(workload.OperationType.VectorSearch, VectorSearchParamSource)
register_param_source_for_operation(workload.OperationType.CreateIndex, CreateIndexParamSource)
register_param_source_for_operation(workload.OperationType.DeleteIndex, DeleteIndexParamSource)
register_param_source_for_operation(workload.OperationType.CreateDataStream, CreateDataStreamParamSource)
Expand Down
Loading

0 comments on commit 1dc9de5

Please sign in to comment.