Skip to content

Commit

Permalink
modify the zero-size values in KJT for dynamic shape compatibility (p…
Browse files Browse the repository at this point in the history
…ytorch#2250)

Summary:
Pull Request resolved: pytorch#2250

# context
* dynamic shape usually has a minimum value requirement: `dynamic_shape >= 2`
* however, in reality, the actual KJT._values could be empty
* this issue was discribed in D57998381

> run the local-default every a few time one could get an error for some of the dynamic_shape being zero
This is because in some corner case (not very rare though), the some dynamic_shape dim of the `sample_input` could be zero,
and 0-size dynamic shape is handled differently during torch.export. **Bascially it will assume this dynamic shape should always be zero.**
* error log: P1462233278
```
[rank0]:   - Not all values of vlen5 = L['args'][0][0].event_id_list_features_seqs['marketplace']._values.size()[0] in the specified range are valid because vlen5 was inferred to be a constant (0).
[rank0]: Suggested fixes:
[rank0]:   vlen5 = 0
```

# method
* padding the kjt._values with the minimum required size `(2, )`
* in the case of empty values, kjt._lengths and kjt._offsets should all be zeros
* it doesn't affect the true logic/mathematic values of the kjt

# issues
1. exported_program.module can't take in empty-value input.
2. deserialized unflattened model can't take in empty-value input, which could happen in real data.
3. deserialized unflattened model can't take in altered input, which could be a potential workaround if can't resolve pytorch#2.

NOTE: Please check the in-line comments in the test file for details

# Other Concerns
1. the inconsistency in the KJT (lengths are zeros, but values is non empty) might be incompatible with some downstream functions/operators, will need more tests to confirm.
2.

Differential Revision: D45410437
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 27, 2024
1 parent 35defde commit 14f79a7
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def mark_dynamic_kjt(
"""
if shapes_collection is None:
shapes_collection = ShapesCollection()
if kjt._values.numel() == 0:
# if the values is empty, we need to set the shape to (2,) to make it compatible with dynamic shape
# a 0-size dynamic shape will cause error in torch.export.
# logically when the values is empty, the lengths and offsets should all be zero-value tensors.
# And this makes the actual values irrelavent to the downstream process.
kjt._values = torch.ones(2, device=kjt._values.device, dtype=kjt._values.dtype)
vlen = _get_dim("vlen") if vlen is None else vlen
shapes_collection[kjt._values] = (vlen,)
if kjt._weights is not None:
Expand Down

0 comments on commit 14f79a7

Please sign in to comment.