diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4b5359f0d..ebd2ac6de 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2489,7 +2489,10 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: return split_list def permute( - self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None + self, + indices: List[int], + indices_tensor: Optional[torch.Tensor] = None, + include_inverse_indices: bool = False, ) -> "KeyedJaggedTensor": """ Permutes the KeyedJaggedTensor. @@ -2587,7 +2590,9 @@ def permute( offset_per_key=None, index_per_key=None, jt_dict=None, - inverse_indices=None, + inverse_indices=( + self.inverse_indices_or_none() if include_inverse_indices else None + ), ) return kjt diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 782728a81..31341c7e9 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -1372,16 +1372,27 @@ def test_permute_vb(self) -> None: lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0]) keys = ["index_0", "index_1", "index_2"] stride_per_key_per_rank = [[2], [4], [3]] + inverse_indices = ( + ["index_0", "index_1", "index_2"], + torch.Tensor( + [ + [0, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 3, 3, 2, 2, 1], + [2, 2, 1, 0, 0, 2, 1, 2, 0], + ] + ), + ) jag_tensor = KeyedJaggedTensor.from_lengths_sync( values=values, keys=keys, lengths=lengths, stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=inverse_indices, ) indices = [1, 0, 2] - permuted_jag_tensor = jag_tensor.permute(indices) + permuted_jag_tensor = jag_tensor.permute(indices, include_inverse_indices=True) self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) self.assertEqual( @@ -1401,6 +1412,15 @@ def test_permute_vb(self) -> None: ) ) self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + self.assertEqual( + jag_tensor.inverse_indices()[0], permuted_jag_tensor.inverse_indices()[0] + ) + self.assertTrue( + torch.equal( + jag_tensor.inverse_indices()[1], + permuted_jag_tensor.inverse_indices()[1], + ) + ) def test_permute_vb_duplicate(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])