Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Missing optimization of DequantizeLinear ∘ Flatten ∘ QuantizeLinear? #21375

Open
mcollinswisc opened this issue Jul 16, 2024 Discussed in #21167 · 1 comment
Labels
feature request request for unsupported feature or enhancement quantization issues related to quantization

Comments

@mcollinswisc
Copy link
Contributor

Discussed in #21167

Originally posted by mcollinswisc June 25, 2024
It looks like ONNXRuntime will optimize DequantizeLinear ∘ Reshape ∘ QuantizeLinear to only the Reshape, eliminating the quantization/de-quantization, if the scales & zero points are the same.

However, an equivalent Flatten is not optimized. Is this likely to be just a missing optimization, or is there some reason the qdq would be preserved in this case?

Tested out in:
https://gist.github.com/mcollinswisc/d1cd9d13b4e5fbad01c75dca5c9ca576
with ONNXRuntime 1.18.0

@github-actions github-actions bot added the quantization issues related to quantization label Jul 16, 2024
@sophies927 sophies927 added the feature request request for unsupported feature or enhancement label Jul 18, 2024
@skottmckay
Copy link
Contributor

Should be possible to add to this list given the ONNX spec for Flatten allows 8-bit integers:

qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name,
{{"Gather", {}},
{"Reshape", {}},
{"Transpose", {}},
{"Squeeze", {}},
{"Unsqueeze", {}}},

skottmckay pushed a commit that referenced this issue Aug 27, 2024
### Description

Extends the Drop QDQ optimization to remove DequantizeLinear and
QuantizeLinear nodes from around operators:

- Flatten
- Expand
- Tile
- Slice
- GatherElements
- ReduceMin
- ReduceMax

### Motivation and Context

To reduce floating-point conversions in quantize inference. Mainly
motivated by the Flatten case, since that will show up in graphs
exported from PyTorch to ONNX. But to make the change complete,
extending to a larger set of ops for which this optimization is valid.

#21375

---------

Co-authored-by: Edward Chen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request request for unsupported feature or enhancement quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants