Skip to content

Commit

Permalink
fix hypothesis strategy that skips entire test without CUDA (#2690)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2690

# context
* original implementation will skip the entire test set if there is no cuda available
* the actual intention is to loop all devices (cpu, meta, cuda) and only skip cuda if not available.

Reviewed By: dstaay-fb

Differential Revision: D68373224

fbshipit-source-id: 28c8b12a61213ebfc794b07f51e5eff77f13938a
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jan 21, 2025
1 parent d0bf444 commit 53752ea
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions torchrec/sparse/tests/test_tensor_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from torchrec.sparse.tensor_dict import maybe_td_to_kjt


class TestTensorDIct(unittest.TestCase):
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
class TestTensorDict(unittest.TestCase):
# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
@given(
device_str=st.sampled_from(
["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
)
)
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_kjt_input(self, device_str: str) -> None:
device = torch.device(device_str)
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
Expand All @@ -36,13 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
features = maybe_td_to_kjt(kjt)
self.assertEqual(features, kjt)

@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
@given(
device_str=st.sampled_from(
["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
)
)
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_td_kjt(self, device_str: str) -> None:
device = torch.device(device_str)
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
Expand Down

0 comments on commit 53752ea

Please sign in to comment.