Skip to content

Commit

Permalink
Resolve feedbacks and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 23, 2024
1 parent 0f2ad7e commit dc2a97d
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 18 deletions.
28 changes: 16 additions & 12 deletions osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,32 +1318,33 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F

return float(correct) / min_num_of_results

doc_type = params.get("type")
response = await self._raw_search(opensearch, doc_type, index, body, request_params, headers=headers)

result = {
"weight": 1,
"unit": "ops",
"success": True,
}

if params.get("k"):
# Add recall@k and recall@1 to the initial result only if k is present in the params
if "k" in params:
result.update({
"recall@k": 0,
"recall@1": 0
})
elif params.get("max_distance"):
# Add recall@max_distance and recall@max_distance_1 to the initial result only if max_distance is present in the params
elif "max_distance" in params:
result.update({
"recall@max_distance": 0,
"recall@max_distance_1": 0
})
elif params.get("min_score"):
# Add recall@min_score and recall@min_score_1 to the initial result only if min_score is present in the params
elif "min_score" in params:
result.update({
"recall@min_score": 0,
"recall@min_score_1": 0
})

recall_processing_start = time.perf_counter()
doc_type = params.get("type")
response = await self._raw_search(opensearch, doc_type, index, body, request_params, headers=headers)

if detailed_results:
props = parse(response, ["hits.total", "hits.total.value", "hits.total.relation", "timed_out", "took"])
hits_total = props.get("hits.total.value", props.get("hits.total", 0))
Expand All @@ -1357,6 +1358,8 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F
"timed_out": timed_out,
"took": took
})

recall_processing_start = time.perf_counter()
response_json = json.loads(response.getvalue())
if _is_empty_search_results(response_json):
self.logger.info("Vector search query returned no results.")
Expand All @@ -1371,19 +1374,20 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F
candidates.append(field_value)
neighbors_dataset = params["neighbors"]

if params.get("k"):
if "k" in params:
num_neighbors = params.get("k", 1)
recall_top_k = calculate_topk_search_recall(candidates, neighbors_dataset, num_neighbors)
recall_top_1 = calculate_topk_search_recall(candidates, neighbors_dataset, 1)
result.update({"recall@k": recall_top_k})
result.update({"recall@1": recall_top_1})
else:

if "max_distance" in params or "min_score" in params:
recall_threshold = calculate_radial_search_recall(candidates, neighbors_dataset)
recall_top_1 = calculate_radial_search_recall(candidates, neighbors_dataset, True)
if params.get("min_score"):
if "min_score" in params:
result.update({"recall@min_score": recall_threshold})
result.update({"recall@min_score_1": recall_top_1})
elif params.get("max_distance"):
elif "max_distance" in params:
result.update({"recall@max_distance": recall_threshold})
result.update({"recall@max_distance_1": recall_top_1})

Expand Down
12 changes: 6 additions & 6 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource):
MIN_SCORE_QUERY_TYPE = "min_score"
MAX_DISTANCE_QUERY_TYPE = "max_distance"
KNN_QUERY_TYPE = "knn"
RADIAL_SEARCH_QUERY_RESULT_SIZE = 10000
DEFAULT_RADIAL_SEARCH_QUERY_RESULT_SIZE = 10000

def __init__(self, workloads, params, query_params, **kwargs):
super().__init__(workloads, params, Context.QUERY, **kwargs)
Expand Down Expand Up @@ -1124,12 +1124,12 @@ def _update_request_params(self):
def _get_query_neighbors(self):
if self.query_type == self.KNN_QUERY_TYPE:
return Context.NEIGHBORS
elif self.query_type == self.MIN_SCORE_QUERY_TYPE:
if self.query_type == self.MIN_SCORE_QUERY_TYPE:
return Context.MIN_SCORE_NEIGHBORS
elif self.query_type == self.MAX_DISTANCE_QUERY_TYPE:
if self.query_type == self.MAX_DISTANCE_QUERY_TYPE:
return Context.MAX_DISTANCE_NEIGHBORS
else:
raise exceptions.InvalidSyntax("Unknown query type [%s]" % self.query_type)
raise Exception("Unknown query type [%s]" % self.query_type)

def _update_body_params(self, vector):
# accept body params if passed from workload, else, create empty dictionary
Expand All @@ -1138,7 +1138,7 @@ def _update_body_params(self, vector):
if self.query_type == self.KNN_QUERY_TYPE:
body_params[self.PARAMS_NAME_SIZE] = self.k
else:
body_params[self.PARAMS_NAME_SIZE] = self.RADIAL_SEARCH_QUERY_RESULT_SIZE
body_params[self.PARAMS_NAME_SIZE] = self.DEFAULT_RADIAL_SEARCH_QUERY_RESULT_SIZE
if self.PARAMS_NAME_QUERY in body_params:
self.logger.warning(
"[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
Expand Down Expand Up @@ -1213,7 +1213,7 @@ def _build_vector_search_query_body(self, vector, efficient_filter=None) -> dict
"k": self.k,
})
else:
raise exceptions.InvalidSyntax("Unknown query type [%s]" % self.query_type)
raise Exception("Unknown query type [%s]" % self.query_type)

query.update({
"vector": vector,
Expand Down
90 changes: 90 additions & 0 deletions tests/workload/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2957,6 +2957,96 @@ def test_params_custom_body(self):
with self.assertRaises(StopIteration):
query_param_source_partition.params()

def test_params_when_multiple_query_type_provided_then_raise_exception(self):
# Create a data set
data_set_path = create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.QUERY,
self.data_set_dir
)
neighbors_data_set_path = create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.NEIGHBORS,
self.data_set_dir
)
filter_body = {
"key": "value"
}

test_param_source_params_1 = {
"field": self.DEFAULT_FIELD_NAME,
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"k": 10,
"min_score": 0.5,
}

with self.assertRaisesRegex(ValueError, "Only one of k, max_distance, or min_score can be specified in vector search."):
query_param_source = VectorSearchPartitionParamSource(
workload.Workload(name="unit-test"),
test_param_source_params_1, {
"index": self.DEFAULT_INDEX_NAME,
"request-params": {},
"body": {
"size": 100,
}
}
)
# This line won't be executed if exception is raised during initialization
query_param_source.partition(0, 1)

test_param_source_params_2 = {
"field": self.DEFAULT_FIELD_NAME,
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"k": 10,
"max_distance": 100.0,
}

with self.assertRaisesRegex(ValueError, "Only one of k, max_distance, or min_score can be specified in vector search."):
query_param_source = VectorSearchPartitionParamSource(
workload.Workload(name="unit-test"),
test_param_source_params_2, {
"index": self.DEFAULT_INDEX_NAME,
"request-params": {},
"body": {
"size": 100,
}
}
)
# This line won't be executed if exception is raised during initialization
query_param_source.partition(0, 1)

test_param_source_params_3 = {
"field": self.DEFAULT_FIELD_NAME,
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"min_score": 0.5,
"max_distance": 100.0,
"k": 10,
}

with self.assertRaisesRegex(ValueError, "Only one of k, max_distance, or min_score can be specified in vector search."):
query_param_source = VectorSearchPartitionParamSource(
workload.Workload(name="unit-test"),
test_param_source_params_3, {
"index": self.DEFAULT_INDEX_NAME,
"request-params": {},
"body": {
"size": 100,
}
}
)
# This line won't be executed if exception is raised during initialization
query_param_source.partition(0, 1)

def _check_params(
self,
actual_params: dict,
Expand Down

0 comments on commit dc2a97d

Please sign in to comment.