diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b269b7820..773f219cf 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -994,11 +994,19 @@ def _maybe_compute_length_per_key( _length = [total_length // len(keys)] * len(keys) _length[0] += total_length % len(keys) elif len(keys) and offsets is not None and len(offsets) > 0: - _length: List[int] = ( - _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) - if variable_stride_per_key - else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() - ) + if variable_stride_per_key: + _length: List[int] = _length_per_key_from_stride_per_key( + torch.diff(offsets), stride_per_key + ) + else: + if not torch.jit.is_scripting(): + torch._check(stride > 0) + strided_offsets = offsets[::stride] + if not torch.jit.is_scripting(): + torch._check( + strided_offsets.numel() > 1 + ) # len(offsets) = len(lengths) + 1 >= 2 + _length: List[int] = torch.diff(strided_offsets).tolist() elif len(keys) and lengths is not None: _length: List[int] = ( _length_per_key_from_stride_per_key(lengths, stride_per_key)