From 534df4578d09b89215868a4ce804b4a219a8dd64 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 12 Dec 2024 14:39:51 -0800 Subject: [PATCH] add stride into KJT pytree (#2587) Summary: # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Differential Revision: D66400821 --- torchrec/sparse/jagged_tensor.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 15952bfa5..fb76fba73 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -3031,13 +3031,17 @@ def dist_init( def _kjt_flatten( t: KeyedJaggedTensor, -) -> Tuple[List[Optional[torch.Tensor]], List[str]]: - return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys +) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], int]]: + # for variable batch scenario, the stride cannot be computed from lengths/len(keys), + # instead, it should be computed from stride_per_key_per_rank, which is not included + # in the flatten spec. The stride is needed for the EBC output shape, so we need to + # store it in the context. + return [getattr(t, a) for a in KeyedJaggedTensor._fields], (t._keys, t.stride()) def _kjt_flatten_with_keys( t: KeyedJaggedTensor, -) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]: +) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], int]]: values, context = _kjt_flatten(t) # pyre can't tell that GetAttrKey implements the KeyEntry protocol return [ # pyre-ignore[7] @@ -3046,9 +3050,11 @@ def _kjt_flatten_with_keys( def _kjt_unflatten( - values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys + values: List[Optional[torch.Tensor]], + context: List[str], # context is (_keys, _stride) ) -> KeyedJaggedTensor: - return KeyedJaggedTensor(context, *values) + keys, stride = context + return KeyedJaggedTensor(keys, *values, stride=stride) def _kjt_flatten_spec(