Skip to content

Commit

Permalink
use inverse indices in KJT permute
Browse files Browse the repository at this point in the history
Summary: calling a kjt.permute() on a VBE KJT makes the output KJT no longer VBE. this diff fixes it such that the output KJT is VBE.

Reviewed By: joshuadeng

Differential Revision: D65621958
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Nov 7, 2024
1 parent 42c512c commit a8c47ee
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
9 changes: 7 additions & 2 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,7 +2489,10 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
return split_list

def permute(
self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None
self,
indices: List[int],
indices_tensor: Optional[torch.Tensor] = None,
include_inverse_indices: bool = False,
) -> "KeyedJaggedTensor":
"""
Permutes the KeyedJaggedTensor.
Expand Down Expand Up @@ -2587,7 +2590,9 @@ def permute(
offset_per_key=None,
index_per_key=None,
jt_dict=None,
inverse_indices=None,
inverse_indices=(
self.inverse_indices_or_none() if include_inverse_indices else None
),
)
return kjt

Expand Down
22 changes: 21 additions & 1 deletion torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,16 +1372,27 @@ def test_permute_vb(self) -> None:
lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0])
keys = ["index_0", "index_1", "index_2"]
stride_per_key_per_rank = [[2], [4], [3]]
inverse_indices = (
["index_0", "index_1", "index_2"],
torch.Tensor(
[
[0, 0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 3, 3, 2, 2, 1],
[2, 2, 1, 0, 0, 2, 1, 2, 0],
]
),
)

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
stride_per_key_per_rank=stride_per_key_per_rank,
inverse_indices=inverse_indices,
)

indices = [1, 0, 2]
permuted_jag_tensor = jag_tensor.permute(indices)
permuted_jag_tensor = jag_tensor.permute(indices, include_inverse_indices=True)

self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
self.assertEqual(
Expand All @@ -1401,6 +1412,15 @@ def test_permute_vb(self) -> None:
)
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
self.assertEqual(
jag_tensor.inverse_indices()[0], permuted_jag_tensor.inverse_indices()[0]
)
self.assertTrue(
torch.equal(
jag_tensor.inverse_indices()[1],
permuted_jag_tensor.inverse_indices()[1],
)
)

def test_permute_vb_duplicate(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
Expand Down

0 comments on commit a8c47ee

Please sign in to comment.