Skip to content

Commit

Permalink
Merge branch 'main' into adrianl/set-ms-domain-from-tensor-quant-over…
Browse files Browse the repository at this point in the history
…rides
  • Loading branch information
adrianlizarraga committed Feb 28, 2024
2 parents 3aaa783 + 913bdc7 commit 3b3468c
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import logging
from pathlib import Path

Expand All @@ -13,7 +15,44 @@
from .fusion_lpnorm import FusionLpNormalization


def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool:
def qnn_preprocess_model(
model_input: Path,
model_output: Path,
fuse_layernorm: bool = False,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = False,
external_data_location: str | None = None,
external_data_size_threshold: int = 1024,
external_data_convert_attribute: bool = False,
) -> bool:
"""
If necessary, this method creates a new "pre-processed" model in preparation for
quantization of a model to be used in QNN EP. Returns true if a new model was created.
This method perfoms the following operations:
- Fuse Erf sequence into a single Gelu node.
- Fuse ReduceL2 sequence into a single LpNormalization node (p == 2).
- (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.
Args:
model_input: Path to the input model file.
model_output: Path the output model file, which is only created if this method returns True.
fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
Defaults to False.
save_as_external_data: True if output model should be saved with external data. Defaults to false.
all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false.
If true, save all tensors to one external file specified by external_data_location.
If false, save each tensor to a file named with the tensor name.
external_data_location: Effective only if save_as_external_data is true. Defaults to None.
Specify the external file to which all tensors are saved. Path is relative
to the model path. If not specified, the model's name is used.
external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024.
Tensors with a data size >= external_data_size_threshold are converted to external data.
To convert every tensor with raw data to external data, set to 0.
external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false.
If true, convert all tensors to external data.
If false, convert only non-attribute tensors to external data.
"""
modified = False
model = onnx.load_model(model_input)
onnx_model = ONNXModel(model)
Expand Down Expand Up @@ -57,6 +96,14 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm:

if modified:
onnx_model.topological_sort()
onnx.save_model(model, model_output)
onnx.save_model(
model,
model_output,
save_as_external_data=save_as_external_data,
all_tensors_to_one_file=all_tensors_to_one_file,
location=external_data_location,
size_threshold=external_data_size_threshold,
convert_attribute=external_data_convert_attribute,
)

return modified
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16}
Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8}
OP_TYPES_TO_EXCLUDE = {"Cast"}
MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB


def get_qnn_qdq_config(
Expand All @@ -28,14 +29,21 @@ def get_qnn_qdq_config(
if per_channel:
raise ValueError("QNN EP does not yet support per-channel quantization.")

# Process model nodes to setup overrides.
model = onnx.load_model(model_input)
model = onnx.load_model(model_input, load_external_data=False)

op_types = set()
tensor_quant_overrides = {}
model_has_external_data = False
name_to_initializer = {}

name_to_initializer = {initializer.name: initializer for initializer in model.graph.initializer}
# Build map of initializers (name -> initializer) and
# check if the model has external data.
for initializer in model.graph.initializer:
name_to_initializer[initializer.name] = initializer
if onnx.external_data_helper.uses_external_data(initializer):
model_has_external_data = True

# Setup quantization overrides for specific operator types
for node in model.graph.node:
op_types.add(node.op_type)

Expand Down Expand Up @@ -89,5 +97,6 @@ def get_qnn_qdq_config(
activation_type=activation_type,
weight_type=weight_type,
op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)),
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
extra_options=extra_options,
)
170 changes: 170 additions & 0 deletions onnxruntime/test/python/quantization/test_qnn_preprocess_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import math
import unittest
from pathlib import Path

import numpy as np
import onnx

from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model
from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain


class TestQnnPreprocessModel(unittest.TestCase):
def build_model(self, shape, scale_val, bias_val):
"""
Build a model that supports 3 kinds of fusions:
- Erf sequence to Gelu
- ReduceL2 sequence to LpNormalization
- ReduceMean sequence to LayerNormalization
"""
root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)

# Erf sequence
one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const")

e_mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["e_mul0_out"])
e_div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["e_div_out"])
e_erf_node = onnx.helper.make_node("Erf", ["e_div_out"], ["e_erf_out"])
e_add_node = onnx.helper.make_node("Add", ["e_erf_out", "one_const"], ["e_add_out"])
e_mul1_node = onnx.helper.make_node("Mul", ["e_add_out", "e_mul0_out"], ["erf_seq_output"])

# ReduceL2 sequence
axes_const = onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), "axes_const")
eps_const = onnx.numpy_helper.from_array(np.array(1e-12, dtype=np.float32), "eps_const")
shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const")

l2_rl2_node = onnx.helper.make_node("ReduceL2", ["erf_seq_output", "axes_const"], ["l2_rl2_out"], keepdims=1)
l2_clip_node = onnx.helper.make_node("Clip", ["l2_rl2_out", "eps_const"], ["l2_clip_out"])
l2_expand_node = onnx.helper.make_node("Expand", ["l2_clip_out", "shape_const"], ["l2_expand_out"])
l2_div_node = onnx.helper.make_node("Div", ["erf_seq_output", "l2_expand_out"], ["l2_seq_output"])

# ReduceMean sequence
scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const")
bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const")
two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const")

m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"])
m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"])
m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"])
m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"])
m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"])
m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"])
m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"])
m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"])
m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"])

graph = onnx.helper.make_graph(
[
e_mul0_node,
e_div_node,
e_erf_node,
e_add_node,
e_mul1_node,
l2_rl2_node,
l2_clip_node,
l2_expand_node,
l2_div_node,
m_rm0_node,
m_sub_node,
m_pow_node,
m_rm1_node,
m_add0_node,
m_sqrt_node,
m_div_node,
m_mul_node,
m_add1_node,
],
"qnn_f32_model",
[root_inp],
[output],
initializer=[
one_const,
half_const,
root2_const,
axes_const,
eps_const,
shape_const,
scale_const,
bias_const,
two_const,
],
)
opset_imports = [
onnx.helper.make_opsetid("", 18),
]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
return onnx.shape_inference.infer_shapes(model)

def test_all_fusions(self):
"""
Test calling qnn_preprocess_model() with a model that supports all 3 fusions.
"""
model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
onnx.save_model(model, "model.onnx")
modified = qnn_preprocess_model("model.onnx", "model.qnn_pp.onnx", fuse_layernorm=True)

self.assertTrue(modified)

fused_model = onnx.load_model("model.qnn_pp.onnx")

# 3 fused Ops: Gelu, LpNorm, LayerNorm
self.assertEqual(len(fused_model.graph.node), 3)
expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"}
for node in fused_model.graph.node:
self.assertIn(node.op_type, expected_op_types)

# Should have added "com.microsoft" opset import because we added a Gelu.
ms_domain_opset = next((opset for opset in fused_model.opset_import if opset.domain == ms_domain), None)
self.assertIsNotNone(ms_domain_opset)
self.assertEqual(ms_domain_opset.version, 1)

def test_external_data(self):
"""
Test calling qnn_preprocess_model() with a model that uses external data.
The new preprocessed model should also have external data.
"""
model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
onnx.save_model(
model,
"model.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.bin",
size_threshold=0,
)
modified = qnn_preprocess_model(
"model.onnx",
"model.qnn_pp.onnx",
fuse_layernorm=True,
save_as_external_data=True,
all_tensors_to_one_file=True,
external_data_location="weights2.bin",
external_data_size_threshold=0,
)

self.assertTrue(modified)

# Model should still have external data.
self.assertTrue(model_has_external_data(Path("model.qnn_pp.onnx")))

fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False)

# 3 fused Ops: Gelu, LpNorm, LayerNorm
self.assertEqual(len(fused_model.graph.node), 3)
expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"}
for node in fused_model.graph.node:
self.assertIn(node.op_type, expected_op_types)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,36 @@ def test_get_qnn_qdq_config(self):
self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0))

def test_get_qnn_qdq_config_ext_data(self):
"""
Test that get_qnn_qdq_config() returns a config that enables external data
if the input model has external data.
"""

# Create model with a weight large enough (> 1024 bytes) to be stored externally.
large_weight = onnx.numpy_helper.from_array(np.random.random((1, 32, 32)).astype(np.float32), "weight")
graph = onnx.helper.make_graph(
[onnx.helper.make_node("Add", ["input", "weight"], ["output"])],
"add_ext_data",
[onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 32, 32))],
[onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1, 32, 32))],
initializer=[large_weight],
)
model = onnx.helper.make_model(
graph,
opset_imports=[onnx.helper.make_opsetid("", 18)],
)
onnx.save_model(
model,
"add_ext_data.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location="add_ext_data.bin",
)

qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations))
self.assertTrue(qnn_config.use_external_data_format)


if __name__ == "__main__":
t = TestTensorQuantOverridesOption()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sympy import Symbol, simplify
from sympy.parsing.sympy_parser import parse_expr

from onnxruntime.training.utils import PTable
from onnxruntime.training.utils import PTable, log_memory_usage

from ._execution_agent import TrainingAgent
from .options import _MemoryOptimizationLevel, _RuntimeOptions
Expand Down Expand Up @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger):

self._is_first_inspect = True

self._m = m

def is_enabled(self) -> bool:
"""Check if memory inspector is enabled."""
return self._is_enabled
Expand Down Expand Up @@ -621,29 +623,13 @@ def inspect_memory(self, cur_phase: Phase):
need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0)

if need_print:
cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
torch_mem_stat = torch.cuda.memory_stats()
cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))

mem_stats = [
["phase", _convert_phase_to_string(cur_phase)],
["allocated", cur_mem_allocated], # current memory allocated for tensors
["max allocated", max_mem_allocated], # peak memory allocated for tensors
["cached", cur_mem_cached], # current memory cached for the caching allocator
["max cached", max_mem_cached], # peak memory cached for caching allocator.
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
]

summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
for stat in mem_stats:
summ += f" | {stat[0]}: {stat[1]}"

self._logger.info(summ)
log_memory_usage(
_convert_phase_to_string(cur_phase),
rank_0_only=True,
step_info=f"step {self._current_step}",
logger=self._logger,
module=self._m,
)

if cur_phase == self._last_phase:
self._increase_step()
Expand All @@ -655,9 +641,6 @@ def inspect_memory(self, cur_phase: Phase):
def _increase_step(self):
self._current_step += 1

def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"

def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]:
mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map)

Expand Down
2 changes: 2 additions & 0 deletions orttraining/orttraining/python/training/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
unflatten_data_using_schema,
)
from onnxruntime.training.utils.torch_profile_utils import (
log_memory_usage,
nvtx_function_decorator,
torch_nvtx_range_pop,
torch_nvtx_range_push,
Expand All @@ -31,6 +32,7 @@
"torch_nvtx_range_push",
"torch_nvtx_range_pop",
"nvtx_function_decorator",
"log_memory_usage",
"pytorch_type_to_onnx_dtype",
"onnx_dtype_to_pytorch_dtype",
"pytorch_scalar_type_to_pytorch_dtype",
Expand Down
Loading

0 comments on commit 3b3468c

Please sign in to comment.