Skip to content

Commit

Permalink
supporting more device types, and unpacking model loading bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
NickMcKillip committed Mar 5, 2024
1 parent 0c879cf commit 9344b1e
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 158 deletions.
18 changes: 0 additions & 18 deletions runtimes/huggingface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,6 @@ Models in the HuggingFace hub can be loaded by specifying their name in `paramet
If `parameters.extra.pretrained_model` is specified, it takes precedence over `parameters.uri`.
````

#### Model Inference
Model inference is done by HuggingFace pipeline. It allows users to run inference on a batch of inputs. Extra inference kwargs can be kept in `parameters.extra`.
```{code-block} json
{
"inputs": [
{
"name": "text_inputs",
"shape": [1],
"datatype": "BYTES",
"data": ["My kitten's name is JoJo,","Tell me a story:"],
}
],
"parameters": {
"extra":{"max_new_tokens": 200,"return_full_text": false}
}
}
```

### Reference

You can find the full reference of the accepted extra settings for the
Expand Down
17 changes: 0 additions & 17 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import Optional, Type, Any, Dict, List, Union, Sequence
from mlserver.codecs.utils import (
has_decoded,
Expand Down Expand Up @@ -171,10 +170,6 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest:

@classmethod
def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
"""
Decode Inference request into dictionary
extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra'
"""
values = {}
field_codecs = cls._find_decode_codecs(request)
for item in request.inputs:
Expand All @@ -186,18 +181,6 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:

value = get_decoded_or_raw(item)
values[item.name] = value

if request.parameters is not None:
if hasattr(request.parameters, "extra"):
extra = request.parameters.extra
if isinstance(extra, dict):
values.update(extra)
else:
logging.warn(
"Extra parameters is provided with "
+ f"value '{extra}' and type '{type(extra)}' \n"
+ "Extra parameters cannot be parsed, expected a dictionary."
)
return values


Expand Down
80 changes: 3 additions & 77 deletions runtimes/huggingface/tests/test_codecs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import logging

from mlserver.types import (
InferenceRequest,
InferenceResponse,
Expand Down Expand Up @@ -28,87 +28,13 @@
]
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"]},
),
(
InferenceRequest(
parameters=Parameters(content_type="str", extra={"foo3": "var2"}),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"], "foo2": ["var1"], "foo3": "var2"},
),
)
],
)
def test_decode_request(inference_request, expected):
payload = HuggingfaceRequestCodec.decode_request(inference_request)
assert payload == expected


@pytest.mark.parametrize(
"inference_request, expected_payload, expected_log_msg",
[
(
InferenceRequest(
parameters=Parameters(content_type="str", extra="foo3"),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
"Extra parameters is provided with ",
+"value: 'foo3' and type '<class 'str'> \n",
+"Extra parameters cannot be parsed, expected a dictionary.",
),
),
(
InferenceRequest(
parameters=Parameters(content_type="str", extra=1234),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
"Extra parameters is provided with "
+ "value '1234' and type '<class 'int'> \n",
+"Extra parameters cannot be parsed, expected a dictionary.",
),
),
],
)
def test_decode_request_with_invalid_parameter_extra(
inference_request, expected_payload, expected_log_msg, caplog
):
caplog.set_level(logging.WARN)
payload = HuggingfaceRequestCodec.decode_request(inference_request)
assert payload == expected_payload
assert expected_log_msg in caplog.text
assert payload == expected


@pytest.mark.parametrize(
Expand Down
46 changes: 0 additions & 46 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,49 +250,3 @@ def test_pipeline_checks_for_eos_and_pad_token(
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size


@pytest.mark.parametrize(
"inference_kwargs, expected_num_tokens",
[
({"max_new_tokens": 10, "return_full_text": False}, 10),
({"max_new_tokens": 20, "return_full_text": False}, 20),
],
)
async def test_pipeline_uses_inference_kwargs(
inference_kwargs: Optional[dict],
expected_num_tokens: int,
):
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={
"pretrained_model": "Maykeye/TinyLLama-v0",
"task": "text-generation",
}
),
)
runtime = HuggingFaceRuntime(model_settings)
runtime.ready = await runtime.load()
payload = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs),
)
tokenizer = runtime._model.tokenizer

prediction = await runtime.predict(payload)
decoded_prediction = MultiInputRequestCodec.decode_response(prediction)
if isinstance(decoded_prediction, dict):
generated_text = decoded_prediction["output"][0]["generated_text"]
assert isinstance(generated_text, str)
tokenized_generated_text = tokenizer.tokenize(generated_text)
num_predicted_tokens = len(tokenized_generated_text)
assert num_predicted_tokens == expected_num_tokens

0 comments on commit 9344b1e

Please sign in to comment.