Skip to content

Commit

Permalink
Start adding more documentation for quantizing/running models on QNN EP
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Feb 6, 2024
1 parent fde74cb commit bd12c36
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions docs/execution-providers/QNN-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,117 @@ The QNN Execution Provider supports a number of configuration options. The `prov
|'2'|longer preparation time, more optimal graph.|
|'3'|longest preparation time, most likely even more optimal graph.|

## Running a model with QNN EP's HTP backend (Python)
The QNN HTP backend, which offloads compute to the NPU, only supports quantized models. Models with 32-bit floating-point activations and weights must first be quantized to use a lower integer precision (e.g., 8-bit or 16-bit integers).
This section provides instructions for quantizing a model and then running the quantized model on QNN EP's HTP backend using Python APIs. Please refer to the [quantization page](../performance/model-optimizations/quantization.md) for a broader overview of quantization concepts.

<p align="center"><img width="50%" src="../../images/qnn_ep_quant_workflow.png" alt="Offline workflow for quantizing an ONNX model for use on QNN EP"/></p>

### Quantizing a model
The ONNX Runtime python package provides utilities for quantizing ONNX models via the `onnxruntime.quantization` import. Note that the quantization utilities are currently only supported on x86_64.
Therefore, it is recommend to either use an x64 machine to quantize models or, alternatively, use a separate x64 python installation on Windows ARM64 machines.

Install the ONNX Runtime x64 python package. We currently recommend installing the nightly version of ONNX Runtime to get the latest updates to the quantization utilities.
```shell
## Install nightly ORT built from main branch
python -m pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ ort-nightly
```

Model quantization for QNN EP requires the use of calibration input data to compute quantization parameters for all activations and weights in the model. Using a calibration dataset that is representative of typical model inputs is crucial in generating an accurate quantized model.
The following snippet defines a sample `CalibrationDataReader` class that provides calibration data to the quantization utilies. Note, however, that the following example uses random inputs as the calibration dataset only for simplicity. Using random input data results in inaccurate models in practice.

```Python3
# data_reader.py

import numpy as np
import onnxruntime
import os
from onnxruntime.quantization import CalibrationDataReader


NP_TYPE = {
"tensor(float)": np.float32,
"tensor(uint8)": np.uint8,
"tensor(int8)": np.int8,
"tensor(uint16)": np.uint16,
"tensor(int16)": np.int16,
}

def get_np_type(onnx_type):
if onnx_type in NP_TYPE:
return NP_TYPE[onnx_type]
else:
raise Exception("Unhandled onnx_type in np_type_from_onnx_type")

class DataReader(CalibrationDataReader):
def __init__(self, model_path: str):
self.enum_data = None

# Use inference session to get input shape.
session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])

inputs = session.get_inputs()

self.data_list = []

# Generate 10 random inputs
# TODO: Load valid calibration input data
for _ in range(10):
input_data = {inp.name : np.random.random(inp.shape).astype(get_np_type(inp.type)) for inp in inputs}
self.data_list.append(input_data)

self.datasize = len(self.data_list)

def get_next(self):
if self.enum_data is None:
self.enum_data = iter(
self.data_list
)
return next(self.enum_data, None)

def rewind(self):
self.enum_data = None

```

The following snippet pre-processes the original model and then quantizes the pre-processed model using the above `CalibrationDataReader` class.

```Python3
# quantize_model.py

import data_reader
import numpy as np
import onnx
import onnxruntime
from onnxruntime.quantization import QuantFormat, QuantType, quantize
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model

if __name__ == "__main__":
input_model_path = "model.onnx" # TODO: Replace with your actual model
output_model_path = "model.qdq.onnx" # Name of final quantized model
my_data_reader = data_reader.DataReader(input_model_path)

# Pre-process the original float32 model.
preproc_model_path = "model.preproc.onnx"
model_changed = qnn_preprocess_model(input_model_path, prepoc_model_path)
model_to_quantize = preproc_model_path if model_changed else input_model_path

# Generate a suitable quantization configuration for this model.
# Note that we're choosing to use uint16 activations and uint8 weights.
qnn_config = get_qnn_qdq_config(model_to_quantize,
my_data_reader,
activation_type=QuantType.QUInt16, # uint16 activations
weight_type=QuantType.QUInt8) # uint8 weights

# Quantize the model.
quantize(model_to_quantize, output_model_path, qnn_config)
```

Running `python quantize_model.py` will generate a quantized model (`model.qdq.onnx`) that can be run on Windows ARM64 devices via ONNX Runtime's QNN EP.
Refer to [quantization/execution_providers/qnn/preprocess.py](https://github.com/microsoft/onnxruntime/blob/23996bbbbe0406a5c8edbf6b7dbd71e5780d3f4b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py#L16) and
[quantization/execution_providers/qnn/quant_config.py](https://github.com/microsoft/onnxruntime/blob/23996bbbbe0406a5c8edbf6b7dbd71e5780d3f4b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py#L20-L27)
for more information on available function parameters.


## QNN context binary cache feature
There's a QNN context which contains QNN graphs after converting, compiling, filnalizing the model. QNN can serialize the context into binary file, so that user can use it for futher inference direclty (without the QDQ model) to improve the model loading cost.
Expand Down
Binary file added images/qnn_ep_quant_workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit bd12c36

Please sign in to comment.