diff --git a/runtimes/mlflow/mlserver_mlflow/runtime.py b/runtimes/mlflow/mlserver_mlflow/runtime.py index 37bf0a039..ffb322fe0 100644 --- a/runtimes/mlflow/mlserver_mlflow/runtime.py +++ b/runtimes/mlflow/mlserver_mlflow/runtime.py @@ -196,5 +196,8 @@ def _sync_metadata(self) -> None: async def predict(self, payload: InferenceRequest) -> InferenceResponse: decoded_payload = self.decode_request(payload) - model_output = self._model.predict(decoded_payload) + params = None + if payload.parameters and payload.parameters.model_extra: + params = payload.parameters.model_extra + model_output = self._model.predict(decoded_payload, params=params) return self.encode_response(model_output, default_codec=TensorDictCodec) diff --git a/runtimes/mlflow/tests/test_runtime.py b/runtimes/mlflow/tests/test_runtime.py index 8531ff987..409d3f48c 100644 --- a/runtimes/mlflow/tests/test_runtime.py +++ b/runtimes/mlflow/tests/test_runtime.py @@ -4,7 +4,6 @@ import pandas as pd from typing import Any - from mlserver.codecs import NumpyCodec, PandasCodec, StringCodec from mlserver.types import ( InferenceRequest, @@ -257,3 +256,51 @@ async def test_invocation_with_params( predict_mock.call_args[0][0].get("foo"), expected["data"]["foo"] ) assert predict_mock.call_args.kwargs["params"] == expected["params"] + + +@pytest.mark.parametrize( + "input, params", + [ + ( + InferenceRequest( + parameters=Parameters( + content_type=NumpyCodec.ContentType, extra_param="extra_value" + ), + inputs=[ + RequestInput( + name="predict", + shape=[1, 10], + data=[range(0, 10)], + datatype="INT64", + parameters=Parameters(extra_param2="extra_value2"), + ) + ], + ), + {"extra_param": "extra_value"}, + ), + ( + InferenceRequest( + parameters=Parameters(content_type=NumpyCodec.ContentType), + inputs=[ + RequestInput( + name="predict", + shape=[1, 10], + data=[range(0, 10)], + datatype="INT64", + ) + ], + ), + None, + ), + ], +) +async def test_predict_with_params( + runtime: MLflowRuntime, + input: InferenceRequest, + params: dict, +): + with mock.patch.object( + runtime._model, "predict", return_value={"test": np.array([1, 2, 3])} + ) as predict_mock: + await runtime.predict(input) + assert predict_mock.call_args.kwargs == {"params": params}