diff --git a/bindings/python/tests/test_simple.py b/bindings/python/tests/test_simple.py index 38ca7236..cf198aed 100644 --- a/bindings/python/tests/test_simple.py +++ b/bindings/python/tests/test_simple.py @@ -248,7 +248,9 @@ def test_torch_slice(self): tensor = slice_[:2] self.assertEqual(list(tensor.shape), [2, 5]) - torch.testing.assert_close(tensor, A[:2], f"{tensor} != {A[:2]}") + if not torch.allclose(tensor, A[:2]): + print(f"{tensor} != {A[:2]}") + torch.testing.assert_close(tensor, A[:2]) tensor = slice_[:, :2] self.assertEqual(list(tensor.shape), [10, 2])