From b6d9ae6b8cf7ec3ecdfbce74d14db5dd5b2de3cf Mon Sep 17 00:00:00 2001 From: Nanbo Liu Date: Tue, 3 Oct 2023 17:22:26 +0000 Subject: [PATCH] added unit tests --- runtimes/huggingface/tests/test_common.py | 68 +++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index df58069d7..9a33c1852 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -104,6 +104,74 @@ def test_pipeline_is_initialised_with_correct_model_param( assert pipeline_call_args.kwargs["model"] == expected +@pytest.mark.parametrize( + "model_kwargs, expected", + [ + (None, None), + ( + {"load_in_8bit": True}, + {"load_in_8bit": True}, + ), + ], +) +@patch("mlserver_huggingface.common._get_pipeline_class") +def test_pipeline_is_initialised_with_correct_model_kwargs( + mock_pipeline_factory, + model_kwargs: Optional[dict], + expected: Optional[str], +): + mock_pipeline_factory.return_value = MagicMock() + + hf_settings = HuggingFaceSettings(model_kwargs=model_kwargs) + + model_settings = ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + ) + + _ = load_pipeline_from_settings(hf_settings, model_settings) + + mock_pipeline_factory.return_value.assert_called_once() + pipeline_call_args = mock_pipeline_factory.return_value.call_args + + assert pipeline_call_args.kwargs["model_kwargs"] == expected + + +@pytest.mark.parametrize( + "pretrained_model, model_kwargs, expected_model_kwargs", + [ + ( + "hf-internal-testing/tiny-bert-for-token-classification", + {"load_in_8bit": True}, + True, + ), + ( + "hf-internal-testing/tiny-bert-for-token-classification", + None, + False, + ), + ], +) +def test_pipeline_uses_model_kwargs( + pretrained_model: str, + model_kwargs: Optional[dict], + expected_model_kwargs: Optional[str], +): + hf_settings = HuggingFaceSettings( + pretrained_model=pretrained_model, + task="token-classification", + model_kwargs=model_kwargs, + ) + model_settings = ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + ) + + m = load_pipeline_from_settings(hf_settings, model_settings) + + assert m.model.is_loaded_in_8bit == expected_model_kwargs + + @pytest.mark.parametrize( "pretrained_model, task, input_batch_size, expected_batch_size", [