diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index a480f602b..aa6ad95bc 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -155,7 +155,7 @@ def test_pipeline_is_initialised_with_correct_model_kwargs( def test_pipeline_uses_model_kwargs( pretrained_model: str, model_kwargs: Optional[dict], - expected: bool, + expected: torch.dtype, ): hf_settings = HuggingFaceSettings( pretrained_model=pretrained_model,