From e2cc13ac875fcc049e03f46c3bc7ad9268104674 Mon Sep 17 00:00:00 2001 From: Joshua Deng Date: Mon, 13 Nov 2023 13:43:18 -0800 Subject: [PATCH] Minimize d2h syncs in calculating `length_per_key` from `stride_per_key` (#1485) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1485 for large numbers of features, we will call .item() for each one causing a large number of d2h syncs. this diff combines list of tensors into a single tensor and calls a single .tolist() Reviewed By: bigning Differential Revision: D51046476 fbshipit-source-id: 26fd38767d1d48dade24057cd2136b15ea29c16c --- torchrec/sparse/jagged_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 1af1f9046..6e6afbe73 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -631,9 +631,9 @@ def _maybe_compute_stride_kjt_scripted( def _length_per_key_from_stride_per_key( lengths: torch.Tensor, stride_per_key: List[int] ) -> List[int]: - return [ - int(torch.sum(chunk).item()) for chunk in torch.split(lengths, stride_per_key) - ] + return torch.cat( + [torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)] + ).tolist() def _maybe_compute_length_per_key(