From e2abba18ea9370329ce6894a4eb3e98ad8f11cb6 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:15:50 +0800 Subject: [PATCH] Skip softmax BF16 test for ROCm (#21162) ### Description Skip softmax BF16 test for ROCm, because BFloat16 is unsupported by MIOpen, and `torch.cuda.is_available()` also returns `True` for ROCm. --- .../test/python/orttraining_test_ortmodule_onnx_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 88735ff18515e..35c5b736bd962 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -148,8 +148,8 @@ def test_onnx_ops(self): @unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support") def test_softmax_bf16_large(self): - if not torch.cuda.is_available(): - # only test bf16 on cuda + if torch.version.cuda is None: + # Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen. return class Model(torch.nn.Module): @@ -175,7 +175,7 @@ def forward(self, input): data_ort.requires_grad = True ort_res = ort_model(input=data_ort) ort_res.backward(gradient=init_grad) - # compara result + # compare result torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)