diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index b97b2473..86e4ea4b 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -40,7 +40,11 @@ def test_rerank(): @pytest.mark.parametrize( "model_name", - [model_name for model_name in CANONICAL_SCORE_VALUES.keys()], + [ + model_desc["model"] + for model_desc in TextCrossEncoder.list_supported_models() + if model_desc["size_in_GB"] < 1 and model_desc["model"] in CANONICAL_SCORE_VALUES.keys() + ], ) def test_batch_rerank(model_name): is_ci = os.getenv("CI")