From 9d8ea07d328840029d65dae3d5ec349dd5178a1f Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 15 Aug 2024 17:23:16 -0700 Subject: [PATCH] refactor _maybe_compute_length_per_key to avoid graph break Summary: # context * in `PositionWeightedModuleCollection` KJT was created by providing both offsets and lengths [codepointer](https://fburl.com/code/e1zqkn2n) ``` return KeyedJaggedTensor( keys=features.keys(), values=features.values(), weights=get_weights_list(cat_seq, features, self.position_weights_dict), lengths=features.lengths(), offsets=features.offsets(), stride=features.stride(), length_per_key=features.length_per_key(), ) ``` * however, in jagged_tensor, offsets logic is in front lengths ``` if len(keys) and values is not None and values.is_meta: # create dummy lengths per key when on meta device total_length = values.numel() _length = [total_length // len(keys)] * len(keys) _length[0] += total_length % len(keys) elif len(keys) and offsets is not None and len(offsets) > 0: _length: List[int] = ( _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) if variable_stride_per_key else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() ) elif len(keys) and lengths is not None: _length: List[int] = ( _length_per_key_from_stride_per_key(lengths, stride_per_key) if variable_stride_per_key else ( torch.sum( pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1 ).tolist() if pt2_guard_size_oblivious(lengths.numel() != 0) else [0] * len(keys) ) ) ``` * we actually perfer lengths over offsets in PT2 compile, so this diff changes the order. # latest results: * tlparse [34309d026](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpPW5dZB/index.html) # previous issue * tlparse [df3d7729e](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html) * P1530103883 ``` Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0). (Size-like symbols: u38, u37) ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False. Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance. Potential framework code culprit (scroll up for full backtrace): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward if end_val < 0: For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u38,u37" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing User Stack (most recent call last): (snipped, see stack below for prefix) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/aps_models/ads/icvr/models/experimental/fmc/ig_fm_v4_mini.py", line 1326, in forward embs_kt_list = self.sparse_arch( File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/apf/rec/modules/embedding_bag_collections_sparse_arch.py", line 358, in forward ret.append(self.position_ebc(id_list_features)) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/distributed/types.py", line 896, in forward return self.compute_and_output_dist(ctx, dist_input) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/distributed/fp_embeddingbag.py", line 131, in compute_and_output_dist fp_features = self.apply_feature_processors_to_kjt_list(input) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/distributed/fp_embeddingbag.py", line 101, in apply_feature_processors_to_kjt_list kjt_list.append(self._feature_processors(features)) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/modules/feature_processor_.py", line 191, in forward weights=get_weights_list(cat_seq, features, self.position_weights_dict), File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/modules/feature_processor_.py", line 142, in get_weights_list seqs = torch.split(cat_seq, features.length_per_key()) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1874, in length_per_key _length_per_key = _maybe_compute_length_per_key( File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1009, in _maybe_compute_length_per_key _length: List[int] = torch.diff(strided_offsets).tolist() ``` Differential Revision: D56339251 --- torchrec/sparse/jagged_tensor.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index b269b7820..773f219cf 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -994,11 +994,19 @@ def _maybe_compute_length_per_key( _length = [total_length // len(keys)] * len(keys) _length[0] += total_length % len(keys) elif len(keys) and offsets is not None and len(offsets) > 0: - _length: List[int] = ( - _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) - if variable_stride_per_key - else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() - ) + if variable_stride_per_key: + _length: List[int] = _length_per_key_from_stride_per_key( + torch.diff(offsets), stride_per_key + ) + else: + if not torch.jit.is_scripting(): + torch._check(stride > 0) + strided_offsets = offsets[::stride] + if not torch.jit.is_scripting(): + torch._check( + strided_offsets.numel() > 1 + ) # len(offsets) = len(lengths) + 1 >= 2 + _length: List[int] = torch.diff(strided_offsets).tolist() elif len(keys) and lengths is not None: _length: List[int] = ( _length_per_key_from_stride_per_key(lengths, stride_per_key)