diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b62647228..a92cc7ca8 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -188,6 +188,19 @@ def permute_multi_embedding( return permuted_values +@torch.fx.wrap +def keyed_tensor_regroup( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + return torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] @@ -2708,11 +2721,7 @@ def to_dict(self) -> Dict[str, torch.Tensor]: def regroup( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - # Fast path, one-to-one correspondence between keyed_tensors and groups - if _all_keys_used_once(keyed_tensors, groups) is True: - return _fbgemm_permute_pooled_embs(keyed_tensors, groups) - else: # Fallback to slow path otherwise - return _regroup_keyed_tensors(keyed_tensors, groups) + return permute_multi_embedding(keyed_tensors, groups) @staticmethod def regroup_as_dict( diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index aa426e448..63de780c0 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -18,7 +18,9 @@ from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, _regroup_keyed_tensors, + keyed_tensor_regroup, KeyedJaggedTensor, KeyedTensor, permute_multi_embedding, @@ -213,7 +215,7 @@ def main( ).float() groups = build_groups(kts, n_groups, duplicates=duplicates) bench( - "_regroup_keyed_tenors" + dup, + "python_native" + dup, labels, batch_size, n_dense + n_sparse, @@ -224,7 +226,7 @@ def main( profile, ) bench( - "KeyedTensor.regroup" + dup, + "[Prod] KeyedTensor.regroup" + dup, labels, batch_size, n_dense + n_sparse, @@ -235,7 +237,7 @@ def main( profile, ) bench( - "KTRegroupAsDict" + dup, + "[Module] KTRegroupAsDict" + dup, labels, batch_size, n_dense + n_sparse, @@ -248,7 +250,7 @@ def main( profile, ) bench( - "permute_multi_embs" + dup, + "[2 Ops] permute_multi_embs" + dup, labels, batch_size, n_dense + n_sparse, @@ -258,6 +260,29 @@ def main( {"keyed_tensors": kts, "groups": groups}, profile, ) + bench( + "[1 Op] KT_regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + keyed_tensor_regroup, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + if not duplicates: + bench( + "[Old Prod] permute_pooled_embs" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _fbgemm_permute_pooled_embs, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) if __name__ == "__main__":