Skip to content

Commit

Permalink
added model_kwargs to huggingface model (#1417)
Browse files Browse the repository at this point in the history
Co-authored-by: Nanbo Liu <[email protected]>
  • Loading branch information
nanbo-liu and Nanbo Liu authored Oct 18, 2023
1 parent b4374da commit 02e95d1
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
2 changes: 1 addition & 1 deletion runtimes/huggingface/mlserver_huggingface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def load_pipeline_from_settings(
hf_settings: HuggingFaceSettings, settings: ModelSettings
) -> Pipeline:
pipeline = _get_pipeline_class(hf_settings)

batch_size = 1
if settings.max_batch_size:
batch_size = settings.max_batch_size
Expand Down Expand Up @@ -54,6 +53,7 @@ def load_pipeline_from_settings(
hf_pipeline = pipeline(
hf_settings.task_name,
model=model,
model_kwargs=hf_settings.model_kwargs,
tokenizer=tokenizer,
device=hf_settings.device,
batch_size=batch_size,
Expand Down
5 changes: 4 additions & 1 deletion runtimes/huggingface/mlserver_huggingface/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class Config:
"""
Name of the model that should be loaded in the pipeline.
"""

model_kwargs: Optional[dict] = None
"""
model kwargs that should be loaded in the pipeline.
"""
pretrained_tokenizer: Optional[str] = None
"""
Name of the tokenizer that should be loaded in the pipeline.
Expand Down
67 changes: 66 additions & 1 deletion runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch

import pytest

import torch
from typing import Dict, Optional
from optimum.onnxruntime.modeling_ort import ORTModelForQuestionAnswering
from transformers.models.distilbert.modeling_distilbert import (
Expand Down Expand Up @@ -104,6 +104,71 @@ def test_pipeline_is_initialised_with_correct_model_param(
assert pipeline_call_args.kwargs["model"] == expected


@pytest.mark.parametrize(
"model_kwargs, expected",
[
(None, None),
(
{"load_in_8bit": True},
{"load_in_8bit": True},
),
],
)
@patch("mlserver_huggingface.common._get_pipeline_class")
def test_pipeline_is_initialised_with_correct_model_kwargs(
mock_pipeline_factory,
model_kwargs: Optional[dict],
expected: Optional[str],
):
mock_pipeline_factory.return_value = MagicMock()

hf_settings = HuggingFaceSettings(model_kwargs=model_kwargs)
model_params = ModelParameters(uri="dummy_uri")
model_settings = ModelSettings(
name="foo", implementation=HuggingFaceRuntime, parameters=model_params
)
_ = load_pipeline_from_settings(hf_settings, model_settings)

mock_pipeline_factory.return_value.assert_called_once()
pipeline_call_args = mock_pipeline_factory.return_value.call_args

assert pipeline_call_args.kwargs["model_kwargs"] == expected


@pytest.mark.parametrize(
"pretrained_model, model_kwargs, expected",
[
(
"hf-internal-testing/tiny-bert-for-token-classification",
{"torch_dtype": torch.float16},
torch.float16,
),
(
"hf-internal-testing/tiny-bert-for-token-classification",
None,
torch.float32,
),
],
)
def test_pipeline_uses_model_kwargs(
pretrained_model: str,
model_kwargs: Optional[dict],
expected: torch.dtype,
):
hf_settings = HuggingFaceSettings(
pretrained_model=pretrained_model,
task="token-classification",
model_kwargs=model_kwargs,
)
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
)
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m.model.dtype == expected


@pytest.mark.parametrize(
"pretrained_model, task, input_batch_size, expected_batch_size",
[
Expand Down

0 comments on commit 02e95d1

Please sign in to comment.