diff --git a/runtimes/huggingface/README.md b/runtimes/huggingface/README.md index 04db06ccb..ba15a39ca 100644 --- a/runtimes/huggingface/README.md +++ b/runtimes/huggingface/README.md @@ -66,6 +66,24 @@ Models in the HuggingFace hub can be loaded by specifying their name in `paramet If `parameters.extra.pretrained_model` is specified, it takes precedence over `parameters.uri`. ```` +#### Model Inference +Model inference is done by HuggingFace pipeline. It allows users to run inference on a batch of inputs. Extra inference kwargs can be kept in `parameters.extra`. +```{code-block} json +{ + "inputs": [ + { + "name": "text_inputs", + "shape": [1], + "datatype": "BYTES", + "data": ["My kitten's name is JoJo,","Tell me a story:"], + } + ], + "parameters": { + "extra":{"max_new_tokens": 200,"return_full_text": false} + } +} +``` + ### Reference You can find the full reference of the accepted extra settings for the diff --git a/runtimes/huggingface/mlserver_huggingface/codecs/base.py b/runtimes/huggingface/mlserver_huggingface/codecs/base.py index 4fdfcbb08..3ea4a5be6 100644 --- a/runtimes/huggingface/mlserver_huggingface/codecs/base.py +++ b/runtimes/huggingface/mlserver_huggingface/codecs/base.py @@ -1,3 +1,4 @@ +import logging from typing import Optional, Type, Any, Dict, List, Union, Sequence from mlserver.codecs.utils import ( has_decoded, @@ -170,6 +171,10 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest: @classmethod def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]: + """ + Decode Inference request into dictionary + extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra' + """ values = {} field_codecs = cls._find_decode_codecs(request) for item in request.inputs: @@ -181,6 +186,18 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]: value = get_decoded_or_raw(item) values[item.name] = value + + if request.parameters is not None: + if hasattr(request.parameters, "extra"): + extra = request.parameters.extra + if isinstance(extra, dict): + values.update(extra) + else: + logging.warn( + "Extra parameters is provided with " + + f"value '{extra}' and type '{type(extra)}' \n" + + "Extra parameters cannot be parsed, expected a dictionary." + ) return values diff --git a/runtimes/huggingface/mlserver_huggingface/runtime.py b/runtimes/huggingface/mlserver_huggingface/runtime.py index fed17eb35..c21d4c141 100644 --- a/runtimes/huggingface/mlserver_huggingface/runtime.py +++ b/runtimes/huggingface/mlserver_huggingface/runtime.py @@ -25,12 +25,14 @@ def __init__(self, settings: ModelSettings): async def load(self) -> bool: logger.info(f"Loading model for task '{self.hf_settings.task_name}'...") loop = asyncio.get_running_loop() - self._model = await asyncio.gather(loop.run_in_executor( - None, - load_pipeline_from_settings, - self.hf_settings, - self.settings, - )) + [self._model] = await asyncio.gather( + loop.run_in_executor( + None, + load_pipeline_from_settings, + self.hf_settings, + self.settings, + ) + ) self._merge_metadata() return True diff --git a/runtimes/huggingface/mlserver_huggingface/settings.py b/runtimes/huggingface/mlserver_huggingface/settings.py index b461080f3..7f7599aa6 100644 --- a/runtimes/huggingface/mlserver_huggingface/settings.py +++ b/runtimes/huggingface/mlserver_huggingface/settings.py @@ -83,10 +83,10 @@ class Config: runtime. """ - device: int = -1 + device: Optional[Union[int, str]] = None """ Device in which this pipeline will be loaded (e.g., "cpu", "cuda:1", "mps", - or a GPU ordinal rank like 1). + or a GPU ordinal rank like 1). Default value of None becomes cpu. """ inter_op_threads: Optional[int] = None diff --git a/runtimes/huggingface/tests/test_codecs.py b/runtimes/huggingface/tests/test_codecs.py index 0aead7663..4c3395ca8 100644 --- a/runtimes/huggingface/tests/test_codecs.py +++ b/runtimes/huggingface/tests/test_codecs.py @@ -1,5 +1,5 @@ import pytest - +import logging from mlserver.types import ( InferenceRequest, InferenceResponse, @@ -28,15 +28,89 @@ ] ), {"foo": ["bar1", "bar2"], "foo2": ["var1"]}, - ) + ), + ( + InferenceRequest( + parameters=Parameters(content_type="str", extra={"foo3": "var2"}), + inputs=[ + RequestInput( + name="foo", + datatype="BYTES", + data=["bar1", "bar2"], + shape=[2, 1], + ), + RequestInput( + name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1] + ), + ], + ), + {"foo": ["bar1", "bar2"], "foo2": ["var1"], "foo3": "var2"}, + ), ], ) def test_decode_request(inference_request, expected): payload = HuggingfaceRequestCodec.decode_request(inference_request) - assert payload == expected +@pytest.mark.parametrize( + "inference_request, expected_payload, expected_log_msg", + [ + ( + InferenceRequest( + parameters=Parameters(content_type="str", extra="foo3"), + inputs=[ + RequestInput( + name="foo", + datatype="BYTES", + data=["bar1", "bar2"], + shape=[2, 1], + ), + RequestInput( + name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1] + ), + ], + ), + {"foo": ["bar1", "bar2"]}, + logging.warn( + "Extra parameters is provided with ", + +"value: 'foo3' and type ' \n", + +"Extra parameters cannot be parsed, expected a dictionary.", + ), + ), + ( + InferenceRequest( + parameters=Parameters(content_type="str", extra=1234), + inputs=[ + RequestInput( + name="foo", + datatype="BYTES", + data=["bar1", "bar2"], + shape=[2, 1], + ), + RequestInput( + name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1] + ), + ], + ), + {"foo": ["bar1", "bar2"]}, + logging.warn( + "Extra parameters is provided with " + + "value '1234' and type ' \n", + +"Extra parameters cannot be parsed, expected a dictionary.", + ), + ), + ], +) +def test_decode_request_with_invalid_parameter_extra( + inference_request, expected_payload, expected_log_msg, caplog +): + caplog.set_level(logging.WARN) + payload = HuggingfaceRequestCodec.decode_request(inference_request) + assert payload == expected_payload + assert expected_log_msg in caplog.text + + @pytest.mark.parametrize( "payload, use_bytes, expected", [ diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index d7b60c51d..a96ea8dd4 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -2,7 +2,7 @@ import pytest import torch -from typing import Dict, Optional +from typing import Dict, Optional, Union from optimum.onnxruntime.modeling_ort import ORTModelForQuestionAnswering from transformers.models.distilbert.modeling_distilbert import ( DistilBertForQuestionAnswering, @@ -13,6 +13,9 @@ from mlserver_huggingface.runtime import HuggingFaceRuntime from mlserver_huggingface.settings import HuggingFaceSettings from mlserver_huggingface.common import load_pipeline_from_settings +from mlserver.types import InferenceRequest, RequestInput +from mlserver.types.dataplane import Parameters +from mlserver_huggingface.codecs.base import MultiInputRequestCodec @pytest.mark.parametrize( @@ -169,6 +172,43 @@ def test_pipeline_uses_model_kwargs( assert m.model.dtype == expected +@pytest.mark.parametrize( + "pretrained_model, device, expected", + [ + ( + "hf-internal-testing/tiny-bert-for-token-classification", + None, + torch.device("cpu"), + ), + ( + "hf-internal-testing/tiny-bert-for-token-classification", + -1, + torch.device("cpu"), + ), + ( + "hf-internal-testing/tiny-bert-for-token-classification", + "cpu", + torch.device("cpu"), + ), + ], +) +def test_pipeline_cpu_device_set( + pretrained_model: str, + device: Optional[Union[str, int]], + expected: torch.device, +): + hf_settings = HuggingFaceSettings( + pretrained_model=pretrained_model, task="token-classification", device=device + ) + model_settings = ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + ) + m = load_pipeline_from_settings(hf_settings, model_settings) + + assert m.model.device == expected + + @pytest.mark.parametrize( "pretrained_model, task, input_batch_size, expected_batch_size", [ @@ -210,3 +250,49 @@ def test_pipeline_checks_for_eos_and_pad_token( m = load_pipeline_from_settings(hf_settings, model_settings) assert m._batch_size == expected_batch_size + + +@pytest.mark.parametrize( + "inference_kwargs, expected_num_tokens", + [ + ({"max_new_tokens": 10, "return_full_text": False}, 10), + ({"max_new_tokens": 20, "return_full_text": False}, 20), + ], +) +async def test_pipeline_uses_inference_kwargs( + inference_kwargs: Optional[dict], + expected_num_tokens: int, +): + model_settings = ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters( + extra={ + "pretrained_model": "Maykeye/TinyLLama-v0", + "task": "text-generation", + } + ), + ) + runtime = HuggingFaceRuntime(model_settings) + runtime.ready = await runtime.load() + payload = InferenceRequest( + inputs=[ + RequestInput( + name="args", + shape=[1], + datatype="BYTES", + data=["This is a test"], + ) + ], + parameters=Parameters(extra=inference_kwargs), + ) + tokenizer = runtime._model.tokenizer + + prediction = await runtime.predict(payload) + decoded_prediction = MultiInputRequestCodec.decode_response(prediction) + if isinstance(decoded_prediction, dict): + generated_text = decoded_prediction["output"][0]["generated_text"] + assert isinstance(generated_text, str) + tokenized_generated_text = tokenizer.tokenize(generated_text) + num_predicted_tokens = len(tokenized_generated_text) + assert num_predicted_tokens == expected_num_tokens diff --git a/runtimes/huggingface/tests/test_settings.py b/runtimes/huggingface/tests/test_settings.py index 01a4c71a9..553918954 100644 --- a/runtimes/huggingface/tests/test_settings.py +++ b/runtimes/huggingface/tests/test_settings.py @@ -98,7 +98,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none): pretrained_tokenizer=None, framework=None, optimum_model=False, - device=-1, + device=None, inter_op_threads=None, intra_op_threads=None, ), @@ -113,7 +113,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none): pretrained_tokenizer=None, framework=None, optimum_model=False, - device=-1, + device=None, inter_op_threads=None, intra_op_threads=None, ), @@ -128,7 +128,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none): pretrained_tokenizer=None, framework=None, optimum_model=False, - device=-1, + device=None, inter_op_threads=None, intra_op_threads=None, ), @@ -143,7 +143,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none): pretrained_tokenizer=None, framework=None, optimum_model=False, - device=-1, + device=None, inter_op_threads=None, intra_op_threads=None, ),