forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[QNN EP] Support per-channel quantized weights (microsoft#20154)
### Description - Adds general support for per-channel quantized weights to QNN EP (HTP backend). - Add QNN EP unit tests for per-channel Conv - Update quantization tool to allow selecting which ops are quantized per-channel (and which axis) via tensor-level overrides. Currently, setting `per_channel=True` assumes all Convs, MatMuls, Gemms, InstanceNormalization, and LayerNormalization ops should be quantized per-channel using some assumed default axis. #### Creating QDQ per-channel Conv model example ```python from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model class DataReader(CalibrationDataReader): # TODO: See ONNX Runtime QNN docs for example of a data reader # https://onnxruntime.ai/docs/execution-providers/QNN-ExecutionProvider.html#generating-a-quantized-model-x64 pass if __name__ == "__main__": input_model_path = "model.onnx" my_data_reader = DataReader(model_to_quantize) # Pre-process the original float32 model. preproc_model_path = "model.preproc.onnx" model_changed = qnn_preprocess_model(input_model_path, preproc_model_path) model_to_quantize = preproc_model_path if model_changed else input_model_path # RELEVANT TO THIS PR: # Make sure Conv's weight input is quantized to int8/symmetric/per-channel with axis == 0. # The presence of the 'axis' key indicates that this is a per-channel quantized weight. init_overrides = {'weight': [{'axis': 0, 'quant_type': QuantType.QInt8, 'symmetric': True}]} qnn_config = get_qnn_qdq_config(model_to_quantize, my_data_reader, init_overrides=init_overrides, activation_type=QuantType.QUInt16, # uint16 activations weight_type=QuantType.QUInt8) # uint8 weights by default quantize(model_to_quantize, "model.qdq.onnx", qnn_config) ``` float32 model: <img width="683" alt="image" src="https://github.com/microsoft/onnxruntime/assets/19691973/ca650e49-1ad0-47d8-8c46-17fbc224ca39"> QDQ model (per-channel Conv weight): <img width="748" alt="image" src="https://github.com/microsoft/onnxruntime/assets/19691973/6bd469f2-968b-4d11-9526-09b3e71f98e7"> ### Motivation and Context Support more models, especially models with int4 quantized weights.
- Loading branch information
1 parent
76b3fb6
commit ca9d9dc
Showing
45 changed files
with
2,064 additions
and
581 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.