Skip to content

Commit

Permalink
Add TrainKNN Runner/Operation for Benchmarking Approximate KNN Algori…
Browse files Browse the repository at this point in the history
…thms (#556)

Signed-off-by: Finn Roblin <[email protected]>
  • Loading branch information
finnroblin authored Jul 18, 2024
1 parent 32593d5 commit 01346bb
Show file tree
Hide file tree
Showing 3 changed files with 508 additions and 3 deletions.
183 changes: 181 additions & 2 deletions osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from osbenchmark.utils import convert
from osbenchmark.client import RequestContextHolder
# Mapping from operation type to specific runner
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter, parse_float_parameter

__RUNNERS = {}

Expand Down Expand Up @@ -105,7 +105,8 @@ def register_default_runners():
register_runner(workload.OperationType.DeleteMlModel, Retry(DeleteMlModel()), async_runner=True)
register_runner(workload.OperationType.RegisterMlModel, Retry(RegisterMlModel()), async_runner=True)
register_runner(workload.OperationType.DeployMlModel, Retry(DeployMlModel()), async_runner=True)

register_runner(workload.OperationType.TrainKnnModel, Retry(TrainKnnModel()), async_runner=True)
register_runner(workload.OperationType.DeleteKnnModel, Retry(DeleteKnnModel()), async_runner=True)

def runner_for(operation_type):
try:
Expand Down Expand Up @@ -652,6 +653,184 @@ def __repr__(self, *args, **kwargs):
return "bulk-index"


class DeleteKnnModel(Runner):
"""
Deletes the K-NN model named model_id.
"""

NAME = "delete-knn-model"
MODEL_DOES_NOT_EXIST_STATUS_CODE = 404

async def __call__(self, opensearch, params):
model_id = parse_string_parameter("model_id", params)
ignore_if_model_does_not_exist = params.get(
"ignore-if-model-does-not-exist", False
)

method = "DELETE"
model_uri = f"/_plugins/_knn/models/{model_id}"

request_context_holder.on_client_request_start()

# 404 indicates the model has not been created. In that case, the runner's response depends on ignore_if_model_does_not_exist.
response = await opensearch.transport.perform_request(
method,
model_uri,
params={"ignore": [self.MODEL_DOES_NOT_EXIST_STATUS_CODE]},
)

request_context_holder.on_client_request_end()

# success condition.
if "result" in response.keys() and response["result"] == "deleted":
self.logger.debug("Model [%s] deleted successfully.", model_id)
return {"weight": 1, "unit": "ops", "success": True}

if "error" not in response.keys():
self.logger.warning(
"Request to delete model [%s] failed but no error, response: [%s]",
model_id,
response,
)
return {"weight": 1, "unit": "ops", "success": False}

if response["status"] != self.MODEL_DOES_NOT_EXIST_STATUS_CODE:
self.logger.warning(
"Request to delete model [%s] failed with status [%s] and response: [%s]",
model_id,
response["status"],
response,
)
return {"weight": 1, "unit": "ops", "success": False}

if ignore_if_model_does_not_exist:
self.logger.debug(
(
"Model [%s] does not exist so it could not be deleted, "
"however ignore-if-model-does-not-exist is True so the "
"DeleteKnnModel operation succeeded."
),
model_id,
)

return {"weight": 1, "unit": "ops", "success": True}

self.logger.warning(
(
"Request to delete model [%s] failed because the model does not exist "
"and ignore-if-model-does-not-exist was set to False. Response: [%s]"
),
model_id,
response,
)
return {"weight": 1, "unit": "ops", "success": False}

def __repr__(self, *args, **kwargs):
return self.NAME


class TrainKnnModel(Runner):
"""
Trains model named model_id until training is complete or retries are exhausted.
"""

NAME = "train-knn-model"
DEFAULT_RETRIES = 1000
DEFAULT_POLL_PERIOD = 0.5

async def __call__(self, opensearch, params):
"""
Create and train one model named model_id.
:param opensearch: The OpenSearch client.
:param params: A hash with all parameters. See below for details.
:return: A hash with meta data for this bulk operation. See below for details.
:raises: Exception if training fails, times out, or a different error occurs.
It expects a parameter dict with the following mandatory keys:
* ``body``: containing parameters to pass on to the train engine.
See https://opensearch.org/docs/latest/search-plugins/knn/api/#train-a-model for information.
* ``retries``: Maximum number of retries allowed for the training to complete (seconds).
* ``polling-interval``: Polling interval to see if the model has been trained yet (seconds).
* ``model_id``: ID of the model to train.
"""
body = params["body"]
model_id = parse_string_parameter("model_id", params)
max_retries = parse_int_parameter("retries", params, self.DEFAULT_RETRIES)
poll_period = parse_float_parameter(
"poll_period", params, self.DEFAULT_POLL_PERIOD
)

method = "POST"
model_uri = f"/_plugins/_knn/models/{model_id}"
request_context_holder.on_client_request_start()
await opensearch.transport.perform_request(
method, f"{model_uri}/_train", body=body
)

current_number_retries = 0
while True:
model_response = await opensearch.transport.perform_request(
"GET", model_uri
)

if "state" not in model_response.keys():
request_context_holder.on_client_request_end()
self.logger.error(
"Failed to create model [%s] with error response: [%s]",
model_id,
model_response,
)
raise Exception(
f"Failed to create model {model_id} with error response: {model_response}"
)

if current_number_retries > max_retries:
request_context_holder.on_client_request_end()
self.logger.error(
"Failed to create model [%s] within [%i] retries.",
model_id,
max_retries,
)
raise TimeoutError(
f"Failed to create model: {model_id} within {max_retries} retries"
)

if model_response["state"] == "training":
current_number_retries += 1
await asyncio.sleep(poll_period)
continue

# at this point, training either failed or finished.
request_context_holder.on_client_request_end()
if model_response["state"] == "created":
self.logger.info(
"Training model [%s] was completed successfully.", model_id
)
return

if model_response["state"] == "failed":
self.logger.error(
"Training for model [%s] failed. Response: [%s]",
model_id,
model_response,
)
raise Exception(f"Failed to create model {model_id}: {model_response}")

self.logger.error(
"Model [%s] in unknown state [%s], response: [%s]",
model_id,
model_response["state"],
model_response,
)
raise Exception(
f"Model {model_id} in unknown state {model_response['state']}, response: {model_response}"
)

def __repr__(self, *args, **kwargs):
return self.NAME


# TODO: Add retry logic to BulkIndex, so that we can remove BulkVectorDataSet and use BulkIndex.
class BulkVectorDataSet(Runner):
"""
Expand Down
6 changes: 6 additions & 0 deletions osbenchmark/workload/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ class OperationType(Enum):
ListAllPointInTime = 16
VectorSearch = 17
BulkVectorDataSet = 18
TrainKnnModel = 19
DeleteKnnModel = 20

# administrative actions
ForceMerge = 1001
Expand Down Expand Up @@ -746,6 +748,10 @@ def from_hyphenated_string(cls, v):
return OperationType.RegisterMlModel
elif v == "deploy-ml-model":
return OperationType.DeployMlModel
elif v == "train-knn-model":
return OperationType.TrainKnnModel
elif v == "delete-knn-model":
return OperationType.DeleteKnnModel
else:
raise KeyError(f"No enum value for [{v}]")

Expand Down
Loading

0 comments on commit 01346bb

Please sign in to comment.