Skip to content

Commit

Permalink
[QNN Quant] Handle external data for QNN preprocessing/quant (#19670)
Browse files Browse the repository at this point in the history
### Description
- Adds parameters to `qnn_preprocess_model()` to allow saving the new
model with external data.
- Updates `get_qnn_qdq_config()` to:
  - Load model without external data (it is not needed)
- Return a quantization configuration with `use_external_data_format`
set to `True` if the model has external data or if the model is >= 2GB.

### Motivation and Context
Update QNN quantization to better handle large models that use external
data.
  • Loading branch information
adrianlizarraga authored Feb 28, 2024
1 parent 7a147fc commit 913bdc7
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import logging
from pathlib import Path

Expand All @@ -13,7 +15,44 @@
from .fusion_lpnorm import FusionLpNormalization


def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool:
def qnn_preprocess_model(
model_input: Path,
model_output: Path,
fuse_layernorm: bool = False,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = False,
external_data_location: str | None = None,
external_data_size_threshold: int = 1024,
external_data_convert_attribute: bool = False,
) -> bool:
"""
If necessary, this method creates a new "pre-processed" model in preparation for
quantization of a model to be used in QNN EP. Returns true if a new model was created.
This method perfoms the following operations:
- Fuse Erf sequence into a single Gelu node.
- Fuse ReduceL2 sequence into a single LpNormalization node (p == 2).
- (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.
Args:
model_input: Path to the input model file.
model_output: Path the output model file, which is only created if this method returns True.
fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
Defaults to False.
save_as_external_data: True if output model should be saved with external data. Defaults to false.
all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false.
If true, save all tensors to one external file specified by external_data_location.
If false, save each tensor to a file named with the tensor name.
external_data_location: Effective only if save_as_external_data is true. Defaults to None.
Specify the external file to which all tensors are saved. Path is relative
to the model path. If not specified, the model's name is used.
external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024.
Tensors with a data size >= external_data_size_threshold are converted to external data.
To convert every tensor with raw data to external data, set to 0.
external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false.
If true, convert all tensors to external data.
If false, convert only non-attribute tensors to external data.
"""
modified = False
model = onnx.load_model(model_input)
onnx_model = ONNXModel(model)
Expand Down Expand Up @@ -57,6 +96,14 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm:

if modified:
onnx_model.topological_sort()
onnx.save_model(model, model_output)
onnx.save_model(
model,
model_output,
save_as_external_data=save_as_external_data,
all_tensors_to_one_file=all_tensors_to_one_file,
location=external_data_location,
size_threshold=external_data_size_threshold,
convert_attribute=external_data_convert_attribute,
)

return modified
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16}
Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8}
OP_TYPES_TO_EXCLUDE = {"Cast"}
MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB


def get_qnn_qdq_config(
Expand All @@ -28,14 +29,21 @@ def get_qnn_qdq_config(
if per_channel:
raise ValueError("QNN EP does not yet support per-channel quantization.")

# Process model nodes to setup overrides.
model = onnx.load_model(model_input)
model = onnx.load_model(model_input, load_external_data=False)

op_types = set()
tensor_quant_overrides = {}
model_has_external_data = False
name_to_initializer = {}

name_to_initializer = {initializer.name: initializer for initializer in model.graph.initializer}
# Build map of initializers (name -> initializer) and
# check if the model has external data.
for initializer in model.graph.initializer:
name_to_initializer[initializer.name] = initializer
if onnx.external_data_helper.uses_external_data(initializer):
model_has_external_data = True

# Setup quantization overrides for specific operator types
for node in model.graph.node:
op_types.add(node.op_type)

Expand Down Expand Up @@ -89,5 +97,6 @@ def get_qnn_qdq_config(
activation_type=activation_type,
weight_type=weight_type,
op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)),
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
extra_options=extra_options,
)
170 changes: 170 additions & 0 deletions onnxruntime/test/python/quantization/test_qnn_preprocess_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import math
import unittest
from pathlib import Path

import numpy as np
import onnx

from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model
from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain


class TestQnnPreprocessModel(unittest.TestCase):
def build_model(self, shape, scale_val, bias_val):
"""
Build a model that supports 3 kinds of fusions:
- Erf sequence to Gelu
- ReduceL2 sequence to LpNormalization
- ReduceMean sequence to LayerNormalization
"""
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)

# Erf sequence
one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const")

e_mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["e_mul0_out"])
e_div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["e_div_out"])
e_erf_node = onnx.helper.make_node("Erf", ["e_div_out"], ["e_erf_out"])
e_add_node = onnx.helper.make_node("Add", ["e_erf_out", "one_const"], ["e_add_out"])
e_mul1_node = onnx.helper.make_node("Mul", ["e_add_out", "e_mul0_out"], ["erf_seq_output"])

# ReduceL2 sequence
axes_const = onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), "axes_const")
eps_const = onnx.numpy_helper.from_array(np.array(1e-12, dtype=np.float32), "eps_const")
shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const")

l2_rl2_node = onnx.helper.make_node("ReduceL2", ["erf_seq_output", "axes_const"], ["l2_rl2_out"], keepdims=1)
l2_clip_node = onnx.helper.make_node("Clip", ["l2_rl2_out", "eps_const"], ["l2_clip_out"])
l2_expand_node = onnx.helper.make_node("Expand", ["l2_clip_out", "shape_const"], ["l2_expand_out"])
l2_div_node = onnx.helper.make_node("Div", ["erf_seq_output", "l2_expand_out"], ["l2_seq_output"])

# ReduceMean sequence
scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const")
bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const")
two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const")

m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"])
m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"])
m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"])
m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"])
m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"])
m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"])
m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"])
m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"])
m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"])

graph = onnx.helper.make_graph(
[
e_mul0_node,
e_div_node,
e_erf_node,
e_add_node,
e_mul1_node,
l2_rl2_node,
l2_clip_node,
l2_expand_node,
l2_div_node,
m_rm0_node,
m_sub_node,
m_pow_node,
m_rm1_node,
m_add0_node,
m_sqrt_node,
m_div_node,
m_mul_node,
m_add1_node,
],
"qnn_f32_model",
[root_inp],
[output],
initializer=[
one_const,
half_const,
root2_const,
axes_const,
eps_const,
shape_const,
scale_const,
bias_const,
two_const,
],
)
opset_imports = [
onnx.helper.make_opsetid("", 18),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return onnx.shape_inference.infer_shapes(model)

def test_all_fusions(self):
"""
Test calling qnn_preprocess_model() with a model that supports all 3 fusions.
"""
model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
onnx.save_model(model, "model.onnx")
modified = qnn_preprocess_model("model.onnx", "model.qnn_pp.onnx", fuse_layernorm=True)

self.assertTrue(modified)

fused_model = onnx.load_model("model.qnn_pp.onnx")

# 3 fused Ops: Gelu, LpNorm, LayerNorm
self.assertEqual(len(fused_model.graph.node), 3)
expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"}
for node in fused_model.graph.node:
self.assertIn(node.op_type, expected_op_types)

# Should have added "com.microsoft" opset import because we added a Gelu.
ms_domain_opset = next((opset for opset in fused_model.opset_import if opset.domain == ms_domain), None)
self.assertIsNotNone(ms_domain_opset)
self.assertEqual(ms_domain_opset.version, 1)

def test_external_data(self):
"""
Test calling qnn_preprocess_model() with a model that uses external data.
The new preprocessed model should also have external data.
"""
model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
onnx.save_model(
model,
"model.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.bin",
size_threshold=0,
)
modified = qnn_preprocess_model(
"model.onnx",
"model.qnn_pp.onnx",
fuse_layernorm=True,
save_as_external_data=True,
all_tensors_to_one_file=True,
external_data_location="weights2.bin",
external_data_size_threshold=0,
)

self.assertTrue(modified)

# Model should still have external data.
self.assertTrue(model_has_external_data(Path("model.qnn_pp.onnx")))

fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False)

# 3 fused Ops: Gelu, LpNorm, LayerNorm
self.assertEqual(len(fused_model.graph.node), 3)
expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"}
for node in fused_model.graph.node:
self.assertIn(node.op_type, expected_op_types)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,36 @@ def test_get_qnn_qdq_config(self):
self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0))

def test_get_qnn_qdq_config_ext_data(self):
"""
Test that get_qnn_qdq_config() returns a config that enables external data
if the input model has external data.
"""

# Create model with a weight large enough (> 1024 bytes) to be stored externally.
large_weight = onnx.numpy_helper.from_array(np.random.random((1, 32, 32)).astype(np.float32), "weight")
graph = onnx.helper.make_graph(
[onnx.helper.make_node("Add", ["input", "weight"], ["output"])],
"add_ext_data",
[onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 32, 32))],
[onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1, 32, 32))],
initializer=[large_weight],
)
model = onnx.helper.make_model(
graph,
opset_imports=[onnx.helper.make_opsetid("", 18)],
)
onnx.save_model(
model,
"add_ext_data.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location="add_ext_data.bin",
)

qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations))
self.assertTrue(qnn_config.use_external_data_format)


if __name__ == "__main__":
t = TestTensorQuantOverridesOption()
Expand Down

0 comments on commit 913bdc7

Please sign in to comment.