diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py index 4915d4706b872..d90e4a4315d5f 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py @@ -26,7 +26,10 @@ from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message - +from pyspark.ml.tests.connect.test_connect_classification import ( + have_torch, + torch_requirement_message, +) if should_test_connect: from pyspark.ml.connect.feature import ( @@ -196,8 +199,10 @@ def test_array_assembler(self): @unittest.skipIf( - not should_test_connect or is_remote_only(), - connect_requirement_message or "pyspark-connect cannot test classic Spark", + not should_test_connect or not have_torch or is_remote_only(), + connect_requirement_message + or torch_requirement_message + or "pyspark-connect cannot test classic Spark", ) class FeatureTests(FeatureTestsMixin, unittest.TestCase): def setUp(self) -> None: