Skip to content

Commit

Permalink
benchmark of fbgemm op - regroup_keyed_tensor
Browse files Browse the repository at this point in the history
Differential Revision: D58907223
  • Loading branch information
Huanyu He authored and facebook-github-bot committed Jul 10, 2024
1 parent 1dc3dde commit e5625d5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
19 changes: 14 additions & 5 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def permute_multi_embedding(
return permuted_values


@torch.fx.wrap
def keyed_tensor_regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
return torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
Expand Down Expand Up @@ -2726,11 +2739,7 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
def regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)
return permute_multi_embedding(keyed_tensors, groups)

@staticmethod
def regroup_as_dict(
Expand Down
33 changes: 29 additions & 4 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_fbgemm_permute_pooled_embs,
_regroup_keyed_tensors,
keyed_tensor_regroup,
KeyedJaggedTensor,
KeyedTensor,
permute_multi_embedding,
Expand Down Expand Up @@ -213,7 +215,7 @@ def main(
).float()
groups = build_groups(kts, n_groups, duplicates=duplicates)
bench(
"_regroup_keyed_tenors" + dup,
"[pytorch generic] backward_fallback" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -224,7 +226,7 @@ def main(
profile,
)
bench(
"KeyedTensor.regroup" + dup,
"[Prod] KeyedTensor.regroup" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -235,7 +237,7 @@ def main(
profile,
)
bench(
"KTRegroupAsDict" + dup,
"[Module] KTRegroupAsDict" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -248,7 +250,7 @@ def main(
profile,
)
bench(
"permute_multi_embs" + dup,
"[2 Ops] permute_multi_embs" + dup,
labels,
batch_size,
n_dense + n_sparse,
Expand All @@ -258,6 +260,29 @@ def main(
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"[1 Op] KT_regroup" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
keyed_tensor_regroup,
{"keyed_tensors": kts, "groups": groups},
profile,
)
if not duplicates:
bench(
"[Old Prod] permute_pooled_embs" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_fbgemm_permute_pooled_embs,
{"keyed_tensors": kts, "groups": groups},
profile,
)


if __name__ == "__main__":
Expand Down

0 comments on commit e5625d5

Please sign in to comment.