Skip to content

Commit

Permalink
add in george and nanbos PRs and fix small bug with _model
Browse files Browse the repository at this point in the history
  • Loading branch information
geodavic authored and NickMcKillip committed Mar 4, 2024
1 parent af1086a commit 0c879cf
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 16 deletions.
18 changes: 18 additions & 0 deletions runtimes/huggingface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ Models in the HuggingFace hub can be loaded by specifying their name in `paramet
If `parameters.extra.pretrained_model` is specified, it takes precedence over `parameters.uri`.
````

#### Model Inference
Model inference is done by HuggingFace pipeline. It allows users to run inference on a batch of inputs. Extra inference kwargs can be kept in `parameters.extra`.
```{code-block} json
{
"inputs": [
{
"name": "text_inputs",
"shape": [1],
"datatype": "BYTES",
"data": ["My kitten's name is JoJo,","Tell me a story:"],
}
],
"parameters": {
"extra":{"max_new_tokens": 200,"return_full_text": false}
}
}
```

### Reference

You can find the full reference of the accepted extra settings for the
Expand Down
17 changes: 17 additions & 0 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional, Type, Any, Dict, List, Union, Sequence
from mlserver.codecs.utils import (
has_decoded,
Expand Down Expand Up @@ -170,6 +171,10 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest:

@classmethod
def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
"""
Decode Inference request into dictionary
extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra'
"""
values = {}
field_codecs = cls._find_decode_codecs(request)
for item in request.inputs:
Expand All @@ -181,6 +186,18 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:

value = get_decoded_or_raw(item)
values[item.name] = value

if request.parameters is not None:
if hasattr(request.parameters, "extra"):
extra = request.parameters.extra
if isinstance(extra, dict):
values.update(extra)
else:
logging.warn(
"Extra parameters is provided with "
+ f"value '{extra}' and type '{type(extra)}' \n"
+ "Extra parameters cannot be parsed, expected a dictionary."
)
return values


Expand Down
14 changes: 8 additions & 6 deletions runtimes/huggingface/mlserver_huggingface/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ def __init__(self, settings: ModelSettings):
async def load(self) -> bool:
logger.info(f"Loading model for task '{self.hf_settings.task_name}'...")
loop = asyncio.get_running_loop()
self._model = await asyncio.gather(loop.run_in_executor(
None,
load_pipeline_from_settings,
self.hf_settings,
self.settings,
))
[self._model] = await asyncio.gather(
loop.run_in_executor(
None,
load_pipeline_from_settings,
self.hf_settings,
self.settings,
)
)
self._merge_metadata()
return True

Expand Down
4 changes: 2 additions & 2 deletions runtimes/huggingface/mlserver_huggingface/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ class Config:
runtime.
"""

device: int = -1
device: Optional[Union[int, str]] = None
"""
Device in which this pipeline will be loaded (e.g., "cpu", "cuda:1", "mps",
or a GPU ordinal rank like 1).
or a GPU ordinal rank like 1). Default value of None becomes cpu.
"""

inter_op_threads: Optional[int] = None
Expand Down
80 changes: 77 additions & 3 deletions runtimes/huggingface/tests/test_codecs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

import logging
from mlserver.types import (
InferenceRequest,
InferenceResponse,
Expand Down Expand Up @@ -28,15 +28,89 @@
]
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"]},
)
),
(
InferenceRequest(
parameters=Parameters(content_type="str", extra={"foo3": "var2"}),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"], "foo3": "var2"},
),
],
)
def test_decode_request(inference_request, expected):
payload = HuggingfaceRequestCodec.decode_request(inference_request)

assert payload == expected


@pytest.mark.parametrize(
"inference_request, expected_payload, expected_log_msg",
[
(
InferenceRequest(
parameters=Parameters(content_type="str", extra="foo3"),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
"Extra parameters is provided with ",
+"value: 'foo3' and type '<class 'str'> \n",
+"Extra parameters cannot be parsed, expected a dictionary.",
),
),
(
InferenceRequest(
parameters=Parameters(content_type="str", extra=1234),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
"Extra parameters is provided with "
+ "value '1234' and type '<class 'int'> \n",
+"Extra parameters cannot be parsed, expected a dictionary.",
),
),
],
)
def test_decode_request_with_invalid_parameter_extra(
inference_request, expected_payload, expected_log_msg, caplog
):
caplog.set_level(logging.WARN)
payload = HuggingfaceRequestCodec.decode_request(inference_request)
assert payload == expected_payload
assert expected_log_msg in caplog.text


@pytest.mark.parametrize(
"payload, use_bytes, expected",
[
Expand Down
88 changes: 87 additions & 1 deletion runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
import torch
from typing import Dict, Optional
from typing import Dict, Optional, Union
from optimum.onnxruntime.modeling_ort import ORTModelForQuestionAnswering
from transformers.models.distilbert.modeling_distilbert import (
DistilBertForQuestionAnswering,
Expand All @@ -13,6 +13,9 @@
from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import HuggingFaceSettings
from mlserver_huggingface.common import load_pipeline_from_settings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.types.dataplane import Parameters
from mlserver_huggingface.codecs.base import MultiInputRequestCodec


@pytest.mark.parametrize(
Expand Down Expand Up @@ -169,6 +172,43 @@ def test_pipeline_uses_model_kwargs(
assert m.model.dtype == expected


@pytest.mark.parametrize(
"pretrained_model, device, expected",
[
(
"hf-internal-testing/tiny-bert-for-token-classification",
None,
torch.device("cpu"),
),
(
"hf-internal-testing/tiny-bert-for-token-classification",
-1,
torch.device("cpu"),
),
(
"hf-internal-testing/tiny-bert-for-token-classification",
"cpu",
torch.device("cpu"),
),
],
)
def test_pipeline_cpu_device_set(
pretrained_model: str,
device: Optional[Union[str, int]],
expected: torch.device,
):
hf_settings = HuggingFaceSettings(
pretrained_model=pretrained_model, task="token-classification", device=device
)
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
)
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m.model.device == expected


@pytest.mark.parametrize(
"pretrained_model, task, input_batch_size, expected_batch_size",
[
Expand Down Expand Up @@ -210,3 +250,49 @@ def test_pipeline_checks_for_eos_and_pad_token(
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size


@pytest.mark.parametrize(
"inference_kwargs, expected_num_tokens",
[
({"max_new_tokens": 10, "return_full_text": False}, 10),
({"max_new_tokens": 20, "return_full_text": False}, 20),
],
)
async def test_pipeline_uses_inference_kwargs(
inference_kwargs: Optional[dict],
expected_num_tokens: int,
):
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={
"pretrained_model": "Maykeye/TinyLLama-v0",
"task": "text-generation",
}
),
)
runtime = HuggingFaceRuntime(model_settings)
runtime.ready = await runtime.load()
payload = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs),
)
tokenizer = runtime._model.tokenizer

prediction = await runtime.predict(payload)
decoded_prediction = MultiInputRequestCodec.decode_response(prediction)
if isinstance(decoded_prediction, dict):
generated_text = decoded_prediction["output"][0]["generated_text"]
assert isinstance(generated_text, str)
tokenized_generated_text = tokenizer.tokenize(generated_text)
num_predicted_tokens = len(tokenized_generated_text)
assert num_predicted_tokens == expected_num_tokens
8 changes: 4 additions & 4 deletions runtimes/huggingface/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none):
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
device=None,
inter_op_threads=None,
intra_op_threads=None,
),
Expand All @@ -113,7 +113,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none):
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
device=None,
inter_op_threads=None,
intra_op_threads=None,
),
Expand All @@ -128,7 +128,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none):
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
device=None,
inter_op_threads=None,
intra_op_threads=None,
),
Expand All @@ -143,7 +143,7 @@ def test_merge_huggingface_settings_extra_raises(model_settings_extra_none):
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
device=None,
inter_op_threads=None,
intra_op_threads=None,
),
Expand Down

0 comments on commit 0c879cf

Please sign in to comment.