Skip to content

Commit

Permalink
Add support to parse response if fields or _source are enabled (#427)
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB authored Dec 29, 2023
1 parent 1dc9de5 commit b861175
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 20 deletions.
17 changes: 15 additions & 2 deletions osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,16 @@ def _is_empty_search_results(content):
return True
return False

def _get_field_value(content, field_name):
if field_name in content: # Will add to candidates if field value is present
return content[field_name]
# if fields are used in request params to return id_field's value
if "fields" in content and id_field in content["fields"]:
return content["fields"][id_field][0] # fields returns always an array
if "_source" in content: # if source is not disabled, retrieve value from source
return _get_field_value(content["_source"], field_name)
return None

def calculate_recall(predictions, neighbors, top_k):
"""
Calculates the recall by comparing top_k neighbors with predictions.
Expand Down Expand Up @@ -1044,8 +1054,11 @@ def calculate_recall(predictions, neighbors, top_k):
id_field = params.get("id-field-name", "_id")
candidates = []
for hit in response_json['hits']['hits']:
if id_field in hit: # Will add to candidates if field value is present
candidates.append(hit[id_field])
field_value = _get_field_value(hit, id_field)
if field_value is None: # Will add to candidates if field value is present
self.logger.warning("No value found for field %s", id_field)
continue
candidates.append(field_value)
neighbors_dataset = params["neighbors"]
num_neighbors = params.get("k", 1)
recall_k = calculate_recall(candidates, neighbors_dataset, num_neighbors)
Expand Down
26 changes: 17 additions & 9 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,8 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):
"""
PARAMS_NAME_K = "k"
PARAMS_NAME_BODY = "body"
PARAMS_NAME_SIZE = "size"
PARAMS_NAME_QUERY = "query"
PARAMS_NAME_REPETITIONS = "repetitions"
PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT = "neighbors_data_set_format"
PARAMS_NAME_NEIGHBORS_DATA_SET_PATH = "neighbors_data_set_path"
Expand All @@ -919,6 +921,7 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):

def __init__(self, params, query_params):
super().__init__(params, Context.QUERY)
self.logger = logging.getLogger(__name__)
self.k = parse_int_parameter(self.PARAMS_NAME_K, params)
self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1)
self.current_rep = 1
Expand All @@ -945,6 +948,17 @@ def _update_request_params(self):
self.PARAMS_NAME_ALLOW_PARTIAL_RESULTS, "false")
self.query_params.update({self.PARAMS_NAME_REQUEST_PARAMS: request_params})

def _update_body_params(self, vector):
# accept body params if passed from workload, else, create empty dictionary
body_params = self.query_params.get(self.PARAMS_NAME_BODY) or dict()
if self.PARAMS_NAME_SIZE not in body_params:
body_params[self.PARAMS_NAME_SIZE] = self.k
if self.PARAMS_NAME_QUERY in body_params:
self.logger.warning("[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
# override query params with vector search query
body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector)
self.query_params.update({self.PARAMS_NAME_BODY: body_params})

def params(self):
"""
Returns: A query parameter with a vector and neighbor from a data set
Expand All @@ -964,33 +978,27 @@ def params(self):
"neighbors": true_neighbors,
})
self._update_request_params()

self.query_params.update({
self.PARAMS_NAME_BODY: self._build_vector_search_query_body(self.field_name, vector)})
self._update_body_params(vector)
self.current += 1
self.percent_completed = self.current / self.total
return self.query_params

def _build_vector_search_query_body(self, field_name: str, vector) -> dict:
def _build_vector_search_query_body(self, vector) -> dict:
"""Builds a k-NN request that can be used to execute an approximate nearest
neighbor search against a k-NN plugin index
Args:
field_name: name of field to search
vector: vector used for query
Returns:
A dictionary containing the body used for search query
"""
return {
"size": self.k,
"query": {
"knn": {
field_name: {
self.field_name: {
"vector": vector,
"k": self.k
}
}
}
}


def get_target(workload, params):
Expand Down
179 changes: 178 additions & 1 deletion tests/worker_coordinator/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,8 @@ async def test_search_pipeline_using_request_params(self, opensearch):
)
opensearch.clear_scroll.assert_not_called()


class VectorSearchQueryRunnerTests(TestCase):
@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_perfect_recall(self, opensearch):
Expand Down Expand Up @@ -2338,7 +2340,6 @@ async def test_query_vector_search_with_no_results(self, opensearch):
headers={"Accept-Encoding": "identity"}
)


@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_imperfect_recall(self, opensearch):
Expand Down Expand Up @@ -2559,6 +2560,182 @@ async def test_query_vector_search_with_zero_recall_1(self, opensearch):
headers={"Accept-Encoding": "identity"}
)

@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_custom_id_field(self, opensearch):
search_response = {
"timed_out": False,
"took": 5,
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"hits": [
{
"_id": "random-id1",
"_score": 0.95,
"fields": {
"id": [0]
}
},
{
"_id": "random-id2",
"_score": 0.88,
"fields": {
"id": [1]
}
},
{
"_id": "random-id3",
"_score": 0.1,
"fields": {
"id": [2]
}
}
]
}
}
opensearch.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response)))

query_runner = runner.Query()

params = {
"index": "unittest",
"operation-type": "vector-search",
"detailed-results": True,
"response-compression-enabled": False,
"id-field-name": "id",
"k": 3,
"neighbors": [0, 1, 2],
"request-params": {
"docvalue_fields": "id",
"_source": False,
},
"body": {
"query": {
"knn": {
"location": {
"vector": [
5,
4
],
"k": 3
}
}}
}
}

async with query_runner:
result = await query_runner(opensearch, params)

self.assertEqual(1, result["weight"])
self.assertEqual("ops", result["unit"])
self.assertEqual(3, result["hits"])
self.assertEqual("eq", result["hits_relation"])
self.assertFalse(result["timed_out"])
self.assertEqual(5, result["took"])
self.assertIn("recall_time_ms", result.keys())
self.assertIn("recall@k", result.keys())
self.assertEqual(result["recall@k"], 1.0)
self.assertIn("recall@1", result.keys())
self.assertEqual(result["recall@1"], 1.0)
self.assertNotIn("error-type", result.keys())

opensearch.transport.perform_request.assert_called_once_with(
"GET",
"/unittest/_search",
params={'docvalue_fields': "id", "_source": False},
body=params["body"],
headers={"Accept-Encoding": "identity"}
)

@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_custom_id_field_inside_source(self, opensearch):
search_response = {
"timed_out": False,
"took": 5,
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"hits": [
{
"_id": "random-id1",
"_score": 0.95,
"_source": {
"id": "101"
}
},
{
"_id": "random-id2",
"_score": 0.88,
"_source": {
"id": "102"
}
},
{
"_id": "random-id3",
"_score": 0.1,
"_source": {
"id": "103",
}
}
]
}
}
opensearch.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response)))

query_runner = runner.Query()

params = {
"index": "unittest",
"operation-type": "vector-search",
"detailed-results": True,
"response-compression-enabled": False,
"id-field-name": "id",
"k": 3,
"neighbors": ["101", "102", "103"],
"body": {
"query": {
"knn": {
"location": {
"vector": [
5,
4
],
"k": 3
}
}}
}
}

async with query_runner:
result = await query_runner(opensearch, params)

self.assertEqual(1, result["weight"])
self.assertEqual("ops", result["unit"])
self.assertEqual(3, result["hits"])
self.assertEqual("eq", result["hits_relation"])
self.assertFalse(result["timed_out"])
self.assertEqual(5, result["took"])
self.assertIn("recall_time_ms", result.keys())
self.assertIn("recall@k", result.keys())
self.assertEqual(result["recall@k"], 1.0)
self.assertIn("recall@1", result.keys())
self.assertEqual(result["recall@1"], 1.0)
self.assertNotIn("error-type", result.keys())

opensearch.transport.perform_request.assert_called_once_with(
"GET",
"/unittest/_search",
params={},
body=params["body"],
headers={"Accept-Encoding": "identity"}
)


class CreateIngestPipelineRunnerTests(TestCase):
@mock.patch("opensearchpy.OpenSearch")
Expand Down
Loading

0 comments on commit b861175

Please sign in to comment.