diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 4e4687b04..07eb0c19a 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -134,7 +134,6 @@ def test_rw(self, weight_dtype: torch.dtype) -> None: ) @settings(max_examples=4, deadline=None) def test_cw(self, test_permute: bool, weight_dtype: torch.dtype) -> None: - test_permute = False num_embeddings = 64 emb_dim = 512 emb_dim_4 = emb_dim // 4