From 6782594cbfa0b42b6fd7b0907cf2d4c522352133 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Tue, 26 Dec 2023 14:06:56 -0800 Subject: [PATCH] Accept body params from workload Will allow body params from workload except query parameter. We will replace that with vector search query. In future we will extend this query to support other features like post filter, efficient filter, etc... Signed-off-by: Vijayan Balasubramanian --- osbenchmark/workload/params.py | 26 +++++++----- tests/workload/params_test.py | 72 ++++++++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/osbenchmark/workload/params.py b/osbenchmark/workload/params.py index ec04f320c..0dd59cc3a 100644 --- a/osbenchmark/workload/params.py +++ b/osbenchmark/workload/params.py @@ -907,6 +907,8 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource): """ PARAMS_NAME_K = "k" PARAMS_NAME_BODY = "body" + PARAMS_NAME_SIZE = "size" + PARAMS_NAME_QUERY = "query" PARAMS_NAME_REPETITIONS = "repetitions" PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT = "neighbors_data_set_format" PARAMS_NAME_NEIGHBORS_DATA_SET_PATH = "neighbors_data_set_path" @@ -919,6 +921,7 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource): def __init__(self, params, query_params): super().__init__(params, Context.QUERY) + self.logger = logging.getLogger(__name__) self.k = parse_int_parameter(self.PARAMS_NAME_K, params) self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1) self.current_rep = 1 @@ -945,6 +948,17 @@ def _update_request_params(self): self.PARAMS_NAME_ALLOW_PARTIAL_RESULTS, "false") self.query_params.update({self.PARAMS_NAME_REQUEST_PARAMS: request_params}) + def _update_body_params(self, vector): + # accept body params if passed from workload, else, create empty dictionary + body_params = self.query_params.get(self.PARAMS_NAME_BODY) or dict() + if self.PARAMS_NAME_SIZE not in body_params: + body_params[self.PARAMS_NAME_SIZE] = self.k + if self.PARAMS_NAME_QUERY in body_params: + self.logger.warning("[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY) + # override query params with vector search query + body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector) + self.query_params.update({self.PARAMS_NAME_BODY: body_params}) + def params(self): """ Returns: A query parameter with a vector and neighbor from a data set @@ -964,33 +978,27 @@ def params(self): "neighbors": true_neighbors, }) self._update_request_params() - - self.query_params.update({ - self.PARAMS_NAME_BODY: self._build_vector_search_query_body(self.field_name, vector)}) + self._update_body_params(vector) self.current += 1 self.percent_completed = self.current / self.total return self.query_params - def _build_vector_search_query_body(self, field_name: str, vector) -> dict: + def _build_vector_search_query_body(self, vector) -> dict: """Builds a k-NN request that can be used to execute an approximate nearest neighbor search against a k-NN plugin index Args: - field_name: name of field to search vector: vector used for query Returns: A dictionary containing the body used for search query """ return { - "size": self.k, - "query": { "knn": { - field_name: { + self.field_name: { "vector": vector, "k": self.k } } } - } def get_target(workload, params): diff --git a/tests/workload/params_test.py b/tests/workload/params_test.py index 211cfbed5..9e734e717 100644 --- a/tests/workload/params_test.py +++ b/tests/workload/params_test.py @@ -2652,7 +2652,7 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.data_set_dir) - def test_params(self): + def test_params_default(self): # Create a data set k = 12 data_set_path = create_data_set( @@ -2676,10 +2676,62 @@ def test_params(self): "data_set_format": self.DEFAULT_TYPE, "data_set_path": data_set_path, "neighbors_data_set_path": neighbors_data_set_path, - "k": k, + "k": k } query_param_source = VectorSearchPartitionParamSource( - test_param_source_params, {"index": self.DEFAULT_INDEX_NAME, "request-params": {}} + test_param_source_params, { + "index": self.DEFAULT_INDEX_NAME, + "request-params": {}, + } + ) + + # Check each + for _ in range(DEFAULT_NUM_VECTORS): + self._check_params( + query_param_source.params(), + self.DEFAULT_FIELD_NAME, + self.DEFAULT_DIMENSION, + k, + ) + + # Assert last call creates stop iteration + with self.assertRaises(StopIteration): + query_param_source.params() + + def test_params_custom_body(self): + # Create a data set + k = 12 + data_set_path = create_data_set( + self.DEFAULT_NUM_VECTORS, + self.DEFAULT_DIMENSION, + self.DEFAULT_TYPE, + Context.QUERY, + self.data_set_dir + ) + neighbors_data_set_path = create_data_set( + self.DEFAULT_NUM_VECTORS, + self.DEFAULT_DIMENSION, + self.DEFAULT_TYPE, + Context.NEIGHBORS, + self.data_set_dir + ) + + # Create a QueryVectorsFromDataSetParamSource with relevant params + test_param_source_params = { + "field": self.DEFAULT_FIELD_NAME, + "data_set_format": self.DEFAULT_TYPE, + "data_set_path": data_set_path, + "neighbors_data_set_path": neighbors_data_set_path, + "k": k + } + query_param_source = VectorSearchPartitionParamSource( + test_param_source_params, { + "index": self.DEFAULT_INDEX_NAME, + "request-params": {}, + "body": { + "size": 100, + } + } ) # Check each @@ -2688,7 +2740,8 @@ def test_params(self): query_param_source.params(), self.DEFAULT_FIELD_NAME, self.DEFAULT_DIMENSION, - k + k, + 100, ) # Assert last call creates stop iteration @@ -2697,12 +2750,13 @@ def test_params(self): def _check_params( self, - params: dict, + actual_params: dict, expected_field: str, expected_dimension: int, - expected_k: int + expected_k: int, + expected_size=None, ): - body = params.get("body") + body = actual_params.get("body") self.assertIsInstance(body, dict) query = body.get("query") self.assertIsInstance(query, dict) @@ -2715,6 +2769,8 @@ def _check_params( self.assertEqual(len(list(vector)), expected_dimension) k = field.get("k") self.assertEqual(k, expected_k) - neighbor = params.get("neighbors") + neighbor = actual_params.get("neighbors") self.assertIsInstance(neighbor, list) self.assertEqual(len(neighbor), expected_dimension) + size = body.get("size") + self.assertEqual(size, expected_size if expected_size else expected_k)