Skip to content

Commit

Permalink
use new op in KTRegroupAsDict module
Browse files Browse the repository at this point in the history
Summary:
# context
* adding PackedTensorAccessor for passing the index tensor to kernel
* GPU trace reading slows down from 2.20ms to 2.26ms

# traces
* previous ~4.90s
 {F1747994738}
* after ~2.00ms
 {F1747994032}

Differential Revision: D53590566
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 7, 2024
1 parent 8e1d37f commit 86f1984
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 37 deletions.
54 changes: 21 additions & 33 deletions torchrec/modules/regroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,7 @@
from typing import Dict, List, Optional, Tuple

import torch
from torchrec.sparse.jagged_tensor import (
_all_keys_used_once,
_desugar_keyed_tensors,
_remap_to_groups,
KeyedTensor,
)


@torch.fx.wrap
def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor:
return torch.cat([kt.values() for kt in kts], dim=dim)
from torchrec.sparse.jagged_tensor import _desugar_keyed_tensors, KeyedTensor


@torch.fx.wrap
Expand Down Expand Up @@ -80,23 +70,22 @@ def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
self._use_fbgemm_regroup: bool = False
self._splits: List[int] = []
self._idx_key_pairs: List[Tuple[int, str]] = []
self._permute_tensor: Optional[torch.Tensor] = None
self._inv_permute_tensor: Optional[torch.Tensor] = None
self._offsets_tensor: Optional[torch.Tensor] = None
self._inv_offsets_tensor: Optional[torch.Tensor] = None
self._permutes: Optional[torch.Tensor] = None
self._in_shapes: Optional[torch.Tensor] = None
self._out_shapes: Optional[torch.Tensor] = None
self._out_lengths: Optional[List[int]] = None

def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
self._use_fbgemm_regroup = True
keys, lengths, values = _desugar_keyed_tensors(kts)
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
keys, lengths, self._groups
self._permutes, self._in_shapes, self._out_shapes, self._out_lengths = (
torch.ops.fbgemm.kt_regroup_permutes(
values[0],
keys,
lengths,
self._groups,
)
)
# no need to pin_memory() or to(..., non_blocking=True) since occurs only once
self._permute_tensor = permute.to(self.device)
self._inv_permute_tensor = inv_permute.to(self.device)
self._offsets_tensor = offsets.to(self.device)
self._inv_offsets_tensor = inv_offsets.to(self.device)
self._splits = splits

def _init_regroup(self, kts: List[KeyedTensor]) -> None:
lengths = [kt.length_per_key() for kt in kts]
Expand Down Expand Up @@ -137,24 +126,23 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
), "All inputs should have the same key_dim"
self._dim = keyed_tensors[0].key_dim()

if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1:
if self._dim == 1:
self._init_fbgemm_regroup(keyed_tensors)
else:
self._init_regroup(keyed_tensors)
self._is_inited = True

if self._use_fbgemm_regroup:
values = _concat_values(keyed_tensors, self._dim)
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
values,
self._offsets_tensor,
self._permute_tensor,
self._inv_offsets_tensor,
self._inv_permute_tensor,
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
[kt.values() for kt in keyed_tensors],
self._permutes,
self._in_shapes,
self._out_shapes,
self._out_lengths,
)
return {key: tensor for key, tensor in zip(self._keys, permuted_values)}
else:
permuted_values = _permuted_values(
keyed_tensors, self._idx_key_pairs, self._dim
)

return _build_dict(self._keys, permuted_values, self._splits, self._dim)
return _build_dict(self._keys, permuted_values, self._splits, self._dim)
12 changes: 8 additions & 4 deletions torchrec/modules/tests/test_regroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ def setUp(self) -> None:
dim_dense=64,
dim_sparse=128,
batch_size=128,
device=torch.device("cpu"),
device=torch.device("cuda"),
run_backward=True,
)
self.num_groups = 2
self.keys = ["user", "object"]
self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float()
self.labels = torch.randint(0, 1, (128,), device=torch.device("cuda")).float()

def test_regroup_backward_skips_and_duplicates(self) -> None:
print(38)
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
)
assert _all_keys_used_once(self.kts, groups) is False

breakpoint()
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
tensor_groups = regroup_module(self.kts)
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
Expand All @@ -65,11 +66,12 @@ def test_regroup_backward_skips_and_duplicates(self) -> None:
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

def test_regroup_backward(self) -> None:
print(70)
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False
)
assert _all_keys_used_once(self.kts, groups) is True

breakpoint()
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
tensor_groups = regroup_module(self.kts)
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
Expand All @@ -96,6 +98,7 @@ def test_regroup_backward(self) -> None:
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

def test_fx_and_jit_regroup(self) -> None:
print(102)
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False
)
Expand All @@ -115,6 +118,7 @@ def test_fx_and_jit_regroup(self) -> None:
torch.allclose(out[key], eager_out[key])

def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None:
print(122)
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
)
Expand Down

0 comments on commit 86f1984

Please sign in to comment.