Skip to content

Commit

Permalink
Back out "Minimize gpu -> cpu data transfers in `_length_per_key_from…
Browse files Browse the repository at this point in the history
…_stride_per_key`" (#1489)

Summary:
Pull Request resolved: #1489

segment_sum_csr causes a regression:

"The kernel is not suitable for the case where the number of segments is small and each segment is large because we parallelize different segments across thread blocks and use one thread block per segment. In your case, if every element is in the same segment, this op will be very slow."

Reviewed By: bigning

Differential Revision: D51039432

fbshipit-source-id: 0d050e80c0046defa395d1b870bc5e677e2f3856
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Nov 8, 2023
1 parent 3d192f9 commit dcd7139
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,9 @@ def _maybe_compute_stride_kjt_scripted(
def _length_per_key_from_stride_per_key(
lengths: torch.Tensor, stride_per_key: List[int]
) -> List[int]:
stride_per_key_offsets = _to_offsets(
_pin_and_move(torch.tensor(stride_per_key, dtype=torch.int32), lengths.device)
)
return torch.ops.fbgemm.segment_sum_csr(1, stride_per_key_offsets, lengths).tolist()
return [
int(torch.sum(chunk).item()) for chunk in torch.split(lengths, stride_per_key)
]


def _maybe_compute_length_per_key(
Expand Down

0 comments on commit dcd7139

Please sign in to comment.