Skip to content

Commit

Permalink
refactor _maybe_compute_length_per_key to avoid graph break
Browse files Browse the repository at this point in the history
Summary:
# context
* in `PositionWeightedModuleCollection` KJT was created by providing both offsets and lengths [codepointer](https://fburl.com/code/e1zqkn2n)
```
return KeyedJaggedTensor(
    keys=features.keys(),
    values=features.values(),
    weights=get_weights_list(cat_seq, features, self.position_weights_dict),
    lengths=features.lengths(),
    offsets=features.offsets(),
    stride=features.stride(),
    length_per_key=features.length_per_key(),
)
```
* however, in jagged_tensor, offsets logic is in front lengths
```
        if len(keys) and values is not None and values.is_meta:
            # create dummy lengths per key when on meta device
            total_length = values.numel()
            _length = [total_length // len(keys)] * len(keys)
            _length[0] += total_length % len(keys)
        elif len(keys) and offsets is not None and len(offsets) > 0:
            _length: List[int] = (
                _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
                if variable_stride_per_key
                else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
            )
        elif len(keys) and lengths is not None:
            _length: List[int] = (
                _length_per_key_from_stride_per_key(lengths, stride_per_key)
                if variable_stride_per_key
                else (
                    torch.sum(
                        pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1
                    ).tolist()
                    if pt2_guard_size_oblivious(lengths.numel() != 0)
                    else [0] * len(keys)
                )
            )
```
* we actually perfer lengths over offsets in PT2 compile, so this diff changes the order.

# latest results:
* tlparse [06f45619f](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpPW5dZB/index.html)

# previous issue
* tlparse [df3d7729e](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html)
* P1530103883
```
Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs.  Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0).  (Size-like symbols: u38, u37)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward
    if end_val < 0:
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u38,u37"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/aps_models/ads/icvr/models/experimental/fmc/ig_fm_v4_mini.py", line 1326, in forward
    embs_kt_list = self.sparse_arch(
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/apf/rec/modules/embedding_bag_collections_sparse_arch.py", line 358, in forward
    ret.append(self.position_ebc(id_list_features))
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/distributed/types.py", line 896, in forward
    return self.compute_and_output_dist(ctx, dist_input)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/distributed/fp_embeddingbag.py", line 131, in compute_and_output_dist
    fp_features = self.apply_feature_processors_to_kjt_list(input)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/distributed/fp_embeddingbag.py", line 101, in apply_feature_processors_to_kjt_list
    kjt_list.append(self._feature_processors(features))
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/modules/feature_processor_.py", line 191, in forward
    weights=get_weights_list(cat_seq, features, self.position_weights_dict),
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/modules/feature_processor_.py", line 142, in get_weights_list
    seqs = torch.split(cat_seq, features.length_per_key())
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1874, in length_per_key
    _length_per_key = _maybe_compute_length_per_key(
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1009, in _maybe_compute_length_per_key
    _length: List[int] = torch.diff(strided_offsets).tolist()
```

Differential Revision: D57217616
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 16, 2024
1 parent 2597d08 commit 4899d34
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,12 +993,6 @@ def _maybe_compute_length_per_key(
total_length = values.numel()
_length = [total_length // len(keys)] * len(keys)
_length[0] += total_length % len(keys)
elif len(keys) and offsets is not None and len(offsets) > 0:
_length: List[int] = (
_length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
if variable_stride_per_key
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
elif len(keys) and lengths is not None:
_length: List[int] = (
_length_per_key_from_stride_per_key(lengths, stride_per_key)
Expand All @@ -1011,6 +1005,12 @@ def _maybe_compute_length_per_key(
else [0] * len(keys)
)
)
elif len(keys) and offsets is not None and len(offsets) > 0:
_length: List[int] = (
_length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
if variable_stride_per_key
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
else:
_length: List[int] = []
length_per_key = _length
Expand Down

0 comments on commit 4899d34

Please sign in to comment.