diff --git a/osbenchmark/utils/dataset.py b/osbenchmark/utils/dataset.py index d8ad2b74f..41b303073 100644 --- a/osbenchmark/utils/dataset.py +++ b/osbenchmark/utils/dataset.py @@ -24,6 +24,8 @@ class Context(Enum): INDEX = 1 QUERY = 2 NEIGHBORS = 3 + MAX_DISTANCE_NEIGHBORS = 4 + MIN_SCORE_NEIGHBORS = 5 class DataSet(ABC): @@ -141,6 +143,12 @@ def parse_context(context: Context) -> str: if context == Context.QUERY: return "test" + if context == Context.MAX_DISTANCE_NEIGHBORS: + return "max_distance_neighbors" + + if context == Context.MIN_SCORE_NEIGHBORS: + return "min_score_neighbors" + raise Exception("Unsupported context") diff --git a/osbenchmark/worker_coordinator/runner.py b/osbenchmark/worker_coordinator/runner.py index 1941c90ac..a2efc2013 100644 --- a/osbenchmark/worker_coordinator/runner.py +++ b/osbenchmark/worker_coordinator/runner.py @@ -65,7 +65,8 @@ def register_default_runners(): register_runner(workload.OperationType.RawRequest, RawRequest(), async_runner=True) register_runner(workload.OperationType.Composite, Composite(), async_runner=True) register_runner(workload.OperationType.SubmitAsyncSearch, SubmitAsyncSearch(), async_runner=True) - register_runner(workload.OperationType.GetAsyncSearch, Retry(GetAsyncSearch(), retry_until_success=True), async_runner=True) + register_runner(workload.OperationType.GetAsyncSearch, Retry(GetAsyncSearch(), retry_until_success=True), + async_runner=True) register_runner(workload.OperationType.DeleteAsyncSearch, DeleteAsyncSearch(), async_runner=True) register_runner(workload.OperationType.CreatePointInTime, CreatePointInTime(), async_runner=True) register_runner(workload.OperationType.DeletePointInTime, DeletePointInTime(), async_runner=True) @@ -85,15 +86,19 @@ def register_default_runners(): register_runner(workload.OperationType.DeleteIndex, Retry(DeleteIndex()), async_runner=True) register_runner(workload.OperationType.CreateComponentTemplate, Retry(CreateComponentTemplate()), async_runner=True) register_runner(workload.OperationType.DeleteComponentTemplate, Retry(DeleteComponentTemplate()), async_runner=True) - register_runner(workload.OperationType.CreateComposableTemplate, Retry(CreateComposableTemplate()), async_runner=True) - register_runner(workload.OperationType.DeleteComposableTemplate, Retry(DeleteComposableTemplate()), async_runner=True) + register_runner(workload.OperationType.CreateComposableTemplate, Retry(CreateComposableTemplate()), + async_runner=True) + register_runner(workload.OperationType.DeleteComposableTemplate, Retry(DeleteComposableTemplate()), + async_runner=True) register_runner(workload.OperationType.CreateDataStream, Retry(CreateDataStream()), async_runner=True) register_runner(workload.OperationType.DeleteDataStream, Retry(DeleteDataStream()), async_runner=True) register_runner(workload.OperationType.CreateIndexTemplate, Retry(CreateIndexTemplate()), async_runner=True) register_runner(workload.OperationType.DeleteIndexTemplate, Retry(DeleteIndexTemplate()), async_runner=True) register_runner(workload.OperationType.ShrinkIndex, Retry(ShrinkIndex()), async_runner=True) - register_runner(workload.OperationType.DeleteSnapshotRepository, Retry(DeleteSnapshotRepository()), async_runner=True) - register_runner(workload.OperationType.CreateSnapshotRepository, Retry(CreateSnapshotRepository()), async_runner=True) + register_runner(workload.OperationType.DeleteSnapshotRepository, Retry(DeleteSnapshotRepository()), + async_runner=True) + register_runner(workload.OperationType.CreateSnapshotRepository, Retry(CreateSnapshotRepository()), + async_runner=True) register_runner(workload.OperationType.WaitForSnapshotCreate, Retry(WaitForSnapshotCreate()), async_runner=True) register_runner(workload.OperationType.WaitForRecovery, Retry(IndicesRecovery()), async_runner=True) register_runner(workload.OperationType.PutSettings, Retry(PutSettings()), async_runner=True) @@ -131,7 +136,8 @@ def register_runner(operation_type, runner, **kwargs): if not async_runner: raise exceptions.BenchmarkAssertionError( - "Runner [{}] must be implemented as async runner and registered with async_runner=True.".format(str(runner))) + "Runner [{}] must be implemented as async runner and registered with async_runner=True.".format( + str(runner))) if getattr(runner, "multi_cluster", False): if "__aenter__" in dir(runner) and "__aexit__" in dir(runner): @@ -140,7 +146,8 @@ def register_runner(operation_type, runner, **kwargs): cluster_aware_runner = _multi_cluster_runner(runner, str(runner), context_manager_enabled=True) else: if logger.isEnabledFor(logging.DEBUG): - logger.debug("Registering context-manager capable runner object [%s] for [%s].", str(runner), str(operation_type)) + logger.debug("Registering context-manager capable runner object [%s] for [%s].", str(runner), + str(operation_type)) cluster_aware_runner = _multi_cluster_runner(runner, str(runner)) # we'd rather use callable() but this will erroneously also classify a class as callable... elif isinstance(runner, types.FunctionType): @@ -149,7 +156,8 @@ def register_runner(operation_type, runner, **kwargs): cluster_aware_runner = _single_cluster_runner(runner, runner.__name__) elif "__aenter__" in dir(runner) and "__aexit__" in dir(runner): if logger.isEnabledFor(logging.DEBUG): - logger.debug("Registering context-manager capable runner object [%s] for [%s].", str(runner), str(operation_type)) + logger.debug("Registering context-manager capable runner object [%s] for [%s].", str(runner), + str(operation_type)) cluster_aware_runner = _single_cluster_runner(runner, str(runner), context_manager_enabled=True) else: if logger.isEnabledFor(logging.DEBUG): @@ -158,6 +166,7 @@ def register_runner(operation_type, runner, **kwargs): __RUNNERS[operation_type] = _with_completion(_with_assertions(cluster_aware_runner)) + # Only intended for unit-testing! def remove_runner(operation_type): del __RUNNERS[operation_type] @@ -200,7 +209,7 @@ def _default_kw_params(self, params): "params": "request-params", "request_timeout": "request-timeout", } - full_result = {k: params.get(v) for (k, v) in kw_dict.items()} + full_result = {k: params.get(v) for (k, v) in kw_dict.items()} # filter Nones return dict(filter(lambda kv: kv[1] is not None, full_result.items())) @@ -215,8 +224,10 @@ def _transport_request_params(self, params): headers.update({"x-opaque-id": opaque_id}) return request_params, headers + request_context_holder = RequestContextHolder() + def time_func(func): async def advised(*args, **kwargs): request_context_holder.on_client_request_start() @@ -225,6 +236,7 @@ async def advised(*args, **kwargs): return response finally: request_context_holder.on_client_request_end() + return advised @@ -232,6 +244,7 @@ class Delegator: """ Mixin to unify delegate handling """ + def __init__(self, delegate, *args, **kwargs): super().__init__(*args, **kwargs) self.delegate = delegate @@ -500,7 +513,7 @@ async def __call__(self, opensearch, params): if "request-params" in params: bulk_params.update(params["request-params"]) - params.pop( "request-params" ) + params.pop("request-params") api_kwargs = self._default_kw_params(params) @@ -521,7 +534,8 @@ async def __call__(self, opensearch, params): response = await opensearch.bulk(doc_type=params.get("type"), params=bulk_params, **api_kwargs) request_context_holder.on_client_request_end() - stats = self.detailed_stats(params, response) if detailed_results else self.simple_stats(bulk_size, unit, response) + stats = self.detailed_stats(params, response) if detailed_results else self.simple_stats(bulk_size, unit, + response) meta_data = { "index": params.get("index"), @@ -926,7 +940,7 @@ async def _search_after_query(opensearch, params): if pit_op: pit_id = CompositeContext.get(pit_op) body["pit"] = {"id": pit_id, - "keep_alive": "1m" } + "keep_alive": "1m"} response = await self._raw_search( opensearch, doc_type=None, index=index, body=body.copy(), @@ -1015,16 +1029,16 @@ async def _scroll_query(opensearch, params): else: request_context_holder.on_client_request_start() r = await opensearch.transport.perform_request("GET", "/_search/scroll", - body={"scroll_id": scroll_id, "scroll": "10s"}, - params=request_params, - headers=headers) + body={"scroll_id": scroll_id, "scroll": "10s"}, + params=request_params, + headers=headers) request_context_holder.on_client_request_end() props = parse(r, ["timed_out", "took"], ["hits.hits"]) timed_out = timed_out or props.get("timed_out", False) took += props.get("took", 0) # is the list of hits empty? all_results_collected = props.get("hits.hits", False) - retrieved_pages +=1 + retrieved_pages += 1 if all_results_collected: break finally: @@ -1033,8 +1047,9 @@ async def _scroll_query(opensearch, params): try: await opensearch.clear_scroll(body={"scroll_id": [scroll_id]}) except BaseException: - self.logger.exception("Could not clear scroll [%s]. This will lead to excessive resource usage in " - "OpenSearch and will skew your benchmark results.", scroll_id) + self.logger.exception( + "Could not clear scroll [%s]. This will lead to excessive resource usage in " + "OpenSearch and will skew your benchmark results.", scroll_id) return { "weight": retrieved_pages, @@ -1057,6 +1072,10 @@ async def _vector_search_query_with_recall(opensearch, params): "success": True, "recall@k": 0, "recall@1": 0, + "recall@min_score": 0, + "recall@min_score_1": 0, + "recall@max_distance": 0, + "recall@max_distance_1": 0, } def _is_empty_search_results(content): @@ -1080,23 +1099,33 @@ def _get_field_value(content, field_name): return _get_field_value(content["_source"], field_name) return None - def calculate_recall(predictions, neighbors, top_k): + def calculate_recall(predictions, neighbors, top_1_recall=False): """ Calculates the recall by comparing top_k neighbors with predictions. recall = Sum of matched neighbors from predictions / total number of neighbors from ground truth Args: predictions: list containing ids of results returned by OpenSearch. neighbors: list containing ids of the actual neighbors for a set of queries - top_k: number of top results to check from the neighbors and should be greater than zero + top_1_recall: boolean to calculate recall@1 Returns: Recall between predictions and top k neighbors from ground truth """ correct = 0.0 - if neighbors is None: + try: + n = neighbors.index('-1') + # Slice the list to have a length of n + truth_set = neighbors[:n] + except ValueError: + # If '-1' is not found in the list, use the entire list + truth_set = neighbors + min_num_of_results = len(truth_set) + if min_num_of_results == 0: self.logger.info("No neighbors are provided for recall calculation") - return 0.0 - min_num_of_results = min(top_k, len(neighbors)) - truth_set = neighbors[:min_num_of_results] + return 1 + + if top_1_recall: + min_num_of_results = 1 + for j in range(min_num_of_results): if j >= len(predictions): self.logger.info("No more neighbors in prediction to compare against ground truth.\n" @@ -1106,10 +1135,11 @@ def calculate_recall(predictions, neighbors, top_k): if predictions[j] in truth_set: correct += 1.0 - return correct / min_num_of_results + return float(correct) / min_num_of_results doc_type = params.get("type") response = await self._raw_search(opensearch, doc_type, index, body, request_params, headers=headers) + recall_processing_start = time.perf_counter() if detailed_results: props = parse(response, ["hits.total", "hits.total.value", "hits.total.relation", "timed_out", "took"]) @@ -1137,12 +1167,19 @@ def calculate_recall(predictions, neighbors, top_k): continue candidates.append(field_value) neighbors_dataset = params["neighbors"] - num_neighbors = params.get("k", 1) - recall_k = calculate_recall(candidates, neighbors_dataset, num_neighbors) - result.update({"recall@k": recall_k}) - - recall_1 = calculate_recall(candidates, neighbors_dataset, 1) - result.update({"recall@1": recall_1}) + recall_threshold = calculate_recall(candidates, neighbors_dataset) + recall_top_1 = calculate_recall(candidates, neighbors_dataset, True) + max_distance = params.get("max_distance") + min_score = params.get("min_score") + if min_score: + result.update({"recall@min_score": recall_threshold}) + result.update({"recall@min_score_1": recall_top_1}) + elif max_distance: + result.update({"recall@max_distance": recall_threshold}) + result.update({"recall@max_distance_1": recall_top_1}) + else: + result.update({"recall@k": recall_threshold}) + result.update({"recall@1": recall_top_1}) recall_processing_end = time.perf_counter() recall_processing_time = convert.seconds_to_ms(recall_processing_end - recall_processing_start) @@ -1172,7 +1209,8 @@ async def _raw_search(self, opensearch, doc_type, index, body, params, headers=N components.append("_search") path = "/".join(components) request_context_holder.on_client_request_start() - response = await opensearch.transport.perform_request("GET", "/" + path, params=params, body=body, headers=headers) + response = await opensearch.transport.perform_request("GET", "/" + path, params=params, body=body, + headers=headers) request_context_holder.on_client_request_end() return response @@ -1205,7 +1243,7 @@ def __call__(self, response: BytesIO, get_point_in_time: bool, hits_total: Optio if get_point_in_time and not parsed.get("pit_id"): raise exceptions.BenchmarkAssertionError("Paginated query failure: " - "pit_id was expected but not found in the response.") + "pit_id was expected but not found in the response.") # standardize these before returning... parsed["hits.total.value"] = parsed.pop("hits.total.value", parsed.pop("hits.total", hits_total)) parsed["hits.total.relation"] = parsed.get("hits.total.relation", "eq") @@ -1269,7 +1307,8 @@ def status(v): result = { "weight": 1, "unit": "ops", - "success": status(cluster_status) >= status(expected_cluster_status) and relocating_shards <= expected_relocating_shards, + "success": status(cluster_status) >= status( + expected_cluster_status) and relocating_shards <= expected_relocating_shards, "cluster-status": cluster_status, "relocating-shards": relocating_shards } @@ -1285,14 +1324,15 @@ class PutPipeline(Runner): @time_func async def __call__(self, opensearch, params): await opensearch.ingest.put_pipeline(id=mandatory(params, "id", self), - body=mandatory(params, "body", self), - master_timeout=params.get("master-timeout"), - timeout=params.get("timeout"), - ) + body=mandatory(params, "body", self), + master_timeout=params.get("master-timeout"), + timeout=params.get("timeout"), + ) def __repr__(self, *args, **kwargs): return "put-pipeline" + class DeletePipeline(Runner): @time_func async def __call__(self, opensearch, params): @@ -1307,6 +1347,7 @@ async def __call__(self, opensearch, params): def __repr__(self, *args, **kwargs): return "delete-pipeline" + # TODO: refactor it after python client support search pipeline https://github.com/opensearch-project/opensearch-py/issues/474 class CreateSearchPipeline(Runner): @time_func @@ -1317,6 +1358,7 @@ async def __call__(self, opensearch, params): def __repr__(self, *args, **kwargs): return "create-search-pipeline" + class Refresh(Runner): @time_func async def __call__(self, opensearch, params): @@ -1431,7 +1473,7 @@ async def __call__(self, opensearch, params): for template, body in templates: request_context_holder.on_client_request_start() await opensearch.cluster.put_component_template(name=template, body=body, - params=request_params) + params=request_params) request_context_holder.on_client_request_end() return { "weight": len(templates), @@ -1461,7 +1503,8 @@ async def _exists(name): for template_name in template_names: if not only_if_exists: request_context_holder.on_client_request_start() - await opensearch.cluster.delete_component_template(name=template_name, params=request_params, ignore=[404]) + await opensearch.cluster.delete_component_template(name=template_name, params=request_params, + ignore=[404]) request_context_holder.on_client_request_end() ops_count += 1 elif only_if_exists and await _exists(template_name): @@ -1476,7 +1519,6 @@ async def _exists(name): "success": True } - def __repr__(self, *args, **kwargs): return "delete-component-template" @@ -1541,8 +1583,8 @@ async def __call__(self, opensearch, params): for template, body in templates: request_context_holder.on_client_request_start() await opensearch.indices.put_template(name=template, - body=body, - params=request_params) + body=body, + params=request_params) request_context_holder.on_client_request_end() return { "weight": len(templates), @@ -1628,7 +1670,8 @@ async def __call__(self, opensearch, params): if "data" in node["roles"]: node_names.append(node["name"]) if not node_names: - raise exceptions.BenchmarkAssertionError("Could not choose a suitable shrink-node automatically. Specify it explicitly.") + raise exceptions.BenchmarkAssertionError( + "Could not choose a suitable shrink-node automatically. Specify it explicitly.") for source_index in source_indices: shrink_node = random.choice(node_names) @@ -1637,13 +1680,13 @@ async def __call__(self, opensearch, params): # prepare index for shrinking await opensearch.indices.put_settings(index=source_index, - body={ - "settings": { - "index.routing.allocation.require._name": shrink_node, - "index.blocks.write": "true" - } - }, - preserve_existing=True) + body={ + "settings": { + "index.routing.allocation.require._name": shrink_node, + "index.blocks.write": "true" + } + }, + preserve_existing=True) self.logger.info("Waiting for relocation to finish for index [%s] ...", source_index) await self._wait_for(opensearch, source_index, f"shard relocation for index [{source_index}]") @@ -1654,7 +1697,7 @@ async def __call__(self, opensearch, params): target_body["settings"]["index.blocks.write"] = None # kick off the shrink operation index_suffix = remove_prefix(source_index, source_indices_stem) - final_target_index = target_index if len(index_suffix) == 0 else target_index+index_suffix + final_target_index = target_index if len(index_suffix) == 0 else target_index + index_suffix request_context_holder.on_client_request_start() await opensearch.indices.shrink(index=source_index, target=final_target_index, body=target_body) request_context_holder.on_client_request_end() @@ -1681,17 +1724,18 @@ async def __call__(self, opensearch, params): path = mandatory(params, "path", self) if not path.startswith("/"): self.logger.error("RawRequest failed. Path parameter: [%s] must begin with a '/'.", path) - raise exceptions.BenchmarkAssertionError(f"RawRequest [{path}] failed. Path parameter must begin with a '/'.") + raise exceptions.BenchmarkAssertionError( + f"RawRequest [{path}] failed. Path parameter must begin with a '/'.") if not bool(headers): #counter-intuitive, but preserves prior behavior headers = None request_context_holder.on_client_request_start() await opensearch.transport.perform_request(method=params.get("method", "GET"), - url=path, - headers=headers, - body=params.get("body"), - params=request_params) + url=path, + headers=headers, + body=params.get("body"), + params=request_params) request_context_holder.on_client_request_end() def __repr__(self, *args, **kwargs): @@ -1721,6 +1765,7 @@ class DeleteSnapshotRepository(Runner): """ Deletes a snapshot repository """ + @time_func async def __call__(self, opensearch, params): await opensearch.snapshot.delete_repository(repository=mandatory(params, "repository", repr(self))) @@ -1733,12 +1778,13 @@ class CreateSnapshotRepository(Runner): """ Creates a new snapshot repository """ + @time_func async def __call__(self, opensearch, params): request_params = params.get("request-params", {}) await opensearch.snapshot.create_repository(repository=mandatory(params, "repository", repr(self)), - body=mandatory(params, "body", repr(self)), - params=request_params) + body=mandatory(params, "body", repr(self)), + params=request_params) def __repr__(self, *args, **kwargs): return "create-snapshot-repository" @@ -1748,6 +1794,7 @@ class CreateSnapshot(Runner): """ Creates a new snapshot repository """ + @time_func async def __call__(self, opensearch, params): wait_for_completion = params.get("wait-for-completion", False) @@ -1757,9 +1804,9 @@ async def __call__(self, opensearch, params): mandatory(params, "body", repr(self)) api_kwargs = self._default_kw_params(params) await opensearch.snapshot.create(repository=repository, - snapshot=snapshot, - wait_for_completion=wait_for_completion, - **api_kwargs) + snapshot=snapshot, + wait_for_completion=wait_for_completion, + **api_kwargs) def __repr__(self, *args, **kwargs): return "create-snapshot" @@ -1776,8 +1823,8 @@ async def __call__(self, opensearch, params): while not snapshot_done: response = await opensearch.snapshot.status(repository=repository, - snapshot=snapshot, - ignore_unavailable=True) + snapshot=snapshot, + ignore_unavailable=True) if "snapshots" in response: response_state = response["snapshots"][0]["state"] @@ -1815,13 +1862,14 @@ class RestoreSnapshot(Runner): """ Restores a snapshot from an already registered repository """ + @time_func async def __call__(self, opensearch, params): api_kwargs = self._default_kw_params(params) await opensearch.snapshot.restore(repository=mandatory(params, "repository", repr(self)), - snapshot=mandatory(params, "snapshot", repr(self)), - wait_for_completion=params.get("wait-for-completion", False), - **api_kwargs) + snapshot=mandatory(params, "snapshot", repr(self)), + wait_for_completion=params.get("wait-for-completion", False), + **api_kwargs) def __repr__(self, *args, **kwargs): return "restore-snapshot" @@ -1897,7 +1945,8 @@ async def __call__(self, opensearch, params): transform_id = mandatory(params, "transform-id", self) body = mandatory(params, "body", self) defer_validation = params.get("defer-validation", False) - await opensearch.transform.put_transform(transform_id=transform_id, body=body, defer_validation=defer_validation) + await opensearch.transform.put_transform(transform_id=transform_id, body=body, + defer_validation=defer_validation) def __repr__(self, *args, **kwargs): return "create-transform" @@ -1966,10 +2015,10 @@ async def __call__(self, opensearch, params): if not self._start_time: self._start_time = time.monotonic() await opensearch.transform.stop_transform(transform_id=transform_id, - force=force, - timeout=timeout, - wait_for_completion=False, - wait_for_checkpoint=wait_for_checkpoint) + force=force, + timeout=timeout, + wait_for_completion=False, + wait_for_checkpoint=wait_for_checkpoint) while True: stats_response = await opensearch.transform.get_transform_stats(transform_id=transform_id) @@ -2046,8 +2095,8 @@ class SubmitAsyncSearch(Runner): async def __call__(self, opensearch, params): request_params = params.get("request-params", {}) response = await opensearch.async_search.submit(body=mandatory(params, "body", self), - index=params.get("index"), - params=request_params) + index=params.get("index"), + params=request_params) op_name = mandatory(params, "name", self) # id may be None if the operation has already returned @@ -2076,7 +2125,7 @@ async def __call__(self, opensearch, params): for search_id, search in async_search_ids(searches): request_context_holder.on_client_request_start() response = await opensearch.async_search.get(id=search_id, - params=request_params) + params=request_params) request_context_holder.on_client_request_end() is_running = response["is_running"] success = success and not is_running @@ -2197,13 +2246,15 @@ def _ctx(): try: return CompositeContext.ctx.get() except LookupError: - raise exceptions.BenchmarkAssertionError("This operation is only allowed inside a composite operation.") from None + raise exceptions.BenchmarkAssertionError( + "This operation is only allowed inside a composite operation.") from None class Composite(Runner): """ Executes a complex request structure which is measured by Benchmark as one composite operation. """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.supported_op_types = [ @@ -2246,7 +2297,8 @@ async def run_stream(self, opensearch, stream, connection_limit): timings.append(timing) else: - raise exceptions.BenchmarkAssertionError("Requests structure must contain [stream] or [operation-type].") + raise exceptions.BenchmarkAssertionError( + "Requests structure must contain [stream] or [operation-type].") except BaseException: # stop all already created tasks in case of exceptions for s in streams: @@ -2397,10 +2449,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def __repr__(self, *args, **kwargs): return "retryable %s" % repr(self.delegate) + class DeleteMlModel(Runner): @time_func async def __call__(self, opensearch, params): - body= { + body = { "query": { "match_phrase": { "name": { @@ -2422,12 +2475,13 @@ async def __call__(self, opensearch, params): model_ids.add(id) for model_id in model_ids: - resp=await opensearch.transport.perform_request('POST', '/_plugins/_ml/models/' + model_id + '/_undeploy') - resp=await opensearch.transport.perform_request('DELETE', '/_plugins/_ml/models/' + model_id) + resp = await opensearch.transport.perform_request('POST', '/_plugins/_ml/models/' + model_id + '/_undeploy') + resp = await opensearch.transport.perform_request('DELETE', '/_plugins/_ml/models/' + model_id) def __repr__(self, *args, **kwargs): return "delete-ml-model" + class RegisterMlModel(Runner): @time_func async def __call__(self, opensearch, params): @@ -2475,12 +2529,13 @@ async def __call__(self, opensearch, params): model_id = resp.get('model_id') with open('model_id.json', 'w') as f: - d = { 'model_id': model_id } + d = {'model_id': model_id} f.write(json.dumps(d)) def __repr__(self, *args, **kwargs): return "register-ml-model" + class DeployMlModel(Runner): @time_func async def __call__(self, opensearch, params): diff --git a/osbenchmark/workload/params.py b/osbenchmark/workload/params.py index 59ebe27d1..ee6337844 100644 --- a/osbenchmark/workload/params.py +++ b/osbenchmark/workload/params.py @@ -40,7 +40,7 @@ from osbenchmark import exceptions from osbenchmark.utils import io from osbenchmark.utils.dataset import DataSet, get_data_set, Context -from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter +from osbenchmark.utils.parse import parse_string_parameter, parse_int_parameter, parse_float_parameter from osbenchmark.workload import workload __PARAM_SOURCES_BY_OP = {} @@ -49,6 +49,7 @@ __STANDARD_VALUE_SOURCES = {} __STANDARD_VALUES = {} + def param_source_for_operation(op_type, workload, params, task_name): try: # we know that this can only be a Benchmark core parameter source @@ -65,6 +66,7 @@ def param_source_for_name(name, workload, params): else: return param_source(workload, params) + def get_standard_value_source(op_name, field_name): try: return __STANDARD_VALUE_SOURCES[op_name][field_name] @@ -76,7 +78,8 @@ def get_standard_value_source(op_name, field_name): def ensure_valid_param_source(param_source): if not inspect.isfunction(param_source) and not inspect.isclass(param_source): - raise exceptions.BenchmarkAssertionError(f"Parameter source [{param_source}] must be either a function or a class.") + raise exceptions.BenchmarkAssertionError( + f"Parameter source [{param_source}] must be either a function or a class.") def register_param_source_for_operation(op_type, param_source_class): @@ -88,12 +91,14 @@ def register_param_source_for_name(name, param_source_class): ensure_valid_param_source(param_source_class) __PARAM_SOURCES_BY_NAME[name] = param_source_class + def register_standard_value_source(op_name, field_name, standard_value_source): if op_name in __STANDARD_VALUE_SOURCES: __STANDARD_VALUE_SOURCES[op_name][field_name] = standard_value_source # We have to allow re-registration for the same op/field, since plugins are loaded many times when a workload is run else: - __STANDARD_VALUE_SOURCES[op_name] = {field_name:standard_value_source} + __STANDARD_VALUE_SOURCES[op_name] = {field_name: standard_value_source} + def generate_standard_values_if_absent(op_name, field_name, n): if not op_name in __STANDARD_VALUES: @@ -109,11 +114,13 @@ def generate_standard_values_if_absent(op_name, field_name, n): for _i in range(n): __STANDARD_VALUES[op_name][field_name].append(standard_value_source()) + def get_standard_value(op_name, field_name, i): try: return __STANDARD_VALUES[op_name][field_name][i] except KeyError: - raise exceptions.SystemSetupError("No standard values generated for operation {}, field {}".format(op_name, field_name)) + raise exceptions.SystemSetupError( + "No standard values generated for operation {}, field {}".format(op_name, field_name)) except IndexError: raise exceptions.SystemSetupError( "Standard value index {} out of range for operation {}, field name {} ({} values total)" @@ -126,11 +133,13 @@ def _unregister_param_source_for_name(name): # something is fishy with the test and we'd rather know early. __PARAM_SOURCES_BY_NAME.pop(name) + # only intended for tests def _clear_standard_values(): __STANDARD_VALUES = {} __STANDARD_VALUE_SOURCES = {} + # Default class ParamSource: """ @@ -288,6 +297,7 @@ def validate_index_codec(self, settings): if "index.codec" in settings: return workload.IndexCodec.is_codec_valid(settings["index.codec"]) + class CreateDataStreamParamSource(ParamSource): def __init__(self, workload, params, **kwargs): super().__init__(workload, params, **kwargs) @@ -307,7 +317,8 @@ def __init__(self, workload, params, **kwargs): for ds in data_streams: self.data_stream_definitions.append(ds) except KeyError: - raise exceptions.InvalidSyntax("Please set the property 'data-stream' for the create-data-stream operation") + raise exceptions.InvalidSyntax( + "Please set the property 'data-stream' for the create-data-stream operation") def params(self): p = {} @@ -404,7 +415,8 @@ def __init__(self, workload, params, **kwargs): try: self.template_definitions.append((params["template"], params["body"])) except KeyError: - raise exceptions.InvalidSyntax("Please set the properties 'template' and 'body' for the create-index-template operation") + raise exceptions.InvalidSyntax( + "Please set the properties 'template' and 'body' for the create-index-template operation") def params(self): p = {} @@ -427,12 +439,14 @@ def __init__(self, workload, params, **kwargs): filter_template = params.get("template") for template in workload.templates: if not filter_template or template.name == filter_template: - self.template_definitions.append((template.name, template.delete_matching_indices, template.pattern)) + self.template_definitions.append( + (template.name, template.delete_matching_indices, template.pattern)) else: try: template = params["template"] except KeyError: - raise exceptions.InvalidSyntax(f"Please set the property 'template' for the {params.get('operation-type')} operation") + raise exceptions.InvalidSyntax( + f"Please set the property 'template' for the {params.get('operation-type')} operation") delete_matching = params.get("delete-matching-indices", False) try: @@ -470,7 +484,8 @@ def __init__(self, workload, params, **kwargs): template = params["template"] self.template_definitions.append(template) except KeyError: - raise exceptions.InvalidSyntax(f"Please set the property 'template' for the {params.get('operation-type')} operation.") + raise exceptions.InvalidSyntax( + f"Please set the property 'template' for the {params.get('operation-type')} operation.") def params(self): return { @@ -581,7 +596,8 @@ def __init__(self, workload, params, **kwargs): # for paginated queries the value does not matter because detailed results are always retrieved. is_paginated = bool(pages) if not is_paginated: - raise exceptions.InvalidSyntax("The property [detailed-results] must be [true] if assertions are defined") + raise exceptions.InvalidSyntax( + "The property [detailed-results] must be [true] if assertions are defined") self.query_params["assertions"] = params["assertions"] # Ensure we pass global parameters @@ -623,12 +639,14 @@ def __init__(self, workload, params, **kwargs): raise exceptions.InvalidSyntax("'conflicts' cannot be used with 'data-streams'") if self.id_conflicts != IndexIdConflict.NoConflicts: - self.conflict_probability = self.float_param(params, name="conflict-probability", default_value=25, min_value=0, max_value=100, + self.conflict_probability = self.float_param(params, name="conflict-probability", default_value=25, + min_value=0, max_value=100, min_operator=operator.lt) self.on_conflict = params.get("on-conflict", "index") if self.on_conflict not in ["index", "update"]: raise exceptions.InvalidSyntax("Unknown 'on-conflict' setting [{}]".format(self.on_conflict)) - self.recency = self.float_param(params, name="recency", default_value=0, min_value=0, max_value=1, min_operator=operator.lt) + self.recency = self.float_param(params, name="recency", default_value=0, min_value=0, max_value=1, + min_operator=operator.lt) else: self.conflict_probability = None @@ -638,16 +656,18 @@ def __init__(self, workload, params, **kwargs): self.corpora = self.used_corpora(workload, params) if len(self.corpora) == 0: - raise exceptions.InvalidSyntax(f"There is no document corpus definition for workload {workload}. You must add at " - f"least one before making bulk requests to OpenSearch.") + raise exceptions.InvalidSyntax( + f"There is no document corpus definition for workload {workload}. You must add at " + f"least one before making bulk requests to OpenSearch.") for corpus in self.corpora: for document_set in corpus.documents: if document_set.includes_action_and_meta_data and self.id_conflicts != IndexIdConflict.NoConflicts: file_name = document_set.document_archive if document_set.has_compressed_corpus() else document_set.document_file - raise exceptions.InvalidSyntax("Cannot generate id conflicts [%s] as [%s] in document corpus [%s] already contains an " - "action and meta-data line." % (id_conflicts, file_name, corpus)) + raise exceptions.InvalidSyntax( + "Cannot generate id conflicts [%s] as [%s] in document corpus [%s] already contains an " + "action and meta-data line." % (id_conflicts, file_name, corpus)) self.pipeline = params.get("pipeline", None) try: @@ -670,7 +690,8 @@ def __init__(self, workload, params, **kwargs): except ValueError: raise exceptions.InvalidSyntax("'batch-size' must be numeric") - self.ingest_percentage = self.float_param(params, name="ingest-percentage", default_value=100, min_value=0, max_value=100) + self.ingest_percentage = self.float_param(params, name="ingest-percentage", default_value=100, min_value=0, + max_value=100) self.param_source = PartitionBulkIndexParamSource(self.corpora, self.batch_size, self.bulk_size, self.ingest_percentage, self.id_conflicts, self.conflict_probability, self.on_conflict, @@ -682,7 +703,8 @@ def float_param(self, params, name, default_value, min_value, max_value, min_ope if min_operator(value, min_value) or value > max_value: interval_min = "(" if min_operator is operator.le else "[" raise exceptions.InvalidSyntax( - "'{}' must be in the range {}{:.1f}, {:.1f}] but was {:.1f}".format(name, interval_min, min_value, max_value, value)) + "'{}' must be in the range {}{:.1f}, {:.1f}] but was {:.1f}".format(name, interval_min, min_value, + max_value, value)) return value except ValueError: raise exceptions.InvalidSyntax("'{}' must be numeric".format(name)) @@ -705,7 +727,7 @@ def used_corpora(self, t, params): # the workload has corpora but none of them match if t.corpora and not corpora: raise exceptions.BenchmarkAssertionError("The provided corpus %s does not match any of the corpora %s." % - (corpora_names, workload_corpora_names)) + (corpora_names, workload_corpora_names)) return corpora @@ -953,7 +975,6 @@ def partition(self, partition_index, total_partitions): if self.data_set is None: self.data_set: DataSet = get_data_set( self.data_set_format, self.data_set_path, self.context) - # if value is -1 or greater than dataset size, use dataset size as num_vectors if self.total_num_vectors < 0 or self.total_num_vectors > self.data_set.size(): self.total_num_vectors = self.data_set.size() self.total = self.total_num_vectors @@ -965,7 +986,8 @@ def partition(self, partition_index, total_partitions): partition_x.num_vectors = min_num_vectors_per_partition # if partition is not divided equally, add extra docs to the last partition - if self.total_num_vectors % total_partitions != 0 and self._is_last_partition(partition_index, total_partitions): + if self.total_num_vectors % total_partitions != 0 and self._is_last_partition(partition_index, + total_partitions): remaining_vectors = self.total_num_vectors - (min_num_vectors_per_partition * total_partitions) partition_x.num_vectors += remaining_vectors @@ -1027,6 +1049,8 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource): request-params: query parameters that can be passed to search request """ PARAMS_NAME_K = "k" + PARAMS_NAME_MAX_DISTANCE = "max_distance" + PARAMS_NAME_MIN_SCORE = "min_score" PARAMS_NAME_BODY = "body" PARAMS_NAME_SIZE = "size" PARAMS_NAME_QUERY = "query" @@ -1041,11 +1065,26 @@ class VectorSearchPartitionParamSource(VectorDataSetPartitionParamSource): PARAMS_NAME_REQUEST_PARAMS = "request-params" PARAMS_NAME_SOURCE = "_source" PARAMS_NAME_ALLOW_PARTIAL_RESULTS = "allow_partial_search_results" + MIN_SCORE_QUERY_TYPE = "min_score" + MAX_DISTANCE_QUERY_TYPE = "max_distance" + KNN_QUERY_TYPE = "knn" def __init__(self, workloads, params, query_params, **kwargs): super().__init__(workloads, params, Context.QUERY, **kwargs) self.logger = logging.getLogger(__name__) - self.k = parse_int_parameter(self.PARAMS_NAME_K, params) + self.k = None + self.distance = None + self.score = None + if self.PARAMS_NAME_K in params: + self.k = parse_int_parameter(self.PARAMS_NAME_K, params) + self.query_type = self.KNN_QUERY_TYPE + if self.PARAMS_NAME_MAX_DISTANCE in params: + self.distance = parse_float_parameter(self.PARAMS_NAME_MAX_DISTANCE, params) + self.query_type = self.MAX_DISTANCE_QUERY_TYPE + if self.PARAMS_NAME_MIN_SCORE in params: + self.score = parse_float_parameter(self.PARAMS_NAME_MIN_SCORE, params) + self.query_type = self.MIN_SCORE_QUERY_TYPE + self.logger.info("query type is set up to %s", self.query_type) self.repetitions = parse_int_parameter(self.PARAMS_NAME_REPETITIONS, params, 1) self.current_rep = 1 self.neighbors_data_set_format = parse_string_parameter( @@ -1058,13 +1097,24 @@ def __init__(self, workloads, params, query_params, **kwargs): self.PARAMS_VALUE_VECTOR_SEARCH) self.query_params = query_params self.query_params.update({ - self.PARAMS_NAME_K: self.k, self.PARAMS_NAME_OPERATION_TYPE: operation_type, self.PARAMS_NAME_ID_FIELD_NAME: params.get(self.PARAMS_NAME_ID_FIELD_NAME), }) + if self.PARAMS_NAME_K in params: + self.query_params.update({ + self.PARAMS_NAME_K: self.k + }) + if self.PARAMS_NAME_MAX_DISTANCE in params: + self.query_params.update({ + self.PARAMS_NAME_MAX_DISTANCE: self.distance + }) + if self.PARAMS_NAME_MIN_SCORE in params: + self.query_params.update({ + self.PARAMS_NAME_MIN_SCORE: self.score + }) if self.PARAMS_NAME_FILTER in params: self.query_params.update({ - self.PARAMS_NAME_FILTER: params.get(self.PARAMS_NAME_FILTER) + self.PARAMS_NAME_FILTER: params.get(self.PARAMS_NAME_FILTER) }) # if neighbors data set is defined as corpus, extract corresponding corpus from workload # and add it to corpora list @@ -1086,15 +1136,29 @@ def _update_request_params(self): self.PARAMS_NAME_ALLOW_PARTIAL_RESULTS, "false") self.query_params.update({self.PARAMS_NAME_REQUEST_PARAMS: request_params}) - def _update_body_params(self, vector): + def _get_query_neighbors(self): + if self.query_type == self.KNN_QUERY_TYPE: + return Context.NEIGHBORS + elif self.query_type == self.MIN_SCORE_QUERY_TYPE: + return Context.MIN_SCORE_NEIGHBORS + elif self.query_type == self.MAX_DISTANCE_QUERY_TYPE: + return Context.MAX_DISTANCE_NEIGHBORS + else: + raise exceptions.InvalidSyntax("Unknown query type [%s]" % self.query_type) + + def _update_body_params(self, vector, distance, score): # accept body params if passed from workload, else, create empty dictionary body_params = self.query_params.get(self.PARAMS_NAME_BODY) or dict() if self.PARAMS_NAME_SIZE not in body_params: - body_params[self.PARAMS_NAME_SIZE] = self.k + if self.query_type == self.KNN_QUERY_TYPE: + body_params[self.PARAMS_NAME_SIZE] = self.k + else: + # if distance is set, set size to 10000, which is the maximum number results returned by default + body_params[self.PARAMS_NAME_SIZE] = 10000 if self.PARAMS_NAME_QUERY in body_params: self.logger.warning( "[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY) - efficient_filter=self.query_params.get(self.PARAMS_NAME_FILTER) + efficient_filter = self.query_params.get(self.PARAMS_NAME_FILTER) # override query params with vector search query body_params[self.PARAMS_NAME_QUERY] = self._build_vector_search_query_body(vector, efficient_filter) self.query_params.update({self.PARAMS_NAME_BODY: body_params}) @@ -1108,9 +1172,10 @@ def partition(self, partition_index, total_partitions): self.neighbors_data_set_path = neighbors_data_set_path[0] if not self.neighbors_data_set_path: self.neighbors_data_set_path = self.data_set_path + # add neighbor instance to partition partition.neighbors_data_set = get_data_set( - self.neighbors_data_set_format, self.neighbors_data_set_path, Context.NEIGHBORS) + self.neighbors_data_set_format, self.neighbors_data_set_path, self._get_query_neighbors()) partition.neighbors_data_set.seek(partition.offset) return partition @@ -1128,29 +1193,57 @@ def params(self): elif is_dataset_exhausted: raise StopIteration vector = self.data_set.read(1)[0] + if self.distance is not None: + max_distance = self.distance + else: + max_distance = None + if self.score is not None: + min_score = self.score + else: + min_score = None neighbor = self.neighbors_data_set.read(1)[0] - true_neighbors = list(map(str, neighbor[:self.k])) + if self.k: + true_neighbors = list(map(str, neighbor[:self.k])) + else: + true_neighbors = list(map(str, neighbor)) self.query_params.update({ "neighbors": true_neighbors, }) self._update_request_params() - self._update_body_params(vector) + self._update_body_params(vector, max_distance, min_score) self.current += 1 self.percent_completed = self.current / self.total return self.query_params def _build_vector_search_query_body(self, vector, efficient_filter=None) -> dict: - """Builds a k-NN request that can be used to execute an approximate nearest + """Builds a vector search request that can be used to execute an approximate nearest neighbor search against a k-NN plugin index Args: vector: vector used for query + efficient_filter: efficient filter used for query Returns: A dictionary containing the body used for search query """ - query = { + query = {} + if self.query_type == self.MAX_DISTANCE_QUERY_TYPE: + query.update({ + "max_distance": self.distance, + }) + elif self.query_type == self.MIN_SCORE_QUERY_TYPE: + query.update({ + "min_score": self.score, + }) + elif self.query_type == self.KNN_QUERY_TYPE: + query.update({ + "k": self.k, + }) + else: + raise exceptions.InvalidSyntax("Unknown query type [%s]" % self.query_type) + + query.update({ "vector": vector, - "k": self.k, - } + }) + if efficient_filter: query.update({ "filter": efficient_filter, @@ -1255,6 +1348,7 @@ def get_target(workload, params): target_name = params.get("data-stream", default_target) return target_name + def number_of_bulks(corpora, start_partition_index, end_partition_index, total_partitions, bulk_size): """ :return: The number of bulk operations that the given client will issue. @@ -1317,7 +1411,8 @@ def create_default_reader(docs, offset, num_lines, num_docs, batch_size, bulk_si am_handler = GenerateActionMetaData(target, docs.target_type, build_conflicting_ids(id_conflicts, num_docs, offset), conflict_probability, on_conflict, recency, use_create=use_create) - return MetadataIndexDataReader(docs.document_file, batch_size, bulk_size, source, am_handler, target, docs.target_type) + return MetadataIndexDataReader(docs.document_file, batch_size, bulk_size, source, am_handler, target, + docs.target_type) def create_readers(num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, id_conflicts, @@ -1332,9 +1427,10 @@ def create_readers(num_clients, start_client_index, end_client_index, corpora, b target = f"{docs.target_index}/{docs.target_type}" if docs.target_index else "/" if docs.target_data_stream: target = docs.target_data_stream - logger.info("Task-relative clients at index [%d-%d] will bulk index [%d] docs starting from line offset [%d] for [%s] " - "from corpus [%s].", start_client_index, end_client_index, num_docs, offset, - target, corpus.name) + logger.info( + "Task-relative clients at index [%d-%d] will bulk index [%d] docs starting from line offset [%d] for [%s] " + "from corpus [%s].", start_client_index, end_client_index, num_docs, offset, + target, corpus.name) readers.append(create_reader(docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency)) else: @@ -1396,7 +1492,8 @@ def bulk_generator(readers, pipeline, original_params): def bulk_data_based(num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, id_conflicts, - conflict_probability, on_conflict, recency, pipeline, original_params, create_reader=create_default_reader): + conflict_probability, on_conflict, recency, pipeline, original_params, + create_reader=create_default_reader): """ Calculates the necessary schedule for bulk operations. @@ -1427,7 +1524,8 @@ def bulk_data_based(num_clients, start_client_index, end_client_index, corpora, class GenerateActionMetaData: RECENCY_SLOPE = 30 - def __init__(self, index_name, type_name, conflicting_ids=None, conflict_probability=None, on_conflict=None, recency=None, + def __init__(self, index_name, type_name, conflicting_ids=None, conflict_probability=None, on_conflict=None, + recency=None, rand=random.random, randint=random.randint, randexp=random.expovariate, use_create=False): if type_name: self.meta_data_index_with_id = '{"index": {"_index": "%s", "_type": "%s", "_id": "%s"}}\n' % \ @@ -1670,7 +1768,8 @@ def read_bulk(self): register_param_source_for_operation(workload.OperationType.DeleteIndexTemplate, DeleteIndexTemplateParamSource) register_param_source_for_operation(workload.OperationType.CreateComponentTemplate, CreateComponentTemplateParamSource) register_param_source_for_operation(workload.OperationType.DeleteComponentTemplate, DeleteComponentTemplateParamSource) -register_param_source_for_operation(workload.OperationType.CreateComposableTemplate, CreateComposableTemplateParamSource) +register_param_source_for_operation(workload.OperationType.CreateComposableTemplate, + CreateComposableTemplateParamSource) register_param_source_for_operation(workload.OperationType.DeleteComposableTemplate, DeleteIndexTemplateParamSource) register_param_source_for_operation(workload.OperationType.Sleep, SleepParamSource) register_param_source_for_operation(workload.OperationType.ForceMerge, ForceMergeParamSource) diff --git a/tests/worker_coordinator/runner_test.py b/tests/worker_coordinator/runner_test.py index 6969937c1..9b25be67b 100644 --- a/tests/worker_coordinator/runner_test.py +++ b/tests/worker_coordinator/runner_test.py @@ -2802,6 +2802,165 @@ async def test_query_vector_search_with_custom_id_field_inside_source(self, open ) + @mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end') + @mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_start') + @mock.patch("opensearchpy.OpenSearch") + @run_async + async def test_query_vector_radial_search_with_min_score(self, opensearch, on_client_request_start, on_client_request_end): + search_response = { + "timed_out": False, + "took": 5, + "hits": { + "total": { + "value": 3, + "relation": "eq" + }, + "hits": [ + { + "_id": 101, + "_score": 0.95 + }, + { + "_id": 102, + "_score": 0.88 + }, + { + "_id": 103, + "_score": 0.87 + } + ] + } + } + opensearch.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response))) + + query_runner = runner.Query() + + params = { + "index": "unittest", + "operation-type": "vector-search", + "detailed-results": True, + "response-compression-enabled": False, + "min_score": 0.80, + "neighbors": [101, 102, 103], + "body": { + "query": { + "knn": { + "location": { + "vector": [ + 5, + 4 + ], + "min_score": 0.80, + } + } + } + } + } + + async with query_runner: + result = await query_runner(opensearch, params) + + self.assertEqual(1, result["weight"]) + self.assertEqual("ops", result["unit"]) + self.assertEqual(3, result["hits"]) + self.assertEqual("eq", result["hits_relation"]) + self.assertFalse(result["timed_out"]) + self.assertEqual(5, result["took"]) + self.assertIn("recall_time_ms", result.keys()) + self.assertIn("recall@min_score", result.keys()) + self.assertEqual(result["recall@min_score"], 1.0) + self.assertIn("recall@min_score_1", result.keys()) + self.assertEqual(result["recall@min_score_1"], 1.0) + self.assertNotIn("error-type", result.keys()) + + opensearch.transport.perform_request.assert_called_once_with( + "GET", + "/unittest/_search", + params={}, + body=params["body"], + headers={"Accept-Encoding": "identity"} + ) + + @mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end') + @mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_start') + @mock.patch("opensearchpy.OpenSearch") + @run_async + async def test_query_vector_radial_search_with_max_distance(self, opensearch, on_client_request_start, on_client_request_end): + search_response = { + "timed_out": False, + "took": 5, + "hits": { + "total": { + "value": 3, + "relation": "eq" + }, + "hits": [ + { + "_id": 101, + "_score": 0.95 + }, + { + "_id": 102, + "_score": 0.88 + }, + { + "_id": 103, + "_score": 0.87 + } + ] + } + } + opensearch.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response))) + + query_runner = runner.Query() + + params = { + "index": "unittest", + "operation-type": "vector-search", + "detailed-results": True, + "response-compression-enabled": False, + "max_distance": 15.0, + "neighbors": [101, 102, 103, 104], + "body": { + "query": { + "knn": { + "location": { + "vector": [ + 5, + 4 + ], + "max_distance": 15.0, + } + } + } + } + } + + async with query_runner: + result = await query_runner(opensearch, params) + + self.assertEqual(1, result["weight"]) + self.assertEqual("ops", result["unit"]) + self.assertEqual(3, result["hits"]) + self.assertEqual("eq", result["hits_relation"]) + self.assertFalse(result["timed_out"]) + self.assertEqual(5, result["took"]) + self.assertIn("recall_time_ms", result.keys()) + self.assertIn("recall@max_distance", result.keys()) + self.assertEqual(result["recall@max_distance"], 0.75) + self.assertIn("recall@max_distance_1", result.keys()) + self.assertEqual(result["recall@max_distance_1"], 1.0) + self.assertNotIn("error-type", result.keys()) + + opensearch.transport.perform_request.assert_called_once_with( + "GET", + "/unittest/_search", + params={}, + body=params["body"], + headers={"Accept-Encoding": "identity"} + ) + + class PutPipelineRunnerTests(TestCase): @mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end') @mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_start')