diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 1af1f9046..6e6afbe73 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -631,9 +631,9 @@ def _maybe_compute_stride_kjt_scripted( def _length_per_key_from_stride_per_key( lengths: torch.Tensor, stride_per_key: List[int] ) -> List[int]: - return [ - int(torch.sum(chunk).item()) for chunk in torch.split(lengths, stride_per_key) - ] + return torch.cat( + [torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)] + ).tolist() def _maybe_compute_length_per_key(