From 41856eb806e11907a9052915a4a3951cb614f1a4 Mon Sep 17 00:00:00 2001 From: Nanbo Liu Date: Tue, 3 Oct 2023 17:25:45 +0000 Subject: [PATCH] fixed typo --- runtimes/huggingface/tests/test_common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index 9a33c1852..ecc9ab869 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -138,7 +138,7 @@ def test_pipeline_is_initialised_with_correct_model_kwargs( @pytest.mark.parametrize( - "pretrained_model, model_kwargs, expected_model_kwargs", + "pretrained_model, model_kwargs, expected", [ ( "hf-internal-testing/tiny-bert-for-token-classification", @@ -155,7 +155,7 @@ def test_pipeline_is_initialised_with_correct_model_kwargs( def test_pipeline_uses_model_kwargs( pretrained_model: str, model_kwargs: Optional[dict], - expected_model_kwargs: Optional[str], + expected: bool, ): hf_settings = HuggingFaceSettings( pretrained_model=pretrained_model, @@ -169,7 +169,7 @@ def test_pipeline_uses_model_kwargs( m = load_pipeline_from_settings(hf_settings, model_settings) - assert m.model.is_loaded_in_8bit == expected_model_kwargs + assert m.model.is_loaded_in_8bit == expected @pytest.mark.parametrize(