Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to parse response if fields or _source are enabled #427

Merged
merged 3 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,16 @@ def _is_empty_search_results(content):
return True
return False

def _get_field_value(content, field_name):
if field_name in content: # Will add to candidates if field value is present
return content[field_name]
# if fields are used in request params to return id_field's value
if "fields" in content and id_field in content["fields"]:
return content["fields"][id_field][0] # fields returns always an array
if "_source" in content: # if source is not disabled, retrieve value from source
return _get_field_value(content["_source"], field_name)
return None

def calculate_recall(predictions, neighbors, top_k):
"""
Calculates the recall by comparing top_k neighbors with predictions.
Expand Down Expand Up @@ -1044,8 +1054,11 @@ def calculate_recall(predictions, neighbors, top_k):
id_field = params.get("id-field-name", "_id")
candidates = []
for hit in response_json['hits']['hits']:
if id_field in hit: # Will add to candidates if field value is present
candidates.append(hit[id_field])
field_value = _get_field_value(hit, id_field)
if field_value is None: # Will add to candidates if field value is present
self.logger.warning("No value found for field %s", id_field)
VijayanB marked this conversation as resolved.
Show resolved Hide resolved
continue
candidates.append(field_value)
neighbors_dataset = params["neighbors"]
num_neighbors = params.get("k", 1)
recall_k = calculate_recall(candidates, neighbors_dataset, num_neighbors)
Expand Down
179 changes: 178 additions & 1 deletion tests/worker_coordinator/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,8 @@ async def test_search_pipeline_using_request_params(self, opensearch):
)
opensearch.clear_scroll.assert_not_called()


class VectorSearchQueryRunnerTests(TestCase):
@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_perfect_recall(self, opensearch):
Expand Down Expand Up @@ -2338,7 +2340,6 @@ async def test_query_vector_search_with_no_results(self, opensearch):
headers={"Accept-Encoding": "identity"}
)


@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_imperfect_recall(self, opensearch):
Expand Down Expand Up @@ -2559,6 +2560,182 @@ async def test_query_vector_search_with_zero_recall_1(self, opensearch):
headers={"Accept-Encoding": "identity"}
)

@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_custom_id_field(self, opensearch):
search_response = {
"timed_out": False,
"took": 5,
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"hits": [
{
"_id": "random-id1",
"_score": 0.95,
"fields": {
"id": [0]
}
},
{
"_id": "random-id2",
"_score": 0.88,
"fields": {
"id": [1]
}
},
{
"_id": "random-id3",
"_score": 0.1,
"fields": {
"id": [2]
}
}
]
}
}
opensearch.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response)))

query_runner = runner.Query()

params = {
"index": "unittest",
"operation-type": "vector-search",
"detailed-results": True,
"response-compression-enabled": False,
"id-field-name": "id",
"k": 3,
"neighbors": [0, 1, 2],
"request-params": {
"docvalue_fields": "id",
"_source": False,
},
"body": {
"query": {
"knn": {
"location": {
"vector": [
5,
4
],
"k": 3
}
}}
}
}

async with query_runner:
result = await query_runner(opensearch, params)

self.assertEqual(1, result["weight"])
self.assertEqual("ops", result["unit"])
self.assertEqual(3, result["hits"])
self.assertEqual("eq", result["hits_relation"])
self.assertFalse(result["timed_out"])
self.assertEqual(5, result["took"])
self.assertIn("recall_time_ms", result.keys())
self.assertIn("recall@k", result.keys())
self.assertEqual(result["recall@k"], 1.0)
self.assertIn("recall@1", result.keys())
self.assertEqual(result["recall@1"], 1.0)
self.assertNotIn("error-type", result.keys())

opensearch.transport.perform_request.assert_called_once_with(
"GET",
"/unittest/_search",
params={'docvalue_fields': "id", "_source": False},
body=params["body"],
headers={"Accept-Encoding": "identity"}
)

@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_with_custom_id_field_inside_source(self, opensearch):
search_response = {
"timed_out": False,
"took": 5,
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"hits": [
{
"_id": "random-id1",
"_score": 0.95,
"_source": {
"id": "101"
}
},
{
"_id": "random-id2",
"_score": 0.88,
"_source": {
"id": "102"
}
},
{
"_id": "random-id3",
"_score": 0.1,
"_source": {
"id": "103",
}
}
]
}
}
opensearch.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response)))

query_runner = runner.Query()

params = {
"index": "unittest",
"operation-type": "vector-search",
"detailed-results": True,
"response-compression-enabled": False,
"id-field-name": "id",
"k": 3,
"neighbors": ["101", "102", "103"],
"body": {
"query": {
"knn": {
"location": {
"vector": [
5,
4
],
"k": 3
}
}}
}
}

async with query_runner:
result = await query_runner(opensearch, params)

self.assertEqual(1, result["weight"])
self.assertEqual("ops", result["unit"])
self.assertEqual(3, result["hits"])
self.assertEqual("eq", result["hits_relation"])
self.assertFalse(result["timed_out"])
self.assertEqual(5, result["took"])
self.assertIn("recall_time_ms", result.keys())
self.assertIn("recall@k", result.keys())
self.assertEqual(result["recall@k"], 1.0)
self.assertIn("recall@1", result.keys())
self.assertEqual(result["recall@1"], 1.0)
self.assertNotIn("error-type", result.keys())

opensearch.transport.perform_request.assert_called_once_with(
"GET",
"/unittest/_search",
params={},
body=params["body"],
headers={"Accept-Encoding": "identity"}
)


class CreateIngestPipelineRunnerTests(TestCase):
@mock.patch("opensearchpy.OpenSearch")
Expand Down
Loading