diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 15952bfa5..fb76fba73 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -3031,13 +3031,17 @@ def dist_init( def _kjt_flatten( t: KeyedJaggedTensor, -) -> Tuple[List[Optional[torch.Tensor]], List[str]]: - return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys +) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], int]]: + # for variable batch scenario, the stride cannot be computed from lengths/len(keys), + # instead, it should be computed from stride_per_key_per_rank, which is not included + # in the flatten spec. The stride is needed for the EBC output shape, so we need to + # store it in the context. + return [getattr(t, a) for a in KeyedJaggedTensor._fields], (t._keys, t.stride()) def _kjt_flatten_with_keys( t: KeyedJaggedTensor, -) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]: +) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], int]]: values, context = _kjt_flatten(t) # pyre can't tell that GetAttrKey implements the KeyEntry protocol return [ # pyre-ignore[7] @@ -3046,9 +3050,11 @@ def _kjt_flatten_with_keys( def _kjt_unflatten( - values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys + values: List[Optional[torch.Tensor]], + context: List[str], # context is (_keys, _stride) ) -> KeyedJaggedTensor: - return KeyedJaggedTensor(context, *values) + keys, stride = context + return KeyedJaggedTensor(keys, *values, stride=stride) def _kjt_flatten_spec(