Skip to content

Commit

Permalink
added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanbo Liu committed Oct 3, 2023
1 parent e1a77d3 commit b6d9ae6
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,74 @@ def test_pipeline_is_initialised_with_correct_model_param(
assert pipeline_call_args.kwargs["model"] == expected


@pytest.mark.parametrize(
"model_kwargs, expected",
[
(None, None),
(
{"load_in_8bit": True},
{"load_in_8bit": True},
),
],
)
@patch("mlserver_huggingface.common._get_pipeline_class")
def test_pipeline_is_initialised_with_correct_model_kwargs(
mock_pipeline_factory,
model_kwargs: Optional[dict],
expected: Optional[str],
):
mock_pipeline_factory.return_value = MagicMock()

hf_settings = HuggingFaceSettings(model_kwargs=model_kwargs)

model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
)

_ = load_pipeline_from_settings(hf_settings, model_settings)

mock_pipeline_factory.return_value.assert_called_once()
pipeline_call_args = mock_pipeline_factory.return_value.call_args

assert pipeline_call_args.kwargs["model_kwargs"] == expected


@pytest.mark.parametrize(
"pretrained_model, model_kwargs, expected_model_kwargs",
[
(
"hf-internal-testing/tiny-bert-for-token-classification",
{"load_in_8bit": True},
True,
),
(
"hf-internal-testing/tiny-bert-for-token-classification",
None,
False,
),
],
)
def test_pipeline_uses_model_kwargs(
pretrained_model: str,
model_kwargs: Optional[dict],
expected_model_kwargs: Optional[str],
):
hf_settings = HuggingFaceSettings(
pretrained_model=pretrained_model,
task="token-classification",
model_kwargs=model_kwargs,
)
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
)

m = load_pipeline_from_settings(hf_settings, model_settings)

assert m.model.is_loaded_in_8bit == expected_model_kwargs


@pytest.mark.parametrize(
"pretrained_model, task, input_batch_size, expected_batch_size",
[
Expand Down

0 comments on commit b6d9ae6

Please sign in to comment.