diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b269b7820..7e71acb19 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -993,12 +993,6 @@ def _maybe_compute_length_per_key( total_length = values.numel() _length = [total_length // len(keys)] * len(keys) _length[0] += total_length % len(keys) - elif len(keys) and offsets is not None and len(offsets) > 0: - _length: List[int] = ( - _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) - if variable_stride_per_key - else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() - ) elif len(keys) and lengths is not None: _length: List[int] = ( _length_per_key_from_stride_per_key(lengths, stride_per_key) @@ -1011,6 +1005,12 @@ def _maybe_compute_length_per_key( else [0] * len(keys) ) ) + elif len(keys) and offsets is not None and len(offsets) > 0: + _length: List[int] = ( + _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) + if variable_stride_per_key + else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() + ) else: _length: List[int] = [] length_per_key = _length