Skip to content

Commit

Permalink
Disable cudnn_fusion_test on A100.
Browse files Browse the repository at this point in the history
This test only seems to pass on H100 at the moment.

PiperOrigin-RevId: 681070398
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 1, 2024
1 parent 28098be commit 1260ebb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1496,9 +1496,8 @@ jax_py_test(
jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
enable_backends = ["gpu"],
enable_backends = [],
enable_configs = [
"gpu_a100",
"gpu_h100",
],
tags = ["multiaccelerator"],
Expand Down
4 changes: 2 additions & 2 deletions tests/cudnn_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
class CudnnFusionTest(jtu.JaxTestCase):
def setUp(self):
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on >= sm80 GPUs")
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Only works on >= sm90 GPUs")
super().setUp()

@parameterized.parameters(["", "pmap"])
Expand Down

0 comments on commit 1260ebb

Please sign in to comment.