Skip to content

Commit

Permalink
Add vector search param source
Browse files Browse the repository at this point in the history
Added new param source to partition vector dataset and
neighbors. This will be passed to runner to perform
search and compare response with neighbors for recall
calculation.

This param source extends Search ParamSource to inherit search's
other query parameters.
Vector Param Source will add additional paramter that are required
for vector serach operation type.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Dec 21, 2023
1 parent f7ae374 commit d4e227e
Show file tree
Hide file tree
Showing 2 changed files with 426 additions and 4 deletions.
186 changes: 183 additions & 3 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@
# under the License.

import collections
import copy
import inspect
import logging
import math
import numbers
import operator
import random
from abc import ABC

import time
from abc import ABC, abstractmethod
from enum import Enum

from osbenchmark import exceptions
from osbenchmark.workload import workload
from osbenchmark.utils import io
from osbenchmark.utils.dataset import DataSet, get_data_set, Context
from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter
from osbenchmark.workload import workload

__PARAM_SOURCES_BY_OP = {}
__PARAM_SOURCES_BY_NAME = {}
Expand Down Expand Up @@ -797,6 +799,183 @@ def params(self):
parsed_params.update(self._client_params())
return parsed_params


class VectorSearchParamSource(SearchParamSource):
def __init__(self, workload, params, **kwargs):
super().__init__(workload, params, **kwargs)
self.delegate_param_source = VectorSearchPartitionPartitionParamSource(params, self.query_params)

def partition(self, partition_index, total_partitions):
return self.delegate_param_source.partition(partition_index, total_partitions)


class VectorDataSetPartitionParamSource(ABC):
""" Abstract class that can read vectors from a data set and partition the
vectors across multiple clients.
Attributes:
field_name: Name of the field to generate the query for
data_set_format: Format data set is serialized with. bigann or hdf5
data_set_path: Path to data set
context: Context the data set will be used in.
data_set: Structure containing meta data about data and ability to read
num_vectors: Number of vectors to use from the data set
total: Number of vectors for the partition
current: Current vector offset in data set
infinite: Property of param source signalling that it can be exhausted
percent_completed: Progress indicator for how exhausted data set is
offset: Offset into the data set to start at. Relevant when there are
multiple partitions
"""

def __init__(self, params, context: Context):
self.field_name: str = parse_string_parameter("field", params)

self.context = context
self.data_set_format = parse_string_parameter("data_set_format", params)
self.data_set_path = parse_string_parameter("data_set_path", params)
self.data_set: DataSet = get_data_set(
self.data_set_format, self.data_set_path, self.context)

num_vectors: int = parse_int_parameter(
"num_vectors", params, self.data_set.size())
# if value is -1 or greater than dataset size, use dataset size as num_vectors
self.num_vectors = self.data_set.size() if (
num_vectors < 0 or num_vectors > self.data_set.size()) else num_vectors
self.total = self.num_vectors
self.current = 0
self.infinite = False
self.percent_completed = 0
self.offset = 0

def _is_last_partition(self, partition_index, total_partitions):
return partition_index == total_partitions - 1

def partition(self, partition_index, total_partitions):
"""
Splits up the parameters source so that multiple clients can read data
from it.
Args:
partition_index: index of one particular partition
total_partitions: total number of partitions data set is split into
Returns:
The parameter source for this particular partition
"""
partition_x = copy.copy(self)

num_vectors = int(self.num_vectors / total_partitions)

# if partition is not divided equally, add extra docs to the last partition
if self.num_vectors % total_partitions != 0 and self._is_last_partition(partition_index, total_partitions):
num_vectors += self.num_vectors - (num_vectors * total_partitions)

partition_x.num_vectors = num_vectors
partition_x.offset = int(partition_index * partition_x.num_vectors)
# We need to create a new instance of the data set for each client
partition_x.data_set = get_data_set(
self.data_set_format,
self.data_set_path,
self.context
)
partition_x.data_set.seek(partition_x.offset)
partition_x.current = partition_x.offset
return partition_x

@abstractmethod
def params(self):
"""
Returns: A single parameter from this source
"""


class VectorSearchPartitionPartitionParamSource(VectorDataSetPartitionParamSource):
""" Parameter source for k-NN. Queries are created from data set
provided.
Attributes:
k: The number of results to return for the search
repetitions: Number of times to re-run query dataset from beginning
"""
PARAMS_NAME_K = "k"
PARAMS_NAME_REPETITIONS = "repetitions"
PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT = "neighbors_data_set_format"
PARAMS_NAME_NEIGHBORS_DATA_SET_PATH = "neighbors_data_set_path"
PARAMS_NAME_OPERATION_TYPE = "operation-type"
PARAMS_VALUE_VECTOR_SEARCH = "vector-search"
PARAMS_NAME_ID_FIELD_NAME = "id_field_name"

def __init__(self, params, query_params):
super().__init__(params, Context.QUERY)
self.k = parse_int_parameter(self.PARAMS_NAME_K, params)
self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1)
self.current_rep = 1
self.neighbors_data_set_format = parse_string_parameter(
self.PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT, params, self.data_set_format)
self.neighbors_data_set_path = parse_string_parameter(
self.PARAMS_NAME_NEIGHBORS_DATA_SET_PATH, params, self.data_set_path)
self.neighbors_data_set: DataSet = get_data_set(
self.neighbors_data_set_format, self.neighbors_data_set_path, Context.NEIGHBORS)
operation_type = parse_string_parameter(self.PARAMS_NAME_OPERATION_TYPE, params,
self.PARAMS_VALUE_VECTOR_SEARCH)
self.query_params = query_params
self.query_params.update({
self.PARAMS_NAME_K: self.k,
self.PARAMS_NAME_OPERATION_TYPE: operation_type,
self.PARAMS_NAME_ID_FIELD_NAME: params.get(self.PARAMS_NAME_ID_FIELD_NAME),
})

def params(self):
"""
Returns: A query parameter with a vector from a data set
"""
is_dataset_exhausted = self.current >= self.num_vectors + self.offset

if is_dataset_exhausted and self.current_rep < self.repetitions:
self.data_set.seek(self.offset)
self.current = self.offset
self.current_rep += 1
elif is_dataset_exhausted:
raise StopIteration
vector = self.data_set.read(1)[0]
neighbor = self.neighbors_data_set.read(1)[0]
true_neighbors = list(map(str, neighbor[:self.k]))
self.query_params.update({
"neighbors": true_neighbors,
"request-params": {
"_source": "false",
# we need to set it to true as this data source is used for actual queries
"allow_partial_search_results": "false"
}
})
self.query_params.update({
"body": self._build_vector_search_query_body(self.field_name, vector)})
self.current += 1
self.percent_completed = self.current / self.total
return self.query_params

def _build_vector_search_query_body(self, field_name: str, vector) -> dict:
"""Builds a k-NN request that can be used to execute an approximate nearest
neighbor search against a k-NN plugin index
Args:
field_name: name of field to search
vector: vector used for query
Returns:
A dictionary containing the body used for search query
"""
return {
"size": self.k,
"query": {
"knn": {
field_name: {
"vector": vector,
"k": self.k
}
}
}
}


def get_target(workload, params):
if len(workload.indices) == 1:
default_target = workload.indices[0].name
Expand Down Expand Up @@ -1215,6 +1394,7 @@ def read_bulk(self):

register_param_source_for_operation(workload.OperationType.Bulk, BulkIndexParamSource)
register_param_source_for_operation(workload.OperationType.Search, SearchParamSource)
register_param_source_for_operation(workload.OperationType.VectorSearch, VectorSearchParamSource)
register_param_source_for_operation(workload.OperationType.CreateIndex, CreateIndexParamSource)
register_param_source_for_operation(workload.OperationType.DeleteIndex, DeleteIndexParamSource)
register_param_source_for_operation(workload.OperationType.CreateDataStream, CreateDataStreamParamSource)
Expand Down
Loading

0 comments on commit d4e227e

Please sign in to comment.