diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index 5d503aabda59..d2e046389b56 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -17,10 +17,22 @@ def to_on_disk_numpy(test_dir, name, t): return path +def _skip_condition_cached_feature(): + return (F._default_context_str != "gpu") or ( + torch.cuda.get_device_capability()[0] < 7 + ) + + +def _reason_to_skip_cached_feature(): + if F._default_context_str != "gpu": + return "GPUCachedFeature tests are available only when testing the GPU backend." + + return "GPUCachedFeature requires a Volta or later generation NVIDIA GPU." + + @unittest.skipIf( - F._default_context_str != "gpu" - or torch.cuda.get_device_capability()[0] < 7, - reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.", + _skip_condition_cached_feature(), + reason=_reason_to_skip_cached_feature(), ) @pytest.mark.parametrize( "dtype", @@ -116,9 +128,8 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): @unittest.skipIf( - F._default_context_str != "gpu" - or torch.cuda.get_device_capability()[0] < 7, - reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.", + _skip_condition_cached_feature(), + reason=_reason_to_skip_cached_feature(), ) @pytest.mark.parametrize( "dtype", @@ -155,9 +166,8 @@ def test_gpu_cached_feature_read_async(dtype, pin_memory): @unittest.skipIf( - F._default_context_str != "gpu" - or torch.cuda.get_device_capability()[0] < 7, - reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.", + _skip_condition_cached_feature(), + reason=_reason_to_skip_cached_feature(), ) @unittest.skipIf( not torch.ops.graphbolt.detect_io_uring(), diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py b/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py index 6c6d242f14c1..dee0fb974e4d 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py @@ -11,7 +11,7 @@ @unittest.skipIf( F._default_context_str != "gpu" or torch.cuda.get_device_capability()[0] < 7, - reason="GPUCachedFeature tests are available only on GPU." + reason="GPUCachedFeature tests are available only when testing the GPU backend." if F._default_context_str != "gpu" else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU.", ) diff --git a/tests/python/pytorch/graphbolt/impl/test_hetero_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_hetero_cached_feature.py index 620999e2a6a4..f3caf34da38e 100644 --- a/tests/python/pytorch/graphbolt/impl/test_hetero_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_hetero_cached_feature.py @@ -15,7 +15,9 @@ def test_hetero_cached_feature(cached_feature_type): or torch.cuda.get_device_capability()[0] < 7 ): pytest.skip( - "GPUCachedFeature requires a Volta or later generation NVIDIA GPU." + "GPUCachedFeature tests are available only when testing the GPU backend." + if F._default_context_str != "gpu" + else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU." ) device = F.ctx() if cached_feature_type == gb.gpu_cached_feature else None pin_memory = cached_feature_type == gb.gpu_cached_feature