Skip to content

Commit

Permalink
Accept body params from workload
Browse files Browse the repository at this point in the history
Will allow body params from workload except query parameter.
We will replace that with vector search query.
In future we will extend this query to support other
features like post filter, efficient filter, etc...

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Dec 26, 2023
1 parent a523752 commit 6782594
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
26 changes: 17 additions & 9 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,8 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):
"""
PARAMS_NAME_K = "k"
PARAMS_NAME_BODY = "body"
PARAMS_NAME_SIZE = "size"
PARAMS_NAME_QUERY = "query"
PARAMS_NAME_REPETITIONS = "repetitions"
PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT = "neighbors_data_set_format"
PARAMS_NAME_NEIGHBORS_DATA_SET_PATH = "neighbors_data_set_path"
Expand All @@ -919,6 +921,7 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):

def __init__(self, params, query_params):
super().__init__(params, Context.QUERY)
self.logger = logging.getLogger(__name__)
self.k = parse_int_parameter(self.PARAMS_NAME_K, params)
self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1)
self.current_rep = 1
Expand All @@ -945,6 +948,17 @@ def _update_request_params(self):
self.PARAMS_NAME_ALLOW_PARTIAL_RESULTS, "false")
self.query_params.update({self.PARAMS_NAME_REQUEST_PARAMS: request_params})

def _update_body_params(self, vector):
# accept body params if passed from workload, else, create empty dictionary
body_params = self.query_params.get(self.PARAMS_NAME_BODY) or dict()
if self.PARAMS_NAME_SIZE not in body_params:
body_params[self.PARAMS_NAME_SIZE] = self.k
if self.PARAMS_NAME_QUERY in body_params:
self.logger.warning("[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
# override query params with vector search query
body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector)
self.query_params.update({self.PARAMS_NAME_BODY: body_params})

def params(self):
"""
Returns: A query parameter with a vector and neighbor from a data set
Expand All @@ -964,33 +978,27 @@ def params(self):
"neighbors": true_neighbors,
})
self._update_request_params()

self.query_params.update({
self.PARAMS_NAME_BODY: self._build_vector_search_query_body(self.field_name, vector)})
self._update_body_params(vector)
self.current += 1
self.percent_completed = self.current / self.total
return self.query_params

def _build_vector_search_query_body(self, field_name: str, vector) -> dict:
def _build_vector_search_query_body(self, vector) -> dict:
"""Builds a k-NN request that can be used to execute an approximate nearest
neighbor search against a k-NN plugin index
Args:
field_name: name of field to search
vector: vector used for query
Returns:
A dictionary containing the body used for search query
"""
return {
"size": self.k,
"query": {
"knn": {
field_name: {
self.field_name: {
"vector": vector,
"k": self.k
}
}
}
}


def get_target(workload, params):
Expand Down
72 changes: 64 additions & 8 deletions tests/workload/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,7 +2652,7 @@ def setUp(self) -> None:
def tearDown(self):
shutil.rmtree(self.data_set_dir)

def test_params(self):
def test_params_default(self):
# Create a data set
k = 12
data_set_path = create_data_set(
Expand All @@ -2676,10 +2676,62 @@ def test_params(self):
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"k": k,
"k": k
}
query_param_source = VectorSearchPartitionParamSource(
test_param_source_params, {"index": self.DEFAULT_INDEX_NAME, "request-params": {}}
test_param_source_params, {
"index": self.DEFAULT_INDEX_NAME,
"request-params": {},
}
)

# Check each
for _ in range(DEFAULT_NUM_VECTORS):
self._check_params(
query_param_source.params(),
self.DEFAULT_FIELD_NAME,
self.DEFAULT_DIMENSION,
k,
)

# Assert last call creates stop iteration
with self.assertRaises(StopIteration):
query_param_source.params()

def test_params_custom_body(self):
# Create a data set
k = 12
data_set_path = create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.QUERY,
self.data_set_dir
)
neighbors_data_set_path = create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.NEIGHBORS,
self.data_set_dir
)

# Create a QueryVectorsFromDataSetParamSource with relevant params
test_param_source_params = {
"field": self.DEFAULT_FIELD_NAME,
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"k": k
}
query_param_source = VectorSearchPartitionParamSource(
test_param_source_params, {
"index": self.DEFAULT_INDEX_NAME,
"request-params": {},
"body": {
"size": 100,
}
}
)

# Check each
Expand All @@ -2688,7 +2740,8 @@ def test_params(self):
query_param_source.params(),
self.DEFAULT_FIELD_NAME,
self.DEFAULT_DIMENSION,
k
k,
100,
)

# Assert last call creates stop iteration
Expand All @@ -2697,12 +2750,13 @@ def test_params(self):

def _check_params(
self,
params: dict,
actual_params: dict,
expected_field: str,
expected_dimension: int,
expected_k: int
expected_k: int,
expected_size=None,
):
body = params.get("body")
body = actual_params.get("body")
self.assertIsInstance(body, dict)
query = body.get("query")
self.assertIsInstance(query, dict)
Expand All @@ -2715,6 +2769,8 @@ def _check_params(
self.assertEqual(len(list(vector)), expected_dimension)
k = field.get("k")
self.assertEqual(k, expected_k)
neighbor = params.get("neighbors")
neighbor = actual_params.get("neighbors")
self.assertIsInstance(neighbor, list)
self.assertEqual(len(neighbor), expected_dimension)
size = body.get("size")
self.assertEqual(size, expected_size if expected_size else expected_k)

0 comments on commit 6782594

Please sign in to comment.