From 86f198471a58e410219ecc8fe2d699e176c324c8 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Sun, 7 Jul 2024 03:06:58 -0700 Subject: [PATCH] use new op in KTRegroupAsDict module Summary: # context * adding PackedTensorAccessor for passing the index tensor to kernel * GPU trace reading slows down from 2.20ms to 2.26ms # traces * previous ~4.90s {F1747994738} * after ~2.00ms {F1747994032} Differential Revision: D53590566 --- torchrec/modules/regroup.py | 54 ++++++++++---------------- torchrec/modules/tests/test_regroup.py | 12 ++++-- 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/torchrec/modules/regroup.py b/torchrec/modules/regroup.py index 4fcf590d0..e4792110f 100644 --- a/torchrec/modules/regroup.py +++ b/torchrec/modules/regroup.py @@ -12,17 +12,7 @@ from typing import Dict, List, Optional, Tuple import torch -from torchrec.sparse.jagged_tensor import ( - _all_keys_used_once, - _desugar_keyed_tensors, - _remap_to_groups, - KeyedTensor, -) - - -@torch.fx.wrap -def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor: - return torch.cat([kt.values() for kt in kts], dim=dim) +from torchrec.sparse.jagged_tensor import _desugar_keyed_tensors, KeyedTensor @torch.fx.wrap @@ -80,23 +70,22 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None: self._use_fbgemm_regroup: bool = False self._splits: List[int] = [] self._idx_key_pairs: List[Tuple[int, str]] = [] - self._permute_tensor: Optional[torch.Tensor] = None - self._inv_permute_tensor: Optional[torch.Tensor] = None - self._offsets_tensor: Optional[torch.Tensor] = None - self._inv_offsets_tensor: Optional[torch.Tensor] = None + self._permutes: Optional[torch.Tensor] = None + self._in_shapes: Optional[torch.Tensor] = None + self._out_shapes: Optional[torch.Tensor] = None + self._out_lengths: Optional[List[int]] = None def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None: self._use_fbgemm_regroup = True keys, lengths, values = _desugar_keyed_tensors(kts) - permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( - keys, lengths, self._groups + self._permutes, self._in_shapes, self._out_shapes, self._out_lengths = ( + torch.ops.fbgemm.kt_regroup_permutes( + values[0], + keys, + lengths, + self._groups, + ) ) - # no need to pin_memory() or to(..., non_blocking=True) since occurs only once - self._permute_tensor = permute.to(self.device) - self._inv_permute_tensor = inv_permute.to(self.device) - self._offsets_tensor = offsets.to(self.device) - self._inv_offsets_tensor = inv_offsets.to(self.device) - self._splits = splits def _init_regroup(self, kts: List[KeyedTensor]) -> None: lengths = [kt.length_per_key() for kt in kts] @@ -137,24 +126,23 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]: ), "All inputs should have the same key_dim" self._dim = keyed_tensors[0].key_dim() - if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1: + if self._dim == 1: self._init_fbgemm_regroup(keyed_tensors) else: self._init_regroup(keyed_tensors) self._is_inited = True if self._use_fbgemm_regroup: - values = _concat_values(keyed_tensors, self._dim) - permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( - values, - self._offsets_tensor, - self._permute_tensor, - self._inv_offsets_tensor, - self._inv_permute_tensor, + permuted_values = torch.ops.fbgemm.permute_multi_embedding( + [kt.values() for kt in keyed_tensors], + self._permutes, + self._in_shapes, + self._out_shapes, + self._out_lengths, ) + return {key: tensor for key, tensor in zip(self._keys, permuted_values)} else: permuted_values = _permuted_values( keyed_tensors, self._idx_key_pairs, self._dim ) - - return _build_dict(self._keys, permuted_values, self._splits, self._dim) + return _build_dict(self._keys, permuted_values, self._splits, self._dim) diff --git a/torchrec/modules/tests/test_regroup.py b/torchrec/modules/tests/test_regroup.py index 4f00b99c1..78fd73ebe 100644 --- a/torchrec/modules/tests/test_regroup.py +++ b/torchrec/modules/tests/test_regroup.py @@ -26,19 +26,20 @@ def setUp(self) -> None: dim_dense=64, dim_sparse=128, batch_size=128, - device=torch.device("cpu"), + device=torch.device("cuda"), run_backward=True, ) self.num_groups = 2 self.keys = ["user", "object"] - self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + self.labels = torch.randint(0, 1, (128,), device=torch.device("cuda")).float() def test_regroup_backward_skips_and_duplicates(self) -> None: + print(38) groups = build_groups( kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True ) assert _all_keys_used_once(self.kts, groups) is False - + breakpoint() regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) tensor_groups = regroup_module(self.kts) pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) @@ -65,11 +66,12 @@ def test_regroup_backward_skips_and_duplicates(self) -> None: torch.allclose(actual_kt_1_grad, expected_kt_1_grad) def test_regroup_backward(self) -> None: + print(70) groups = build_groups( kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False ) assert _all_keys_used_once(self.kts, groups) is True - + breakpoint() regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) tensor_groups = regroup_module(self.kts) pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) @@ -96,6 +98,7 @@ def test_regroup_backward(self) -> None: torch.allclose(actual_kt_1_grad, expected_kt_1_grad) def test_fx_and_jit_regroup(self) -> None: + print(102) groups = build_groups( kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False ) @@ -115,6 +118,7 @@ def test_fx_and_jit_regroup(self) -> None: torch.allclose(out[key], eager_out[key]) def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None: + print(122) groups = build_groups( kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True )