diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 7276457fb..1af1f9046 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -631,10 +631,9 @@ def _maybe_compute_stride_kjt_scripted( def _length_per_key_from_stride_per_key( lengths: torch.Tensor, stride_per_key: List[int] ) -> List[int]: - stride_per_key_offsets = _to_offsets( - _pin_and_move(torch.tensor(stride_per_key, dtype=torch.int32), lengths.device) - ) - return torch.ops.fbgemm.segment_sum_csr(1, stride_per_key_offsets, lengths).tolist() + return [ + int(torch.sum(chunk).item()) for chunk in torch.split(lengths, stride_per_key) + ] def _maybe_compute_length_per_key(