From 504642ac8969bba75c5aef297b992a4ce9aef07b Mon Sep 17 00:00:00 2001 From: Albert Chen Date: Mon, 6 Jan 2025 08:14:16 -0800 Subject: [PATCH] Add VBE support for PositionWeightedModuleCollection (#2647) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2647 As titled, we have seen wins from position encoding in modeling and would like to leverage PositionWeightedModuleCollection to reduce the cost https://fb.workplace.com/groups/204375858345877/permalink/884618276988295/ I have a stack locally that show NE equivalence between PositionWeightedModuleCollection and position encoding in modeling {F1974047979} Given IG has adopted VBE, I am adding necessary plumbing for VBE in PositionWeightedModuleCollection **Diffs will land after code freeze but publish first to get the review underway** Reviewed By: TroyGarden Differential Revision: D67526005 fbshipit-source-id: bf245d87f4e91998bcd31e2c79f120ec22736ab4 --- torchrec/modules/feature_processor_.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index c88d3b45a..24427cec3 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 import abc -from typing import Dict, Optional +from typing import Dict, List, Optional import torch @@ -150,6 +150,13 @@ def get_weights_list( return torch.cat(weights_list) if weights_list else features.weights_or_none() +@torch.fx.wrap +def get_stride_per_key_per_rank(kjt: KeyedJaggedTensor) -> Optional[List[List[int]]]: + if not kjt.variable_stride_per_key(): + return None + return kjt.stride_per_key_per_rank() + + class PositionWeightedModuleCollection(FeatureProcessorsCollection, CopyMixIn): def __init__( self, max_feature_lengths: Dict[str, int], device: Optional[torch.device] = None @@ -193,6 +200,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: offsets=features.offsets(), stride=features.stride(), length_per_key=features.length_per_key(), + stride_per_key_per_rank=get_stride_per_key_per_rank(features), ) def copy(self, device: torch.device) -> nn.Module: