diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index c88d3b45a..24427cec3 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 import abc -from typing import Dict, Optional +from typing import Dict, List, Optional import torch @@ -150,6 +150,13 @@ def get_weights_list( return torch.cat(weights_list) if weights_list else features.weights_or_none() +@torch.fx.wrap +def get_stride_per_key_per_rank(kjt: KeyedJaggedTensor) -> Optional[List[List[int]]]: + if not kjt.variable_stride_per_key(): + return None + return kjt.stride_per_key_per_rank() + + class PositionWeightedModuleCollection(FeatureProcessorsCollection, CopyMixIn): def __init__( self, max_feature_lengths: Dict[str, int], device: Optional[torch.device] = None @@ -193,6 +200,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: offsets=features.offsets(), stride=features.stride(), length_per_key=features.length_per_key(), + stride_per_key_per_rank=get_stride_per_key_per_rank(features), ) def copy(self, device: torch.device) -> nn.Module: