From 71551dacd510a9b85d6ef9fa12af319fa4687592 Mon Sep 17 00:00:00 2001 From: Xiaoyu <85524621+xiaoyu-work@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:40:58 -0700 Subject: [PATCH 01/11] Add ModelProto support for transformers optimize_model (#19990) ### Description Add `ModelProto` support as an input to transformers `optimize_model` API. ### Motivation and Context Currently, the `optimize_model` API only accepts a model path as the input model. However, for large models, saving and loading from disk can be time-consuming. By adding `ModelProto` as an input option to the `optimize_model` API, significant time can be saved. --- .../python/tools/transformers/onnx_utils.py | 55 +++++++++++++++ .../python/tools/transformers/optimizer.py | 69 +++++++++++++------ .../python/transformers/test_onnx_utils.py | 38 ++++++++++ 3 files changed, 140 insertions(+), 22 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/onnx_utils.py create mode 100644 onnxruntime/test/python/transformers/test_onnx_utils.py diff --git a/onnxruntime/python/tools/transformers/onnx_utils.py b/onnxruntime/python/tools/transformers/onnx_utils.py new file mode 100644 index 0000000000000..64fade9369395 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_utils.py @@ -0,0 +1,55 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from fusion_utils import NumpyHelper +from onnx import ModelProto, TensorProto +from onnx.external_data_helper import set_external_data +from onnx_model import OnnxModel + +from onnxruntime import OrtValue + + +def extract_raw_data_from_model(model: ModelProto): + """ + Extract external data from model and return the external data as a list of tuples (name, value). + Note this function does not handle external data that is not loaded into the model as raw data. + + Args: + model (ModelProto): the model proto to extract external data from. + Returns: + (external_names, external_values): a tuple of two lists of external data names and values. + """ + external_data = [] + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + name = initializer.name + + if initializer.HasField("raw_data"): + numpy_tensor = NumpyHelper.to_array(initializer) + ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) + external_data.append((name, ort_value)) + # mimic set_external_data + set_external_data(initializer, location="foo.bin") + initializer.name = name + initializer.ClearField("raw_data") + + return zip(*external_data) + + +def has_external_data(model: ModelProto): + """ + Check if the model has external data. + + Args: + model (ModelProto): the model proto to check for external data. + Returns: + bool: True if the model has external data, False otherwise. + """ + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: + return True + return False diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..068ccefef7d97 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -21,11 +21,12 @@ import logging import os import tempfile -from typing import Dict, List, Optional +from pathlib import Path +from typing import Dict, List, Optional, Union import coloredlogs from fusion_options import FusionOptions -from onnx import ModelProto, TensorProto, load_model +from onnx import ModelProto, load_model from onnx_model import OnnxModel from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel @@ -40,6 +41,9 @@ from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel +import onnxruntime +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + logger = logging.getLogger(__name__) # Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level @@ -64,7 +68,7 @@ def optimize_by_onnxruntime( - onnx_model_path: str, + onnx_model: Union[str, ModelProto], use_gpu: bool = False, optimized_model_path: Optional[str] = None, opt_level: Optional[int] = 99, @@ -80,7 +84,7 @@ def optimize_by_onnxruntime( Use onnxruntime to optimize model. Args: - onnx_model_path (str): the path of input onnx model. + onnx_model (str | ModelProto): the path of input onnx model or ModelProto. use_gpu (bool): whether the optimized model is targeted to run in GPU. optimized_model_path (str or None): the path of optimized model. opt_level (int): graph optimization level. @@ -95,8 +99,6 @@ def optimize_by_onnxruntime( assert opt_level in [1, 2, 99] from torch import version as torch_version - import onnxruntime - if ( use_gpu and provider is None @@ -105,9 +107,13 @@ def optimize_by_onnxruntime( ) ): logger.error("There is no gpu for onnxruntime to do optimization.") - return onnx_model_path + return onnx_model - model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) + model = ( + OnnxModel(load_model(onnx_model, load_external_data=False)) + if isinstance(onnx_model, str) + else OnnxModel(onnx_model) + ) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -125,7 +131,10 @@ def optimize_by_onnxruntime( sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL if optimized_model_path is None: - path_prefix = onnx_model_path[:-5] # remove .onnx suffix + if isinstance(onnx_model, str): + path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix + else: + path_prefix = "optimized_model" optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu") sess_options.optimized_model_filepath = optimized_model_path @@ -174,7 +183,20 @@ def optimize_by_onnxruntime( else: providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) + # For large model, extract external data from model and add to session options + if isinstance(onnx_model, ModelProto): + if has_external_data(onnx_model): + raise ValueError( + "ModelProto has external data not loaded into memory, ORT cannot create session. " + "Please load external data before calling this function. " + "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information." + ) + external_names, external_values = extract_raw_data_from_model(onnx_model) + sess_options.add_external_initializers(list(external_names), list(external_values)) + + # Inference session is only used to optimize the model. + onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model + onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -187,7 +209,7 @@ def optimize_by_fusion( num_heads: int = 0, hidden_size: int = 0, optimization_options: Optional[FusionOptions] = None, -): +) -> OnnxModel: """Optimize Model by graph fusion logic. Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable @@ -241,7 +263,7 @@ def optimize_by_fusion( def optimize_model( - input: str, + input: Union[str, ModelProto], model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, @@ -252,7 +274,7 @@ def optimize_model( verbose: bool = False, *, provider: Optional[str] = None, -): +) -> OnnxModel: """Optimize Model by OnnxRuntime and/or python fusion logic. ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). @@ -275,7 +297,7 @@ def optimize_model( For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters. Args: - input (str): input model path. + input (str | ModelProto): input model path or ModelProto. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. 0 allows detect the parameter from graph automatically. @@ -298,9 +320,9 @@ def optimize_model( if model_type not in MODEL_TYPES: logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.") - return OnnxModel(load_model(input)) + return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input) - (optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type] + (optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type] if opt_level is None: opt_level = default_opt_level @@ -316,11 +338,9 @@ def optimize_model( # Auto detect if input model has external data has_external_data_file = False - original_model = load_model(input, load_external_data=False) - for initializer in original_model.graph.initializer: - if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: - has_external_data_file = True - break + original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input + if has_external_data(original_model): + has_external_data_file = True del original_model if opt_level > 1: @@ -365,7 +385,12 @@ def optimize_model( if only_onnxruntime and not temp_model_path: logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True") - model = load_model(temp_model_path or input) + if temp_model_path is not None: + model = load_model(temp_model_path) + elif isinstance(input, str): + model = load_model(input) + else: + model = input if only_onnxruntime: optimizer = optimizer_class(model, num_heads, hidden_size) diff --git a/onnxruntime/test/python/transformers/test_onnx_utils.py b/onnxruntime/test/python/transformers/test_onnx_utils.py new file mode 100644 index 0000000000000..974991359795e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_utils.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import unittest + +import numpy +from onnx import ModelProto, TensorProto, helper +from onnx.external_data_helper import set_external_data + +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + + +class TestOnnxUtils(unittest.TestCase): + def test_extract_raw_data_from_model(self): + model = self._get_model_proto_with_raw_data(False) + external_names, external_values = extract_raw_data_from_model(model) + self.assertEqual(list(external_names), ["inputs"]) + self.assertEqual(len(external_values), 1) + self.assertEqual(external_values[0].numpy(), [0.0]) + + def test_has_external_data(self): + model = self._get_model_proto_with_raw_data() + self.assertTrue(has_external_data(model)) + + def test_has_external_data_with_no_external_data(self): + model = self._get_model_proto_with_raw_data(False) + self.assertFalse(has_external_data(model)) + + def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto: + input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None]) + output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None]) + raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes() + tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True) + if has_external_data: + set_external_data(tensor, location="foo.bin") + node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"]) + return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor])) From 3b4b99b90b7de7848e5c1e817ad19b32bf598b27 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 23 Mar 2024 08:53:50 -0700 Subject: [PATCH 02/11] Fix a bug in WASM's GEMM (#20023) ### Description Fix a bug in WASM's GEMM. The bug was found when running "ConvAddActivationFusionTests.ConvGemmDirect" unit test in a wasm build with address sanitizer enabled. When CountK=25, CountN=1, lda=25, ldc=1, the function I am modifying triggered a read out of bound error. The bug fix was provided by @fs-eire. --- onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp index 955b7c5deee9a..43a12b37e4ffa 100644 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp @@ -171,11 +171,9 @@ Return Value: if (k > 0) { Row0AElements0 = a[0]; - Row0AElements1 = a[1]; if (ProcessTwoRows) { Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; } BElements0 = MlasLoadFloat32x4(B + 0); From cdc5d72ba9dfcba38462d7fcfa7047fd6005fa5a Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Sat, 23 Mar 2024 11:05:08 -0700 Subject: [PATCH 03/11] [QDQ Quant] Support mixed-precision integer quantization via overrides (#19925) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Adds support for specifying mixed precision QDQ models via tensor quantization overrides. ### Motivation and Context This PR implements an approach for supported "mixed precision" models. The following figure demonstrates an example mixed precision model as defined in this PR. ![image](https://github.com/microsoft/onnxruntime/assets/19691973/40ae3bf9-b21a-4ba5-a1cd-41c1e08c21e7) A mixed precision QDQ model consists of regions with different activation/weight quantization data types. The boundary between regions converts between activation quantization data types (e.g., uint8 to uint16) using a DQ to Q sequence. The ability to specify regions with different quantization data types enables exploring the tradeoffs between accuracy and latency. A higher integer precision may improve accuracy at the expense of latency, so selectively promoting certain regions to a higher precision can aid in achieving a desirable balance in key metrics. #### Current support By default, the ORT quantizer supports specifying default activation and weight quantization data types for the entire model. A recent PR added support for specifying basic quantization overrides at the tensor level via the `extra_options["TensorQuantOverrides"]` configuration: ``` TensorQuantOverrides = dictionary : Default is {}. Set tensor quantization overrides. The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list contains a single dictionary. For per-channel quantization, the list contains a dictionary for each channel in the tensor. Each dictionary contains optional overrides with the following keys and values. 'quant_type' = QuantType : The tensor's quantization data type. 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also set `scale` or `zero_point`. 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also set `scale` or `zero_point`. 'rmax' = Float : Override the maximum real tensor value in calibration data. Invalid if also set `scale` or `zero_point`. 'rmin' = Float : Override the minimum real tensor value in calibration data. Invalid if also set `scale` or `zero_point`. ``` The tensor-level overrides are currently used to override the quantization type for weights/initializers or to set specific scale/zero-point values for a tensor (e.g., QNN requires Sigmoid to use a specific scale/zero-point at its output). However, these overrides are not typically used to override activation quantization types due in large part to operator data type constraints. Consider, for example, that all inputs and outputs to an Add operator must be of the same data type. Consequently, using tensor-level overrides to promote the Add’s output to 16-bits would force the inputs to also be overridden to 16-bit. In turn, this would have a cascading effect on potentially the entire graph. The solution implemented by this PR is to allow the specification of tensor boundaries where the activation quantization data type changes. #### The approach The following figure shows a model with a region that has been promoted to 16-bit from the default 8-bit activation type. ![image](https://github.com/microsoft/onnxruntime/assets/19691973/5998c301-ae20-4ac9-8a43-37f335cfcf8b) Note the following observations: - Op2’s output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type, while Op7 and Op8 consume the original u8 type. - Op3’s output is converted from u8 to u16. Op5 consumes the converted u16 type. - Op4’s output is just u16 (not converted). - Op5’s output is converted from u16 to u8. Op6 consumes the u8 type. The approach implemented by this PR uses the tensor-level quantization overrides to specify a tensor’s quantization type at both the producer and consumer ends. **The following shows the overrides necessary to create this mixed precision QDQ model.** ```python3 overrides = { “Op2_out”: [{“quant_type”: QUInt8, “convert”: {“quant_type”: QUInt16, “recv_nodes”: {“Op4”}}}], “Op3_out”: [{“quant_type”: QUInt8, “convert”: {“quant_type”: QUInt16, “recv_nodes”: {“Op5”}}}], “Op4_out”: [{“quant_type”: QUInt16}], “Op5_out”: [{“quant_type”: QUInt16, “convert”: {“quant_type”: QUInt8, “recv_nodes”: {“Op6”}}}] } ``` --- .../tools/quantization/base_quantizer.py | 323 +------- .../python/tools/quantization/onnx_model.py | 10 + .../tools/quantization/onnx_quantizer.py | 227 +++++- .../tools/quantization/operators/conv.py | 2 +- .../tools/quantization/operators/direct_q8.py | 4 +- .../tools/quantization/operators/gather.py | 4 +- .../tools/quantization/operators/gemm.py | 4 +- .../tools/quantization/operators/norm.py | 2 +- .../tools/quantization/operators/softmax.py | 38 +- .../tools/quantization/operators/split.py | 2 +- .../tools/quantization/qdq_quantizer.py | 711 ++++++++++++++++-- .../python/tools/quantization/registry.py | 3 +- .../quantization/tensor_quant_overrides.py | 214 ++++++ .../test/python/quantization/test_qdq.py | 594 ++++++++++++++- 14 files changed, 1744 insertions(+), 394 deletions(-) create mode 100644 onnxruntime/python/tools/quantization/tensor_quant_overrides.py diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 667d7047c1fbd..80617b7b5edaa 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -21,19 +21,15 @@ from .quant_utils import ( ONNX_TYPE_TO_NP_TYPE, TENSOR_NAME_QUANT_SUFFIX, - QuantizedValue, - QuantizedValueType, QuantType, - compute_scale_zp, - compute_scale_zp_float8, find_by_name, - get_qmin_qmax_for_qType, model_has_infer_metadata, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, tensor_proto_to_array, ) +from .tensor_quant_overrides import TensorQuantOverridesHelper class QuantizationParams: @@ -121,27 +117,17 @@ def __init__( self.opset_version = self.check_opset_version() - # Map of all original value names to quantized value names - self.quantized_value_map = {} + # Get tensor-level quantization overrides and ensure they are valid. + self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {})) - self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides() - self.quantization_params = self.calculate_quantization_params() - - # to store specified scale and zeropoint instead of calculated value, tensor_name->(scale, zeropoint) - self.used_scale_zp_map = {} - - def set_quant_scale_zp(self, tensor_name, value): - assert isinstance(value, tuple) and len(value) == 2, "value must be scale(float or float16) and zeropoint" - assert hasattr(value[0], "dtype") - assert tensor_name not in self.used_scale_zp_map, f"{tensor_name} has been setted before" - self.used_scale_zp_map[tensor_name] = value + initializer_names = {initzer.name for initzer in self.model.initializer()} + overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid( + initializer_names, self.value_infos.keys(), activation_qType + ) + if not overrides_valid: + raise ValueError(overrides_err) - def find_quant_scale_zp(self, input_name): - if input_name in self.used_scale_zp_map: - return self.used_scale_zp_map[input_name] - if self.parent is not None: - return self.parent.find_quantized_value(input_name) - return (None, None) + self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types() def quantize_model(self): raise NotImplementedError @@ -212,36 +198,16 @@ def check_opset_version(self): return opset_version - def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0): """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale """ - # Handle case where bias already in quantization map - if bias_name in self.quantized_value_map: - return self.quantized_value_map[bias_name].q_name - - # get scale for weight - weight_scale_name = self.quantized_value_map[weight_name].scale_name - weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) - weight_scale = tensor_proto_to_array(weight_initializer) - # get bias bias_initializer = find_by_name(bias_name, self.model.initializer()) bias_data = tensor_proto_to_array(bias_initializer) quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX - # get scale for input - if input_name in self.quantized_value_map: - input_scale_name = self.quantized_value_map[input_name].scale_name - elif input_name in self.quantization_params: - _, input_scale_name, _, _, _ = self._get_quantization_params(input_name) - else: - raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization") - - inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) - input_scale = tensor_proto_to_array(inputscale_initializer) - # quantize bias if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN: data = np.asarray(bias_data) @@ -293,22 +259,16 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0]) self.model.initializer_extend([packed_bias_zp_initializer]) - assert bias_name not in self.quantized_value_map - quantized_value = QuantizedValue( - bias_name, + return ( quantized_bias_name, quantized_bias_scale_name, quantized_bias_zp_name, - QuantizedValueType.Initializer, - 0 if bias_scale_data.size > 1 else None, - node_type=node_type, - node_qtype=node_qtype, + bias_scale_data, + node_type, + node_qtype, ) - self.quantized_value_map[bias_name] = quantized_value - - return quantized_bias_name - def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False): + def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False): """ :param weight: TensorProto initializer :param qType: type to quantize to @@ -316,22 +276,13 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei If keep_float_weight is False, quantize the weight, or don't quantize the weight. :return: quantized weight name, zero point name, scale name """ - # Find if this input is already quantized - if weight.name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight.name] - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" # Quantize weight data. Use quantization overrides if provided by the user. weight_data = tensor_proto_to_array(weight) - quant_overrides = self.get_per_tensor_quant_overrides(weight.name) + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name) if "quant_type" in quant_overrides: qType = quant_overrides["quant_type"].tensor_type # noqa: N806 @@ -392,19 +343,9 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) self.model.initializer_extend([q_weight_initializer]) - # Log entry for this quantized weight - quantized_value = QuantizedValue( - weight.name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight.name] = quantized_value return q_weight_name, zp_name, scale_name - def quantize_weight_per_channel( + def quantize_weight_per_channel_impl( self, weight_name, weight_qType, @@ -412,22 +353,13 @@ def quantize_weight_per_channel( reduce_range=True, keep_float_weight=False, ): - # Find if this input is already quantized - if weight_name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight_name] - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: raise ValueError("{} is not an initializer", weight_name) weights = tensor_proto_to_array(initializer) channel_count = weights.shape[channel_axis] - quant_overrides_for_channels = self.get_per_channel_quant_overrides(weight_name, channel_count) + quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(weight_name, channel_count) # If user provides per-channel quantization overrides, all channels must use the same quantization type. # So, just use the first channel's type. @@ -499,16 +431,6 @@ def quantize_weight_per_channel( zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale" - quantized_value = QuantizedValue( - weight_name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight_name] = quantized_value - # Update packed weight, zero point, and scale initializers zero_scale_shape = [initializer.dims[channel_axis]] scale_initializer = onnx.helper.make_tensor( @@ -530,194 +452,25 @@ def quantize_weight_per_channel( return q_weight_name, zp_name, scale_name - def _get_and_check_tensor_quant_overrides(self): - """ - Get tensor quantization overrides and check correctness. - """ - tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {}) - tensor_quant_override_types = set() - - # Validate that compatible/valid overrides are provided. - if tensor_quant_overrides: - initializer_names = self.model.get_initializer_name_set() - value_info_names = set(self.value_infos.keys()) - keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} - - for tensor_name, quant_overrides_list in tensor_quant_overrides.items(): - if tensor_name not in initializer_names and tensor_name not in value_info_names: - raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model") - - if not isinstance(quant_overrides_list, list): - raise ValueError(f"Tensor quantization overrides for '{tensor_name}' are not in a list") - - is_initializer = tensor_name in initializer_names - if not is_initializer and len(quant_overrides_list) > 1: - raise ValueError( - f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer" - ) - - quant_type = None - for index, quant_overrides in enumerate(quant_overrides_list): - if not isinstance(quant_overrides, dict): - raise ValueError( - f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict" - ) - - # For per-channel quantization, all channels must use the same quantization type. - # Therefore, if the user tries to override the quant_type for a channel, it must match in all - # other channels. - if index == 0: - quant_type = quant_overrides.get("quant_type") - if quant_type: - tensor_quant_override_types.add(quant_type.tensor_type) - elif quant_type != quant_overrides.get("quant_type"): - raise ValueError( - "Channel quantization types for tensor '{tensor_name}' do not match at index {index}." - ) - - has_scale = "scale" in quant_overrides - has_zero_point = "zero_point" in quant_overrides - - if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): - raise ValueError( - "Must provide both 'scale' and 'zero_point' if one of the overrides is provided" - ) - - if has_scale: - for key in keys_unsupported_with_scale_zp: - if key in quant_overrides: - raise ValueError( - f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'" - ) - - return tensor_quant_overrides, tensor_quant_override_types - - def get_per_tensor_quant_overrides(self, tensor_name): - quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}]) - num_overrides = len(quant_overrides_list) - if num_overrides > 1: - raise ValueError( - f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " - f"but found {num_overrides} per-channel overrides." - ) - - return quant_overrides_list[0] if num_overrides > 0 else {} - - def get_per_channel_quant_overrides(self, tensor_name, num_channels): - quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{} for i in range(num_channels)]) - - if len(quant_overrides_list) != num_channels: - raise ValueError( - f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " - f"but found {len(quant_overrides_list)} instead." - ) - - return quant_overrides_list - - def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): - """ - Create initializers and inputs in the graph for zero point and scale of output. - Zero point and scale values are obtained from self.quantization_params if specified. - parameter param_name: Name of the quantization parameter. - return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. - """ - zero_point_type = self.activation_qType - - if use_scale is None or use_zeropoint is None: - if self.quantization_params is None or param_name not in self.quantization_params: - logging.info(f'Quantization parameters for tensor:"{param_name}" not specified') - return False, "", "", "", "" - - params = self.quantization_params[param_name] - if not isinstance(params, QuantizationParams): - raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.") - if params is None or len(params) != 3: - raise ValueError( - "Quantization parameters should contain zero point, scale, quant type. " - f"Specified values for output {param_name}: {params}" - ) - - zero_point_values = np.array([params["zero_point"]]) - if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): - raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") - scale_values = np.array([params["scale"]]) - assert scale_values.dtype != np.float64 - zero_point_type = params["quant_type"] - else: - zero_point_values = np.array([use_zeropoint]) - scale_values = np.array([use_scale]) - params = self.quantization_params[param_name] - if "scale" in params: - dtype = params["scale"].dtype - scale_values = scale_values.astype(dtype) - assert scale_values.dtype != np.float64 - - zero_point_shape = [] - zero_point_name = param_name + "_zero_point" - scale_shape = [] - scale_name = param_name + "_scale" - - # Add initializers - init_zp = onnx.helper.make_tensor( - zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() - ) - self.model.add_initializer(init_zp) - if scale_values.dtype == np.float32: - scale_type = onnx.TensorProto.FLOAT - elif scale_values.dtype == np.float16: - scale_type = onnx.TensorProto.FLOAT16 - else: - raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") - init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) - self.model.add_initializer(init_scale) - - return True, scale_name, zero_point_name, scale_shape, zero_point_shape - - def calculate_quantization_params(self): + def adjust_tensor_ranges(self): if self.tensors_range is None: - return {} + return - # adjust tensor_ranges for input of Clip and Relu node for node in self.model.nodes(): - if node.op_type not in ["Clip", "Relu"]: - continue - if self.is_activation_symmetric: - continue - if not self.should_quantize_node(node): - continue - if len(self.model.input_name_to_nodes()[node.input[0]]) != 1: - continue - if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range: - continue - td = self.tensors_range[node.output[0]] - if not isinstance(td, TensorData): - raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") - self.tensors_range[node.input[0]] = td - - quantization_params = {} - for tensor_name in self.tensors_range: - td = self.tensors_range[tensor_name] - if not isinstance(td, TensorData): - raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") - - quant_overrides = self.get_per_tensor_quant_overrides(tensor_name) - - quant_type = self.activation_qType - if "quant_type" in quant_overrides: - quant_type = quant_overrides["quant_type"].tensor_type - - if "scale" in quant_overrides and "zero_point" in quant_overrides: - zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] - elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: - zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1]) - else: - rmin = quant_overrides.get("rmin", td.range_value[0]) - rmax = quant_overrides.get("rmax", td.range_value[1]) - symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) - reduce_range = quant_overrides.get("reduce_range", False) - qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) - zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) - - quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) - - return quantization_params + # adjust tensor_ranges for input of Clip and Relu node + if node.op_type in ["Clip", "Relu"]: + if self.is_activation_symmetric: + continue + if not self.should_quantize_node(node): + continue + if len(self.model.input_name_to_nodes()[node.input[0]]) != 1: + continue + if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range: + continue + td = self.tensors_range[node.output[0]] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") + self.tensors_range[node.input[0]] = td + # Adjust Softmax to range from 0.0 to 1.0 + elif node.op_type == "Softmax": + self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 716dd1eacec6a..174bf5fd1509c 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -441,6 +441,11 @@ def replace_input_of_all_nodes(self, old_input_name, new_input_name): for node in self.model.graph.node: ONNXModel.replace_node_input(node, old_input_name, new_input_name) + def replace_input_of_nodes(self, old_input_name, new_input_name, node_names_set): + for node in self.model.graph.node: + if node.name in node_names_set: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + @staticmethod def replace_node_output(node, old_output_name, new_output_name): assert isinstance(old_output_name, str) and isinstance(new_output_name, str) @@ -452,6 +457,11 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name): for node in self.model.graph.node: ONNXModel.replace_node_output(node, old_output_name, new_output_name) + def replace_output_of_nodes(self, old_output_name, new_output_name, node_names_set): + for node in self.model.graph.node: + if node.name in node_names_set: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + def remove_unused_constant(self): input_name_to_nodes = self.input_name_to_nodes() diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index e2044db04303d..4b76de6ecf1cb 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -5,30 +5,31 @@ # -------------------------------------------------------------------------- import logging +import numpy as np import onnx import onnx.numpy_helper from onnx import onnx_pb as onnx_proto -try: - from onnx.reference.op_run import to_array_extended -except ImportError: - # old version of onnx. - to_array_extended = None - -from .base_quantizer import BaseQuantizer +from .base_quantizer import BaseQuantizer, QuantizationParams +from .calibrate import TensorData from .onnx_model import ONNXModel from .quant_utils import ( TENSOR_NAME_QUANT_SUFFIX, QuantizationMode, QuantizedValue, + QuantizedValueType, __producer__, __version__, add_infer_metadata, attribute_to_kwarg, + compute_scale_zp, + compute_scale_zp_float8, find_by_name, + get_qmin_qmax_for_qType, get_qrange_for_qType, ms_domain, save_and_reload_model_with_shape_infer, + tensor_proto_to_array, ) from .registry import CreateOpQuantizer @@ -77,6 +78,7 @@ def __init__( self.fuse_dynamic_quant = self.opset_version > 10 self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] + self.new_nodes = [] self.graph_scope = "/" # for human readable debug information self.tensor_names = {} # in case the shape inference not totally working @@ -88,6 +90,8 @@ def __init__( if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") + self.quantization_params = self.calculate_quantization_params() + # QuantizeRange tensor name and zero tensor name for scale and zero point calculation. # Used when static is False self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8" @@ -97,6 +101,8 @@ def __init__( # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor) self.fixed_zero_zp_name = "fixed_zero_zp" + # Map of all original value names to quantized value names + self.quantized_value_map = {} # some output from nodes will be quantized, yet itself should be treat as existing so # no dequantized will be applied when needed later self.generated_value_names = self.model.get_non_initializer_inputs() @@ -494,6 +500,65 @@ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, i return input_scale_name, input_zp_name, [], [] + def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): + """ + Create initializers and inputs in the graph for zero point and scale of output. + Zero point and scale values are obtained from self.quantization_params if specified. + parameter param_name: Name of the quantization parameter. + return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. + """ + zero_point_type = self.activation_qType + + if use_scale is None or use_zeropoint is None: + if self.quantization_params is None or param_name not in self.quantization_params: + logging.info(f'Quantization parameters for tensor:"{param_name}" not specified') + return False, "", "", "", "" + + params = self.quantization_params[param_name] + if not isinstance(params, QuantizationParams): + raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.") + if params is None or len(params) != 3: + raise ValueError( + "Quantization parameters should contain zero point, scale, quant type. " + f"Specified values for output {param_name}: {params}" + ) + + zero_point_values = np.array([params["zero_point"]]) + if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): + raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") + scale_values = np.array([params["scale"]]) + assert scale_values.dtype != np.float64 + zero_point_type = params["quant_type"] + else: + zero_point_values = np.array([use_zeropoint]) + scale_values = np.array([use_scale]) + params = self.quantization_params[param_name] + if "scale" in params: + dtype = params["scale"].dtype + scale_values = scale_values.astype(dtype) + assert scale_values.dtype != np.float64 + + zero_point_shape = [] + zero_point_name = param_name + "_zero_point" + scale_shape = [] + scale_name = param_name + "_scale" + + # Add initializers + init_zp = onnx.helper.make_tensor( + zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + ) + self.model.add_initializer(init_zp) + if scale_values.dtype == np.float32: + scale_type = onnx_proto.TensorProto.FLOAT + elif scale_values.dtype == np.float16: + scale_type = onnx_proto.TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + self.model.add_initializer(init_scale) + + return True, scale_name, zero_point_name, scale_shape, zero_point_shape + def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None): """ Given an input for a node (which is not a initializer), this function @@ -564,6 +629,55 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None + def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + """ + Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale + """ + + # Handle case where bias already in quantization map + if bias_name in self.quantized_value_map: + return self.quantized_value_map[bias_name].q_name + + # get scale for weight + weight_scale_name = self.quantized_value_map[weight_name].scale_name + weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_initializer) + + # get scale for input + if input_name in self.quantized_value_map: + input_scale_name = self.quantized_value_map[input_name].scale_name + elif input_name in self.quantization_params: + _, input_scale_name, _, _, _ = self._get_quantization_params(input_name) + else: + raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization") + + inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(inputscale_initializer) + + ( + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + bias_scale_data, + node_type, + node_qtype, + ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, beta) + + assert bias_name not in self.quantized_value_map + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + node_type=node_type, + node_qtype=node_qtype, + ) + self.quantized_value_map[bias_name] = quantized_value + + return quantized_bias_name + def contains_tensor(self, tensor_name): """ only check for value info and newly generated tensor names, initializers are checked separately @@ -721,6 +835,71 @@ def __quantize_inputs( return quantized_input_names, zero_point_names, scale_names, nodes + def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False): + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ + # Find if this input is already quantized + if weight.name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight.name] + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( + weight, qType, reduce_range, keep_float_weight + ) + + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight.name] = quantized_value + return q_weight_name, zp_name, scale_name + + def quantize_weight_per_channel( + self, + weight_name, + weight_qType, + channel_axis, + reduce_range=True, + keep_float_weight=False, + ): + # Find if this input is already quantized + if weight_name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight_name] + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( + weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight + ) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight_name] = quantized_value + + return q_weight_name, zp_name, scale_name + def _dequantize_value(self, value_name): """ Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize @@ -771,3 +950,37 @@ def _dequantize_outputs(self): dequantize_node = self._dequantize_value(output.name) if dequantize_node is not None: self.new_nodes.append(dequantize_node) + + def calculate_quantization_params(self): + if self.tensors_range is None: + return None + + self.adjust_tensor_ranges() + + quantization_params = {} + for tensor_name in self.tensors_range: + td = self.tensors_range[tensor_name] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") + + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + + quant_type = self.activation_qType + if "quant_type" in quant_overrides: + quant_type = quant_overrides["quant_type"].tensor_type + + if "scale" in quant_overrides and "zero_point" in quant_overrides: + zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] + elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: + zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1]) + else: + rmin = quant_overrides.get("rmin", td.range_value[0]) + rmax = quant_overrides.get("rmax", td.range_value[1]) + symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) + reduce_range = quant_overrides.get("reduce_range", False) + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) + zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) + + quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index 06204585ba1ca..7054173450569 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -252,4 +252,4 @@ def quantize(self): self.quantizer.quantize_weight_tensor(node.input[1]) if len(node.input) == 3: - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) diff --git a/onnxruntime/python/tools/quantization/operators/direct_q8.py b/onnxruntime/python/tools/quantization/operators/direct_q8.py index c14532b96acbc..ae9679ae8ec7a 100644 --- a/onnxruntime/python/tools/quantization/operators/direct_q8.py +++ b/onnxruntime/python/tools/quantization/operators/direct_q8.py @@ -73,6 +73,6 @@ def quantize(self): if self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(self.node.input[0]) if not self.disable_qdq_for_node_output: - self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0]) + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) elif self.quantizer.is_tensor_quantized(self.node.input[0]) and not self.disable_qdq_for_node_output: - self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0]) + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) diff --git a/onnxruntime/python/tools/quantization/operators/gather.py b/onnxruntime/python/tools/quantization/operators/gather.py index f48725d1e428f..e390e874a2662 100644 --- a/onnxruntime/python/tools/quantization/operators/gather.py +++ b/onnxruntime/python/tools/quantization/operators/gather.py @@ -59,6 +59,6 @@ def quantize(self): if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(node.input[0]) - self.quantizer.quantize_activation_tensor(node.output[0], node.input[0]) + self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name) elif self.quantizer.is_tensor_quantized(node.input[0]): - self.quantizer.quantize_activation_tensor(node.output[0], node.input[0]) + self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name) diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py index d269c8fb47bd1..df24e256aa7fc 100644 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ b/onnxruntime/python/tools/quantization/operators/gemm.py @@ -153,7 +153,9 @@ def quantize(self): if len(node.input) == 3: if self.quantizer.is_input_a_initializer(node.input[2]): - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node)) + self.quantizer.quantize_bias_tensor( + node.name, node.input[2], node.input[0], node.input[1], get_beta(self.node) + ) set_default_beta(self.node) else: logging.warning( diff --git a/onnxruntime/python/tools/quantization/operators/norm.py b/onnxruntime/python/tools/quantization/operators/norm.py index e825fe6075601..3c14c926a7e75 100644 --- a/onnxruntime/python/tools/quantization/operators/norm.py +++ b/onnxruntime/python/tools/quantization/operators/norm.py @@ -29,7 +29,7 @@ def quantize(self): self.quantizer.quantize_activation_tensor(node.input[1]) # Bias - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) # Output if not self.disable_qdq_for_node_output: diff --git a/onnxruntime/python/tools/quantization/operators/softmax.py b/onnxruntime/python/tools/quantization/operators/softmax.py index 61a69ab3649dd..4b39fae8ac063 100644 --- a/onnxruntime/python/tools/quantization/operators/softmax.py +++ b/onnxruntime/python/tools/quantization/operators/softmax.py @@ -1,18 +1,8 @@ -import numpy as np import onnx import onnx.helper -from ..quant_utils import ( - TENSOR_NAME_QUANT_SUFFIX, - QuantizedValue, - QuantizedValueType, - attribute_to_kwarg, - compute_scale_zp, - get_qmin_qmax_for_qType, - ms_domain, -) +from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase class QLinearSoftmax(QuantOperatorBase): @@ -82,29 +72,3 @@ def quantize(self): nodes.append(qnode) self.quantizer.new_nodes += nodes return None - - -class QDQSoftmax(QDQOperatorBase): - def quantize(self): - super().quantize() - output_name = self.node.output[0] - quant_overrides = self.quantizer.get_per_tensor_quant_overrides(output_name) - - quant_type = self.quantizer.activation_qType - if "quant_type" in quant_overrides: - quant_type = quant_overrides["quant_type"].tensor_type - - if "scale" in quant_overrides and "zero_point" in quant_overrides: - out_zero_point, out_scale = quant_overrides["zero_point"], quant_overrides["scale"] - else: - # Unless overridden by the user, force Softmax to range from 0.0 to 1.0 - qparams = self.quantizer.quantization_params[output_name] - dtype = qparams.data["scale"].dtype - rmin = quant_overrides.get("rmin", np.array(0, dtype=dtype)) - rmax = quant_overrides.get("rmax", np.array(1, dtype=dtype)) - symmetric = quant_overrides.get("symmetric", self.quantizer.is_activation_symmetric) - reduce_range = quant_overrides.get("reduce_range", False) - qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) - out_zero_point, out_scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=symmetric) - - self.quantizer.set_quant_scale_zp(output_name, (out_scale, out_zero_point)) diff --git a/onnxruntime/python/tools/quantization/operators/split.py b/onnxruntime/python/tools/quantization/operators/split.py index c36b767f5abcc..74fc30cd075d2 100644 --- a/onnxruntime/python/tools/quantization/operators/split.py +++ b/onnxruntime/python/tools/quantization/operators/split.py @@ -60,4 +60,4 @@ def quantize(self): self.quantizer.quantize_activation_tensor(node.input[0]) if not self.disable_qdq_for_node_output: for output in node.output: - self.quantizer.quantize_activation_tensor(output, node.input[0]) + self.quantizer.quantize_output_same_as_input(output, node.input[0], node.name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 1875c552fab9c..c323c6fec545a 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -3,15 +3,21 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import logging +from dataclasses import dataclass from enum import Enum +from typing import Any +import numpy as np import onnx import onnx.numpy_helper from onnx import TensorProto from onnx import onnx_pb as onnx_proto -from .base_quantizer import BaseQuantizer +from .base_quantizer import BaseQuantizer, QuantizationParams +from .calibrate import TensorData from .quant_utils import ( DEQUANT_OP_NAME, QUANT_OP_NAME, @@ -24,8 +30,12 @@ add_quant_input_suffix, add_quant_output_suffix, add_quant_suffix, + compute_scale_zp, + compute_scale_zp_float8, find_by_name, + get_qmin_qmax_for_qType, ms_domain, + tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -36,6 +46,17 @@ class QDQQuantTensorType(Enum): BIAS = 2 +# Holds the name of the node input from which a node output will share the +# same quantization param initializers (zero-point and scale initializers). +# Ex: A Transpose node's output will use the same quant param initializers used at the input. +@dataclass +class QDQQuantParamProvider: + input_name: str + node_name: str + + +# Holds information for tensors that have been marked for quantization by operator quantizers. +# Does not hold information for bias tensors. class QDQTensorQuantInfo: def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None): self.tensor_type = tensor_type @@ -46,6 +67,64 @@ def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provide self.data_type = data_type +# Holds information for bias tensors that have been marked for quantization by operator quantizers. +@dataclass +class QDQBiasQuantInfo: + node_name: str + input_name: str + weight_name: str + beta: float + + +# Holds quantization parameter values (scale, zp) for a tensor. +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorQuantParams: + original: QuantizationParams # Generated by producer node. + converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes. + converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type. + + +# Holds scale and zero_point initializer TensorProtos. +@dataclass +class QDQScaleZpInitializers: + scale: TensorProto + zero_point: TensorProto + + +# Holds all scale and zero-point initializers for a tensor. +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorScaleZpInitializers: + original: QDQScaleZpInitializers + converted: QDQScaleZpInitializers | None + converted_recv_nodes: set[str] | None + + +# Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.). +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorQuantizedValue: + original: QuantizedValue + converted: QuantizedValue | None + converted_recv_nodes: set[str] | None + + def get_for_consumer(self, consumer_node_name) -> QuantizedValue: + if self.converted is None: # Quantized value is not converted, return original + return self.original + + if self.converted_recv_nodes is None: # All consumers receive the converted value + return self.converted + + # Check if consumer node name is in the list of nodes that + # receive the converted quantization value. If not, return the original value generated + # by the tensor's producer. + return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original + + class QDQQuantizer(BaseQuantizer): def __init__( self, @@ -74,7 +153,7 @@ def __init__( extra_options, ) self.tensors_to_quantize = {} - self.bias_to_quantize = [] + self.bias_to_quantize = {} self.nodes_to_remove = [] @@ -100,8 +179,7 @@ def __init__( # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) - if self.dedicated_qdq_pair: - self.tensor_to_its_receiving_nodes = {} + self.tensor_to_its_receiving_nodes = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) @@ -112,7 +190,7 @@ def __init__( # if the activation or weight types are 16-bit integers. # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. int16_types = (TensorProto.UINT16, TensorProto.INT16) - overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types) + overrides_have_int16 = any(t.tensor_type in int16_types for t in self.tensor_quant_override_qtypes) if not self.qdq_op_domain and ( self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 ): @@ -123,6 +201,11 @@ def __init__( ) self.qdq_op_domain = ms_domain + self.quantization_params = self.calc_graph_quant_params() + + # Map of all original value names to quantized value names + self.quantized_value_map = {} + def _get_tensor_type(self, tensor_name): """ Check if tensor can be quantized @@ -158,45 +241,71 @@ def _is_tensor_quantizable(self, tensor_name): return False - def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=QDQQuantTensorType.ACTIVATION): + def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION): """ - Quantize tensors. If quant_param_tensor is not None, tensor with name tensor_name will be quantized with same - quantization parameters as tensor quant_param_tensor + Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that + want to quantize a tensor (i.e., "mark" a tensor for quantization). + + If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same + quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output + will typically use the same quantization parameter initializers used at the Transpose node's input. Args: tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter + quant_sharing_provider: name of the tensor and node that provides quantization parameter tensor_type: QDQQuantTensorType default ACTIVATION """ if self._is_tensor_quantizable(tensor_name): - if quant_sharing_param: + if quant_sharing_provider: + if not isinstance(quant_sharing_provider, QDQQuantParamProvider): + raise TypeError( + f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}." + ) + data_type = self._get_tensor_type(tensor_name) self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( - tensor_type=tensor_type, quant_para_provider=quant_sharing_param, data_type=data_type + tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type ) elif tensor_name not in self.tensors_to_quantize: data_type = self._get_tensor_type(tensor_name) self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type) - def quantize_activation_tensor(self, tensor_name, quant_sharing_param=None): + def quantize_activation_tensor(self, tensor_name: str): """ - Quantize Activation Tensor + Adds a tensor to the list of tensors to quantize. Called by op quantizers that + want to quantize a tensor (i.e., "mark" a tensor for quantization). + Args: tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter - """ - return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.ACTIVATION) + return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION) - def quantize_weight_tensor(self, tensor_name, quant_sharing_param=None): + def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str): """ - Quantize Weight Tensor + Adds a tensor to the list of tensors to quantize. Called by op quantizers that + want to quantize an output tensor using the same quantization parameters as one of the node's inputs. + + Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at + the Transpose node's input. + Args: - tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter + output_name: name of the node output to quantize so that it uses the same quantization params as an input. + input_name: name of the node input from which the output tensor will get its quantization params. + node_name: name of the node that consumes `input_name`. + """ + return self.__quantize_tensor( + output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION + ) + def quantize_weight_tensor(self, tensor_name: str): """ - return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.WEIGHT) + Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that + want to quantize a weight (i.e., "mark" a weight for quantization). + + Args: + tensor_name: name of the weight to quantize + """ + return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT) def quantize_weight_tensor_per_channel(self, tensor_name, axis): weight = find_by_name(tensor_name, self.model.initializer()) @@ -208,7 +317,19 @@ def quantize_weight_tensor_per_channel(self, tensor_name, axis): else: logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.") - def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): + def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0): + """ + Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that + want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta. + TODO: Explain the reasoning for using this formula. + + Args: + node_name: name of the node that consumes the bias, input, and weight tensors. + bias_name: name of the bias tensor to quantize. + input_name: name of the input tensor whose scale is used to compute the bias's scale. + weight_name: name of the weight tensor whose scale is used to compute the bias's scale. + beta: Multiplier used to compute the bias's scale. + """ # If the user provided quantization overrides for this tensor, treat it as a regular weight. if self.tensor_quant_overrides.get(bias_name): logging.info( @@ -223,7 +344,10 @@ def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): weight = find_by_name(bias_name, self.model.initializer()) if weight is not None: if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): - self.bias_to_quantize.append((bias_name, input_name, weight_name, beta)) + if bias_name not in self.bias_to_quantize: + self.bias_to_quantize[bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) + else: + logging.warning(f"Bias {bias_name} has already been marked for quantization") else: logging.warning(f"Expected {bias_name} to be a weight") @@ -239,11 +363,10 @@ def quantize_model(self): op_quantizer = CreateQDQQuantizer(self, node) op_quantizer.quantize() - if self.dedicated_qdq_pair: - for tensor_name in node.input: - if tensor_name not in self.tensor_to_its_receiving_nodes: - self.tensor_to_its_receiving_nodes[tensor_name] = [] - self.tensor_to_its_receiving_nodes[tensor_name].append(node) + for tensor_name in node.input: + if tensor_name not in self.tensor_to_its_receiving_nodes: + self.tensor_to_its_receiving_nodes[tensor_name] = [] + self.tensor_to_its_receiving_nodes[tensor_name].append(node) self._quantize_normal_tensors() self._quantize_sharing_param_tensors() @@ -263,6 +386,8 @@ def quantize_model(self): def try_replacing_upstream_output(self, upstream_output_name, output_name): if ( output_name in self.quantization_params + and self.quantization_params[output_name].converted is None + and self.quantization_params[upstream_output_name].converted is None and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1 and not self.model.is_graph_output(upstream_output_name) and not self.model.is_graph_input(upstream_output_name) @@ -273,6 +398,50 @@ def try_replacing_upstream_output(self, upstream_output_name, output_name): return True return False + def _create_q_node( + self, + q_input: str, + q_output: str, + quant_node_name: str, + scale_name: str, + zp_name: str, + axis: int | None = None, + ): + """ + Creates a QuantizeLinear node and adds it to the model. + """ + qlinear_node = onnx.helper.make_node( + QUANT_OP_NAME, + [q_input, scale_name, zp_name], + [q_output], + quant_node_name, + axis=axis, + domain=self.qdq_op_domain, + ) + self.model.add_nodes([qlinear_node]) + + def _create_dq_node( + self, + dq_input: str, + dq_output: str, + dequant_node_name: str, + scale_name: str, + zp_name: str, + axis: int | None = None, + ): + """ + Creates a DequantizeLinear node and adds it to the model. + """ + dequant_node = onnx.helper.make_node( + DEQUANT_OP_NAME, + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name, + axis=axis, + domain=self.qdq_op_domain, + ) + self.model.add_nodes([dequant_node]) + def _create_qdq_nodes( self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None ): @@ -383,7 +552,7 @@ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_ty QuantizedValueType.Input, scale_type=data_type, ) - self.quantized_value_map[tensor_name] = quantized_value + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None) else: q_input = tensor_name dq_output = add_dequant_output_suffix(tensor_name) @@ -413,9 +582,165 @@ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_ty QuantizedValueType.Input, scale_type=data_type, ) - self.quantized_value_map[tensor_name] = quantized_value + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + def _add_qdq_ops_for_converted_activation( + self, + tensor_name, + first_scale_name, + first_zp_name, + scale_data_type, + convert_scale_name, + convert_zp_name, + convert_recv_nodes, + ): + """ + Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the + original data type from the producer, while other consumers use the converted data type. + This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16). + + T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float' + where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) ---> + + This function handles the following scenarios: + + 1) Tensor T is not a graph output; all consumers use the converted type + + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> + + 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type + + ---> Q1 -+-> DQ1 ---> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + + 3) Tensor T is a graph output; all consumers use the converted type + + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> + | + +-> + + 4) Tensor T is a graph output; some consumers use the original type, others use the converted type + + ---> Q1 -+-> DQ1 -+-> + | | + | +-> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + tensor_recv_nodes = set([node.name for node in self.tensor_to_its_receiving_nodes[tensor_name]]) + + if ( + self.dedicated_qdq_pair + and tensor_name in self.tensor_to_its_receiving_nodes + and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1 + ): + # TODO: Add support for dedicated_qdq_pair if/when needed. + raise ValueError( + "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled" + ) + + # Determine which nodes consume the original quantized type and which nodes + # consume the converted quantized type. + original_recv_nodes = tensor_recv_nodes + if convert_recv_nodes is None: # In this case, all consumers receive the converted type. + convert_recv_nodes = tensor_recv_nodes + original_recv_nodes = set() + else: + original_recv_nodes = original_recv_nodes - convert_recv_nodes + + all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes) + is_graph_output = self.model.is_graph_output(tensor_name) + + # Create first Q op. + first_q_input = tensor_name + if is_graph_output: + first_q_input = add_quant_input_suffix(tensor_name) + self.model.replace_output_of_all_nodes(tensor_name, first_q_input) + + first_q_output = add_quant_output_suffix(tensor_name) + self._create_q_node( + first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name + ) + + # Create first DQ op. + first_dq_output = add_dequant_output_suffix(tensor_name) + if is_graph_output and not all_use_converted: + first_dq_output = tensor_name + if original_recv_nodes and first_dq_output != tensor_name: + self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes) + + self._create_dq_node( + first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name + ) + + # Create parallel clone of first DQ op if _not all_ consumers use the converted type. + # --> DQ1' --> Q2 --> DQ2 --> + # + # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with + # it by some EPs (e.g., QNN) without breaking other "node units". + # Ex QNN fusion: + # --> Convert (fused) --> DQ2 --> + second_q_input = first_dq_output + if not all_use_converted: + second_q_input = add_quant_input_suffix(f"{tensor_name}_convert") + self._create_dq_node( + first_q_output, + second_q_input, + add_dequant_suffix(f"{tensor_name}_convert_clone"), + first_scale_name, + first_zp_name, + ) + + # Create second Q op. + second_q_output = add_quant_output_suffix(f"{tensor_name}_convert") + self._create_q_node( + second_q_input, + second_q_output, + add_quant_suffix(f"{tensor_name}_convert"), + convert_scale_name, + convert_zp_name, + ) + + # Create second DQ op. + second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert") + if is_graph_output and all_use_converted: + second_dq_output = tensor_name + if convert_recv_nodes and second_dq_output != tensor_name: + self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes) + self._create_dq_node( + second_q_output, + second_dq_output, + add_dequant_suffix(f"{tensor_name}_convert"), + convert_scale_name, + convert_zp_name, + ) + + # Store in quantized_value_map + original_quantized_value = QuantizedValue( + tensor_name, + first_dq_output, + first_scale_name, + first_zp_name, + QuantizedValueType.Input, + scale_type=scale_data_type, + ) + converted_quantized_value = QuantizedValue( + tensor_name, + second_dq_output, + convert_scale_name, + convert_zp_name, + QuantizedValueType.Input, + scale_type=scale_data_type, + ) + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue( + original_quantized_value, converted_quantized_value, convert_recv_nodes + ) def _quantize_normal_tensors(self): + """ + Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers. + """ for tensor_name, tensor_info in self.tensors_to_quantize.copy().items(): if tensor_name in self.quantized_value_map: continue @@ -426,53 +751,105 @@ def _quantize_normal_tensors(self): if initializer: self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis) else: - used_scale, used_zp = self.find_quant_scale_zp(tensor_name) - if used_scale is not None and not hasattr(used_scale, "dtype"): - raise TypeError( - f"Unexpected type {type(used_scale)} for used_scale and tensor_name={tensor_name!r}" - ) - data_found, scale_name, zp_name, _, _ = self._get_quantization_params( - tensor_name, used_scale, used_zp - ) - - if not data_found: + tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name) + if not tensor_qparam_initializers: raise ValueError( f"Quantization parameters are not specified for param {tensor_name}. " "In static mode quantization params for inputs and outputs of nodes to be quantized are required." ) - self._add_qdq_pair_for_activation(tensor_name, scale_name, zp_name, data_type=tensor_info.data_type) + if tensor_qparam_initializers.converted is None: + # Normal case: --> Q --> DQ --> + self._add_qdq_pair_for_activation( + tensor_name, + tensor_qparam_initializers.original.scale.name, + tensor_qparam_initializers.original.zero_point.name, + data_type=tensor_info.data_type, + ) + else: + # Conversion case: ---> Q1 -+-> DQ1 --> + # | + # +-> DQ1' --> Q2 --> DQ2 --> + assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type + self._add_qdq_ops_for_converted_activation( + tensor_name, + tensor_qparam_initializers.original.scale.name, + tensor_qparam_initializers.original.zero_point.name, + tensor_info.data_type, + tensor_qparam_initializers.converted.scale.name, + tensor_qparam_initializers.converted.zero_point.name, + tensor_qparam_initializers.converted_recv_nodes, + ) del self.tensors_to_quantize[tensor_name] def _quantize_sharing_param_tensors(self): + """ + Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers. + Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor. + For example, a Transpose node's output tensor will typically want to use the same quantization parameter + initializers as the Transpose node's input. + """ while self.tensors_to_quantize: for tensor_name, tensor_info in self.tensors_to_quantize.copy().items(): - tensor_provider_name = tensor_info.quant_para_provider - if tensor_provider_name in self.quantized_value_map: + quant_provider = tensor_info.quant_para_provider + if quant_provider and quant_provider.input_name in self.quantized_value_map: del self.tensors_to_quantize[tensor_name] - quantized_value = self.quantized_value_map[tensor_provider_name] - # Quantize the input - initializer = find_by_name(tensor_name, self.model.initializer()) - if initializer is not None: + quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer( + quant_provider.node_name + ) + if self.is_input_a_initializer(tensor_name): raise ValueError("Quantization parameter shared mode is not supported for weight yet") - self._add_qdq_pair_for_activation(tensor_name, quantized_value.scale_name, quantized_value.zp_name) + + # Need to check if this tensor's quant_type is converted for some consumers. + # If so, create new scale/zp initializers for these consumers. + converted_qparam_inits = None + converted_recv_nodes = None + if tensor_name in self.quantization_params: + tensor_params = self.quantization_params[tensor_name] + if tensor_params.converted: + converted_qparam_inits = self._make_scale_zp_initializers( + tensor_name, tensor_params.converted, "_convert" + ) + converted_recv_nodes = tensor_params.converted_recv_nodes + + if converted_qparam_inits is None: + # Normal case: --> Q_shared --> DQ_shared --> + self._add_qdq_pair_for_activation( + tensor_name, quantized_value.scale_name, quantized_value.zp_name + ) + else: + # Conversion case: ---> Q_shared -+-> DQ_shared --> + # | + # +-> DQ_shared' --> Q2 --> DQ2 --> + self._add_qdq_ops_for_converted_activation( + tensor_name, + quantized_value.scale_name, + quantized_value.zp_name, + converted_qparam_inits.scale.data_type, + converted_qparam_inits.scale.name, + converted_qparam_inits.zero_point.name, + converted_recv_nodes, + ) def _quantize_bias_tensors(self): - for bias_name, input_name, weight_name, beta in self.bias_to_quantize: + """ + Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers. + """ + for bias_name, bias_info in self.bias_to_quantize.items(): if bias_name in self.quantized_value_map: continue # Quantize the input - self.quantize_bias_static(bias_name, input_name, weight_name, beta) + self.quantize_bias_static(bias_name, bias_info) init = find_by_name(bias_name, self.model.initializer()) self.model.remove_initializer(init) - quant_value = self.quantized_value_map[bias_name] + quant_value = self.quantized_value_map[bias_name].original if quant_value.node_type == "Cast": # simple cast to float 16 and not DequantizeLinear # cublasLtMatmul only supports (b)float16, float bias. if not isinstance(init.data_type, int): - raise TypeError(f"Unexpected type {type(init.data_type)} for input={input_name!r}") + raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}") node_name = add_dequant_suffix(bias_name) dequant_node = onnx.helper.make_node( "Cast", @@ -511,5 +888,233 @@ def _quantize_bias_tensors(self): raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.") self.model.add_node(dequant_node) - def is_tensor_quantized(self, tensor_name): + def is_tensor_quantized(self, tensor_name: str): return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize + + def quantize_initializer( + self, + weight: onnx.TensorProto, + qType: onnx.TensorProto.DataType, + reduce_range: bool = False, + keep_float_weight: bool = False, + ) -> tuple[str, str, str]: + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ + # Find if this input is already quantized + if weight.name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight.name].original + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( + weight, qType, reduce_range, keep_float_weight + ) + + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) + return q_weight_name, zp_name, scale_name + + def quantize_weight_per_channel( + self, + weight_name: str, + weight_qType: onnx.TensorProto.DataType, + channel_axis: int, + reduce_range: bool = True, + keep_float_weight: bool = False, + ) -> tuple[str, str, str]: + # Find if this input is already quantized + if weight_name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight_name].original + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( + weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight + ) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + return q_weight_name, zp_name, scale_name + + def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str: + """ + Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale + """ + + # Handle case where bias already in quantization map + if bias_name in self.quantized_value_map: + return self.quantized_value_map[bias_name].original.q_name + + # get scale for weight + weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name + weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_initializer) + + # get scale for input + input_scale_name = ( + self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name + ) + inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(inputscale_initializer) + + ( + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + bias_scale_data, + node_type, + node_qtype, + ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta) + + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + node_type=node_type, + node_qtype=node_qtype, + ) + self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + return quantized_bias_name + + def _make_scale_zp_initializers( + self, param_name: str, params: QuantizationParams, init_name_suffix: str = "" + ) -> QDQScaleZpInitializers: + """ + Creates and returns scale and zero-point initializers for the given quantization params. The initializers are + named: + - {param_name}_zero_point{init_name_suffix} + - {param_name}_scale{init_name_suffix} + """ + zero_point_values = np.array([params["zero_point"]]) + if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): + raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") + scale_values = np.array([params["scale"]]) + assert scale_values.dtype != np.float64 + zero_point_type = params.data.get("quant_type", self.activation_qType) + + zero_point_shape = [] + zero_point_name = param_name + "_zero_point" + init_name_suffix + scale_shape = [] + scale_name = param_name + "_scale" + init_name_suffix + + # Add initializers to model + init_zp = onnx.helper.make_tensor( + zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + ) + self.model.add_initializer(init_zp) + + if scale_values.dtype == np.float32: + scale_type = onnx_proto.TensorProto.FLOAT + elif scale_values.dtype == np.float16: + scale_type = onnx_proto.TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + self.model.add_initializer(init_scale) + + return QDQScaleZpInitializers(init_scale, init_zp) + + def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None: + """ + Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted + to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise, + only one pair of zp/scale initializers is created. + """ + if self.quantization_params is None or tensor_name not in self.quantization_params: + logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified') + return None + + tensor_params = self.quantization_params[tensor_name] + if not isinstance(tensor_params, QDQTensorQuantParams): + raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.") + + original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original) + converted_inits = ( + self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert") + if tensor_params.converted + else None + ) + + return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes) + + def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams: + """ + Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional + user-provided overrides. + """ + quant_type = self.activation_qType + if "quant_type" in quant_overrides: + quant_type = quant_overrides["quant_type"].tensor_type + + if "scale" in quant_overrides and "zero_point" in quant_overrides: + zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] + elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: + zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1]) + else: + rmin = quant_overrides.get("rmin", tensor_data.range_value[0]) + rmax = quant_overrides.get("rmax", tensor_data.range_value[1]) + symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) + reduce_range = quant_overrides.get("reduce_range", False) + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) + zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) + + return QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + + def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: + """ + Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range + and optional user-provided overrides. + """ + if self.tensors_range is None: + return {} + + self.adjust_tensor_ranges() + + quantization_params = {} + for tensor_name in self.tensors_range: + td = self.tensors_range[tensor_name] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") + + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + original = self.calc_quant_params(td, quant_overrides) + converted = None + converted_recv_nodes = None + + if "convert" in quant_overrides: + converted = self.calc_quant_params(td, quant_overrides["convert"]) + converted_recv_nodes = quant_overrides["convert"].get("recv_nodes") + + quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index a693f4192bc2b..b00e830a2a366 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -18,7 +18,7 @@ from .operators.pooling import QLinearPool from .operators.qdq_base_operator import QDQOperatorBase from .operators.resize import QDQResize, QResize -from .operators.softmax import QDQSoftmax, QLinearSoftmax +from .operators.softmax import QLinearSoftmax from .operators.split import QDQSplit, QSplit from .operators.where import QDQWhere, QLinearWhere from .quant_utils import QuantizationMode @@ -79,7 +79,6 @@ "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, - "Softmax": QDQSoftmax, "Where": QDQWhere, "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py new file mode 100644 index 0000000000000..610b96b9d7937 --- /dev/null +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -0,0 +1,214 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import json +from collections.abc import MutableMapping +from typing import Any + +from .quant_utils import QuantType + + +class TensorQuantOverridesHelper(MutableMapping): + """ + Utility wrapper over the tensor quantization overrides passed via extra_options. + """ + + def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]): + self.overrides = raw_overrides + self.quant_types = None + + def get_per_tensor_overrides(self, tensor_name: str) -> dict[str, Any]: + overrides_list = self.overrides.get(tensor_name, [{}]) + num_overrides = len(overrides_list) + if num_overrides > 1: + raise ValueError( + f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " + f"but found {num_overrides} per-channel overrides." + ) + + return overrides_list[0] if num_overrides > 0 else {} + + def get_per_channel_overrides( + self, + tensor_name: str, + num_channels: int, + ) -> list[dict[str, Any]]: + overrides_list = self.overrides.get(tensor_name, [{} for i in range(num_channels)]) + + if len(overrides_list) != num_channels: + raise ValueError( + f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " + f"but found {len(overrides_list)} instead." + ) + + return overrides_list + + def get_quant_types(self) -> set[QuantType]: + if self.quant_types is not None: + return self.quant_types + + self.quant_types = set() + + if self.overrides: + for quant_overrides_list in self.overrides.values(): + for quant_overrides in quant_overrides_list: + if "quant_type" in quant_overrides: + self.quant_types.add(quant_overrides["quant_type"]) + + if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]: + self.quant_types.add(quant_overrides["convert"]["quant_type"]) + + return self.quant_types + + def is_valid( + self, + initializer_names: set[str], + activation_names: set[str], + default_activation_qtype, + ) -> tuple[bool, str | None]: + self.quant_types = set() + + # Validate that compatible/valid overrides are provided. + if self.overrides: + keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} + + for tensor_name, quant_overrides_list in self.overrides.items(): + if tensor_name not in initializer_names and tensor_name not in activation_names: + return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model" + + if not isinstance(quant_overrides_list, list): + return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list" + + is_initializer = tensor_name in initializer_names + if not is_initializer and len(quant_overrides_list) > 1: + return ( + False, + f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer", + ) + + quant_type = None + for index, quant_overrides in enumerate(quant_overrides_list): + if not isinstance(quant_overrides, dict): + return ( + False, + f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict", + ) + + # For per-channel quantization, all channels must use the same quantization type. + # Therefore, if the user tries to override the quant_type for a channel, it must match in all + # other channels. + if index == 0: + quant_type = quant_overrides.get("quant_type") + if quant_type: + self.quant_types.add(quant_type) + elif quant_type != quant_overrides.get("quant_type"): + return ( + False, + "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.", + ) + + has_scale = "scale" in quant_overrides + has_zero_point = "zero_point" in quant_overrides + + if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): + return ( + False, + "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", + ) + + if has_scale: + for key in keys_unsupported_with_scale_zp: + if key in quant_overrides: + return ( + False, + f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'", + ) + + if "reduce_range" in quant_overrides and not is_initializer: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + if "convert" in quant_overrides: + if index > 0: + return ( + False, + f"Per-channel overrides (tensor '{tensor_name}') do not support 'convert'.", + ) + + if is_initializer: + return False, "Cannot use 'convert' override for initializers" + + if "quant_type" not in quant_overrides["convert"]: + return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'" + + if "reduce_range" in quant_overrides["convert"]: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + convert_quant_type = quant_overrides["convert"]["quant_type"] + original_quant_type = quant_type if quant_type is not None else default_activation_qtype + if convert_quant_type == original_quant_type: + return ( + False, + f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')", + ) + + convert_has_scale = "scale" in quant_overrides["convert"] + convert_has_zero_point = "zero_point" in quant_overrides["convert"] + + if (convert_has_scale and not convert_has_zero_point) or ( + convert_has_zero_point and not convert_has_scale + ): + return ( + False, + f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')", + ) + + if convert_has_scale: + for key in keys_unsupported_with_scale_zp: + if key in quant_overrides["convert"]: + return ( + False, + f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point' (tensor '{tensor_name}')", + ) + + self.quant_types.add(convert_quant_type) + + return True, None + + def pprint_str(self, indent=None) -> str: + return json.dumps(self.overrides, default=str, indent=indent) + + def get_dict(self) -> dict[str, list[dict[str, Any]]]: + return self.overrides + + # Required implementations of abstract methods in collections.abc.MutableMapping + # so that this class can be used like a dict. + def __setitem__(self, key: str, value: list[dict]): + self.overrides[key] = value + + def __getitem__(self, key: str) -> list[dict]: + return self.overrides[key] + + def __delitem__(self, key: str): + del self.overrides[key] + + def __iter__(self): + return iter(self.overrides) + + def __len__(self): + return len(self.overrides) + + def __str__(self) -> str: + return str(self.overrides) + + def __repr__(self) -> str: + return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})" diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 9e7a4a125121d..db4ab7e8a412c 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -4,7 +4,9 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations +import os import tempfile import unittest from pathlib import Path @@ -25,12 +27,12 @@ class TestQDQFormat(unittest.TestCase): - def input_feeds(self, n, name2shape): + def input_feeds(self, n, name2shape, np_float_type=np.float32): input_data_list = [] for _i in range(n): inputs = {} for name, shape in name2shape.items(): - inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + inputs.update({name: np.random.randint(-1, 2, shape).astype(np_float_type)}) input_data_list.extend([inputs]) dr = TestDataFeeds(input_data_list) return dr @@ -720,5 +722,593 @@ def test_activation_only(self): check_op_type_count(self, qdq_model_path, **qop_nodes) +class TestQDQMixedPrecision(TestQDQFormat): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.mixed_prec_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_test_model_for_add_qdq_ops( + self, + num_consumers: int, + is_graph_output: bool, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + op0_transpose: bool = False, + ): + """ + Builds a float32 model with a single producer node and a configurable number of consumer nodes. + The tensor between the producer and consumers can be optionally made a graph output. + op_0 can optionally be made a Transpose node to test sharing qparams across the input and output. + + +-> op_0_out (optional graph output) + | + input_0 --> op_0 --+-> op_1 --> output_0 + | + +-> op_2 --> output_1 + | + ... + | + +-> op_{n} --> output_{n-1} + """ + shape = (1, 2, 3) + shape_t = (1, 3, 2) + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, shape) + output_shape = shape if not op0_transpose else shape_t + + outputs = [] + for i in range(num_consumers): + outputs.append(onnx.helper.make_tensor_value_info(f"output_{i}", float_type, output_shape)) + + if is_graph_output: + outputs.append(onnx.helper.make_tensor_value_info("op_0_out", float_type, output_shape)) + + nodes = [] + if op0_transpose: + nodes.append(onnx.helper.make_node("Transpose", ["input_0"], ["op_0_out"], perm=[0, 2, 1], name="op_0")) + else: + nodes.append(onnx.helper.make_node("Sigmoid", ["input_0"], ["op_0_out"], name="op_0")) + + for i in range(num_consumers): + op_index = i + 1 + nodes.append(onnx.helper.make_node("Cos", ["op_0_out"], [f"output_{i}"], name=f"op_{op_index}")) + + graph = onnx.helper.make_graph( + nodes, + "test_add_qdq_ops_for_converted_activation", + [input_0], + outputs, + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_add_tensor_qdq_ops_case_1(self): + """ + Tensor T is not a graph output; all consumers use the converted type + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_1{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_1{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 2, False, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_1", "op_2"}}, + } + ], + "output_0": [{"quant_type": QuantType.QUInt16}], + "output_1": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 --> DQ_8 --> Q_16 --> DQ_16 -+-> op_1 --> Q --> DQ --> output_0 + # | + # +-> op_2 --> Q --> DQ --> output_1 + qdq_node_counts = {"QuantizeLinear": 5, "DequantizeLinear": 5} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT16) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + + def test_add_tensor_qdq_ops_case_2(self): + """ + Tensor T is not a graph output; some consumers use the original type, others use the converted type + ---> Q1 -+-> DQ1 ---> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_2{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_2{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 4, False, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_3", "op_4"}}, + } + ], + "output_2": [{"quant_type": QuantType.QUInt16}], + "output_3": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 -+-> DQ_8 -+-> op_1 --> Q --> DQ --> output_0 + # | | + # | +-> op_2 --> Q --> DQ --> output_1 + # | + # +-> DQ_8' --> Q_16 --> DQ_16 -+-> op_3 --> Q --> DQ --> output_2 + # | + # +-> op_4 --> Q --> DQ --> output_3 + qdq_node_counts = {"QuantizeLinear": 7, "DequantizeLinear": 8} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT8) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT8) + output_2_zp_init = initializers["output_2_zero_point"] + self.assertEqual(output_2_zp_init.data_type, onnx.TensorProto.UINT16) + output_3_zp_init = initializers["output_3_zero_point"] + self.assertEqual(output_3_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + output_2_scale_init = initializers["output_2_scale"] + self.assertEqual(output_2_scale_init.data_type, float_type) + output_3_scale_init = initializers["output_3_scale"] + self.assertEqual(output_3_scale_init.data_type, float_type) + + def test_add_tensor_qdq_ops_case_3(self): + """ + Tensor T is a graph output; all consumers use the converted type + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> + | + +-> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_3{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_3{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 2, True, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_1", "op_2"}}, + } + ], + "output_0": [{"quant_type": QuantType.QUInt16}], + "output_1": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 --> DQ_8 --> Q_16 --> DQ_16 -+-> op_1 --> Q --> DQ --> output_0 + # | + # +-> op_2 --> Q --> DQ --> output_1 + # | + # +--> op_0_out (is graph output) + qdq_node_counts = {"QuantizeLinear": 5, "DequantizeLinear": 5} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + graph_outputs = {g_output.name: g_output for g_output in qdq_model.graph.output} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + self.assertNotIn("op_0_out_scale", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT16) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + + self.assertIn("op_0_out", graph_outputs) + + def test_add_tensor_qdq_ops_case_4(self): + """ + Tensor T is a graph output; some consumers use the original type, others use the converted type + ---> Q1 -+-> DQ1 -+-> + | | + | +-> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_4{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_4{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 4, True, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_3", "op_4"}}, + } + ], + "output_2": [{"quant_type": QuantType.QUInt16}], + "output_3": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 -+-> DQ_8 -+-> op_1 --> Q --> DQ --> output_0 + # | | + # | +-> op_2 --> Q --> DQ --> output_1 + # | | + # | +-> op_0_out (is graph output) + # | + # +-> DQ_8' --> Q_16 --> DQ_16 -+-> op_3 --> Q --> DQ --> output_2 + # | + # +-> op_4 --> Q --> DQ --> output_3 + qdq_node_counts = {"QuantizeLinear": 7, "DequantizeLinear": 8} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + graph_outputs = {g_output.name: g_output for g_output in qdq_model.graph.output} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT8) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT8) + output_2_zp_init = initializers["output_2_zero_point"] + self.assertEqual(output_2_zp_init.data_type, onnx.TensorProto.UINT16) + output_3_zp_init = initializers["output_3_zero_point"] + self.assertEqual(output_3_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + output_2_scale_init = initializers["output_2_scale"] + self.assertEqual(output_2_scale_init.data_type, float_type) + output_3_scale_init = initializers["output_3_scale"] + self.assertEqual(output_3_scale_init.data_type, float_type) + + self.assertIn("op_0_out", graph_outputs) + + def build_test_model_1(self, shape): + """ + Returns the following float32 model. + + input_0 --> op1 --> op3 --> op5 --> op6 --> output_0 + ^ + | + input_1 --> op2 -+-> op4 ----+ + | + +-> op7 --> output_1 + | + +-> op8 --> output_2 + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, shape) + input_1 = onnx.helper.make_tensor_value_info("input_1", onnx.TensorProto.FLOAT, shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, shape) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx.TensorProto.FLOAT, shape) + output_2 = onnx.helper.make_tensor_value_info("output_2", onnx.TensorProto.FLOAT, shape) + + op1_node = onnx.helper.make_node("Sigmoid", ["input_0"], ["op1_out"], name="op1") + op2_node = onnx.helper.make_node("Cos", ["input_1"], ["op2_out"], name="op2") + op3_node = onnx.helper.make_node("Sin", ["op1_out"], ["op3_out"], name="op3") + op4_node = onnx.helper.make_node("Tanh", ["op2_out"], ["op4_out"], name="op4") + op5_node = onnx.helper.make_node("Mul", ["op3_out", "op4_out"], ["op5_out"], name="op5") + op6_node = onnx.helper.make_node("Relu", ["op5_out"], ["output_0"], name="op6") + op7_node = onnx.helper.make_node("Cos", ["op2_out"], ["output_1"], name="op7") + op8_node = onnx.helper.make_node("Sigmoid", ["op2_out"], ["output_2"], name="op8") + + graph = onnx.helper.make_graph( + [ + op1_node, + op2_node, + op3_node, + op4_node, + op5_node, + op6_node, + op7_node, + op8_node, + ], + "mixed_prec_test", + [input_0, input_1], + [output_0, output_1, output_2], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_16bit_subgraph(self): + """ + Test correctness of a qdq model that uses a default 8-bit quantization type and contains + a subgraph that uses 16-bit activations. + """ + shape = (1, 2, 3) + f32_model_path = os.path.join(self._tmp_dir_path, "model.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "model.qdq.onnx") + qdq_mixed_model_path = os.path.join(self._tmp_dir_path, "model.mixed.qdq.onnx") + f32_model = self.build_test_model_1(shape) + onnx.save_model(f32_model, f32_model_path) + + data_reader = self.input_feeds(3, {"input_0": shape, "input_1": shape}) + + # Create pure 8-bit qdq model + quantize_static( + f32_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + ) + + # Create mixed precision 8-bit/16-bit qdq model + mixed_prec_overrides = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + data_reader.rewind() + quantize_static( + f32_model_path, + qdq_mixed_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + extra_options={"TensorQuantOverrides": mixed_prec_overrides}, + ) + + qop_nodes = {"Relu": 0, "QuantizeLinear": 11, "DequantizeLinear": 12} + check_op_type_count(self, qdq_mixed_model_path, **qop_nodes) + data_reader.rewind() + check_model_correctness(self, f32_model_path, qdq_mixed_model_path, data_reader.get_next()) + data_reader.rewind() + check_model_correctness(self, f32_model_path, qdq_model_path, data_reader.get_next()) + + if __name__ == "__main__": unittest.main() From 5b64d7c32b29e1f97523f184a147107431d99611 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Sat, 23 Mar 2024 11:19:14 -0700 Subject: [PATCH 04/11] [JS/WebGPU] Use non-matmul implementation for ConvTranspose in channel-first case. (#20022) ### Description Avoid using vec4 Matmul implementation for ConvTranspose with channel-last ### Motivation and Context --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 11 +- js/web/test/data/ops/conv-transpose.jsonc | 262 ++++++++++++++++++ 2 files changed, 266 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 11c8778b72335..080b24a2432aa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -164,17 +164,14 @@ export const createConv2DTransposeMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = - isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0; // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = isVec4 ? - [8, 8, 1] : - [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 7038e2a4f8766..8ed48dd07e6f1 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -392,5 +392,267 @@ ] } ] + }, + { + "name": "ConvTranspose without bias addition C", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "strides", "data": [2, 2], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ], + "dims": [1, 4, 16, 16], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15 + ], + "dims": [4, 4, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, + 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, + 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, + 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, + 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, + 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, + 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, + 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, + 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, + 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, + 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, + 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, + 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, + 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, + 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, + 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, + 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, + 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, + 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, + 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, + 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, + 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, + 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, + 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, + 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, + 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, + 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, + 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, + 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, + 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, + 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, + 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, + 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, + 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, + 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, + 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, + 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, + 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, + 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, + 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, + 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, + 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, + 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, + 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, + 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, + 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, + 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 32, 36, 64, 72, + 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, + 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, + 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, + 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, + 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, + 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, + 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, + 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, + 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, + 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, + 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, + 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, + 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, + 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, + 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, + 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, + 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, + 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, + 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, + 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, + 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, + 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, + 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, + 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, + 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, + 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, + 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, + 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, + 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, + 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, + 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, + 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, + 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, + 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, + 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, + 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, + 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, + 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, + 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, + 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, + 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, + 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, + 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, + 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, + 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, + 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, + 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, + 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, + 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860 + ], + "dims": [1, 4, 32, 32], + "type": "float32" + } + ] + } + ] } ] From f977be066318111f01d7f2e2373824566a7fe9c0 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Sat, 23 Mar 2024 13:43:20 -0700 Subject: [PATCH 05/11] Fix issue that failed to load Conv node with external initializer (#20042) ### Description Fix issue that failed to load Conv node with external initializer. Root cause the model path is not provided while loading the weight and bias tensor for Conv. --- .../qnn/builder/opbuilder/base_op_builder.cc | 4 +++- .../test/providers/qnn/qnn_basic_test.cc | 20 ++++++++++++++++++ .../test/testdata/conv_qdq_external_ini.bin | Bin 0 -> 2000 bytes .../test/testdata/conv_qdq_external_ini.onnx | Bin 0 -> 2204 bytes 4 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/conv_qdq_external_ini.bin create mode 100644 onnxruntime/test/testdata/conv_qdq_external_ini.onnx diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 6d8c80bd2aaa1..08c9a8449cc33 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -244,7 +244,9 @@ Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrap TensorShape new_tensor_shape(new_tensor_shape_dims); Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(Env::Default(), nullptr, initializer, in_tensor)); + onnxruntime::PathString model_path = qnn_model_wrapper.GetGraphViewer().ModelPath().ToPathString(); + const ORTCHAR_T* model_path_str = model_path.empty() ? nullptr : model_path.c_str(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(Env::Default(), model_path_str, initializer, in_tensor)); ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 4f294f899c170..7fd2441441dcf 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -168,6 +168,26 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { } } +// Conv node `Conv` is not supported: GetFileLength for conv_qdq_external_ini.bin failed:open file conv_qdq_external_ini.bin fail, +// errcode = 2 - The system cannot find the file specified. +TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { + Ort::SessionOptions so; + onnxruntime::ProviderOptions options; +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + so.AppendExecutionProvider("QNN", options); + + Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv_qdq_external_ini.onnx"; + + Ort::Session session(*ort_env, ort_model_path, so); +} + // Helper function that runs an ONNX model with a NHWC Resize operator to test that // type/shape inference succeeds during layout transformation. // Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. diff --git a/onnxruntime/test/testdata/conv_qdq_external_ini.bin b/onnxruntime/test/testdata/conv_qdq_external_ini.bin new file mode 100644 index 0000000000000000000000000000000000000000..e749ab5af29c58c932a40d432f236cf3ae5d5be3 GIT binary patch literal 2000 zcmeH`>rY#C7{?pqB__t07!!T%i^VtFE2B}ehy&IE8_MR22yHD;N=Gk1PwD0K^z?SR z9(sDCJ?-f!)KfZWx%5I?29(l{%L;+XZf+(6;$>zkk*_xh~3$emfk0VBY+?((J zo`m%SwzfRN3-oB#iwg5v~S%!v~&neh6&Hr;uEZFw=i zY~j*p+H}J4&0ahni-bvs+x!rMEzTj*mvijsP3sZy_~#N_&is5fYP7J~IgMwqR5wv` zbw`L)V_}~>HGh+6bXw{zw|KTuqFT)^+9|_NZ~aJVP@D-= z&UKC@R);=uq4qFaYwmwk@HhgQxM@%tX!OEe0-sLfwd|D{D1{R+A=zIts%-C?I*gd8 z&faT@^=LXvK_f^S& zV6=5RGpX%YM0+zd3%X0` zgl~`ct`V@!;P+s z6Dth#RSFxRcX^vMX1r!VyJcn^>0O5@H0W0$aKZNF(wfshBOa1Q6UrU|`5viu2BI=r z_D@_qLUk$HCy%wNvM!>ecRPjRHx=T^Z?n+6Xo4Pga>!NdPc4Y1!?w^pQs!1ZxEO>&)vDS8 z-fr`15?8i)KCU=&0Nz_H7V8DTn}Fv5zk%6H!27_z13+Xw0saN>8ptmKUVi4of3E`G y0D$$u-+;s2bAZE}LtY2@Dezx_KcD#(F#8weMeqF&L1fh*_57&kM?L@F_xwL!Iyz|p literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/conv_qdq_external_ini.onnx b/onnxruntime/test/testdata/conv_qdq_external_ini.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fad6074aea133f2d5c3ecbf5b597bae70bdde4ce GIT binary patch literal 2204 zcmbVOL2uJA6eexgg|~L3p=(PRw5r4<(n33?u>sPmzy*m5S1wVSx~;R+$r5)R;}>v% z8#fN{Z;&`3{slh+J89RXO}vwO)L` z2VVs!TQ0e5Yh7#x#thK|CCzFCmgyj$rlAN$dQmP4COv0l?$0CigzB@%%7q&3*#uRqK*K{SM=qJWojx6t| zOC9;Cq*|N1Nfb;V6}O495>YPsxWb;|+;6t%>65_vg@Rb5WyKv+wn9eItG9gBl>BS9 z$&6Mr$KR?9cYD||DRxQF5jw&TVOuT0*1&hiUXA@_)IBJiBcCA02!^JII@wN^gG!L| z%wEvRL=QJfupLkz_t0!6_ylW|Kw+vf+CQb~Uo3)qsctCg=A{^(LppFXZ$9Exs#E^O zcrM3HmZ)lF-ck{P+1xk8fagpPT*aua|mCcByqX zTjqc?XdJ9m?60<0DT{YuvRv`f|k73U}hqtG{P@BDiW-Y)E-MwZ7d(6_XKH z)jm97s`{cJ3Hq99c^jtF(?fgb#(#l(f6jRF`liZzJF?>ivJdrXv6CxM76n;?j3}^{ zWv#|rx50uHCi?=S!Fm+kx+b-TM!-%V);)*S5)}kMWcVCM*v~hV zo)|Nqj%y2>t)*Ad5T9B0#cr*teKCfF_x91c0<$n`?&{MOm0>~H%>X?k74<6Ix7?97 z!j|s_Zl7wz8cr;><8}c|tcAz=$dR&oji2KT7-;MO4l_$594YB#kY-5gn=IYPfQ{dW z+RW Date: Sat, 23 Mar 2024 14:30:35 -0700 Subject: [PATCH 06/11] Packed QKV and Rotary Embedding Support for sm<80 GQA (#20012) ### Description Add support for packed qkv input and rotary embedding with sm<80 using memory efficient attention kernel. ### Motivation and Context Allows lower-end gpus to run gqa with packed qkv input and rotary embedding. --- .../cuda/bert/group_query_attention.cc | 23 ++- .../cuda/bert/group_query_attention_impl.cu | 160 ++++++++++++++++-- .../cuda/bert/group_query_attention_impl.h | 2 + .../python/transformers/test_flash_attn.py | 95 ++++++----- 4 files changed, 216 insertions(+), 64 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 814aa1fb3c8f0..112f609d46598 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -159,8 +159,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - do_rotary_ == false && - key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -172,18 +170,31 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (use_memory_efficient_attention && needs_buff) { kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); } + size_t rotary_buffer_bytes = 0; + if (use_memory_efficient_attention && do_rotary_) { + rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size; + rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; + } size_t fmha_buffer_bytes = 0; if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); } + size_t unpacked_qkv_bytes = 0; + if (use_memory_efficient_attention && parameters.is_packed_qkv) { + unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); + } auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream()); #else constexpr bool use_memory_efficient_attention = false; auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif // seqlens_k buffer @@ -251,7 +262,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } - // Rotary + if (unpacked_qkv_buffer != nullptr) { + data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); + } + if (rotary_buffer != nullptr) { + data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); + } + // Rotary Embedding if (parameters.do_rotary) { data.cos_cache = reinterpret_cast(cos_cache->Data()); data.sin_cache = reinterpret_cast(sin_cache->Data()); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index afba83be34e2d..f519be1c97149 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -42,6 +42,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include using namespace onnxruntime::cuda; @@ -150,6 +151,8 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block, const bool past_only = false) { @@ -171,14 +174,14 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -191,7 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -200,7 +203,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -281,6 +284,8 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, template Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block) { const int batch_size = parameters.batch_size; @@ -300,12 +305,12 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, const dim3 block(H, kv_num_heads, 1); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { @@ -316,14 +321,14 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, H, kv_num_heads, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlaceLarge<<>>(present_sequence_length, H, kv_num_heads, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } @@ -468,6 +473,83 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } +// Kernel to unpack qkv from packed qkv +template +__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, + const int batch_size) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + int d = (num_heads + 2 * kv_num_heads) * head_size; + const int qkv_size = batch_size * sequence_length * d; + const int q_size = num_heads * head_size; + const int k_size = kv_num_heads * head_size; + if (tid < qkv_size) { + int batch = tid / (d * sequence_length); + int sequence = (tid % (d * sequence_length)) / d; + int offset = tid % d; + if (offset < q_size) { + int unpacked_i = batch * sequence_length * num_heads * head_size + sequence * num_heads * head_size + offset; + unpacked_q[unpacked_i] = packed_qkv[tid]; + } else if (offset < q_size + k_size) { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size); + unpacked_k[unpacked_i] = packed_qkv[tid]; + } else { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size - k_size); + unpacked_v[unpacked_i] = packed_qkv[tid]; + } + } +} + +// Unpack packed qkv +template +Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, + cudaStream_t stream, const int max_threads_per_block) { + const int threads = max_threads_per_block; + const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; + UnpackQKV<<>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size); + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return CUDA_CALL(cudaGetLastError()); +} + ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -517,7 +599,8 @@ Status FlashAttention( seqlens_k = data.seqlens_k_total; } } else if (!parameters.kv_share_buffer) { // copy past kv to present kv - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, + true)); } void* present_key = reinterpret_cast(const_cast(data.present_key)); @@ -563,15 +646,62 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const void* query = reinterpret_cast(data.query); - const void* key = reinterpret_cast(data.key); - const void* value = reinterpret_cast(data.value); + const void* query; + const void* key; + const void* value; + + if (!parameters.is_packed_qkv) { + query = reinterpret_cast(data.query); + key = reinterpret_cast(data.key); + value = reinterpret_cast(data.value); + } else { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q = reinterpret_cast(data.unpacked_qkv_buffer); + auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); + auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); + ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast(data.query), q, k, v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size, stream, max_threads_per_block)); + query = reinterpret_cast(q); + key = reinterpret_cast(k); + value = reinterpret_cast(v); + } + + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q_buffer = reinterpret_cast(data.rotary_buffer); + auto k_buffer = q_buffer + q_size; + auto position_ids_buff = reinterpret_cast(k_buffer + k_size); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream, + max_threads_per_block)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, q_buffer, reinterpret_cast(query), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + query = reinterpret_cast(q_buffer); + key = reinterpret_cast(k_buffer); + } if (parameters.is_prompt) { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + batch_size); } else { ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } @@ -583,7 +713,7 @@ Status EfficientAttention( "Past and present kv shall share the same tensor when kv_share_buffer is on."); } // Concatenate new kv in place - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, stream, max_threads_per_block)); } else { // Not share buffer case if (data.past_key != nullptr && data.past_key == data.present_key) { @@ -591,7 +721,7 @@ Status EfficientAttention( "Past and present kv share the same tensor but kv_share_buffer is not on."); } // Copy past and concat new KV to present buffer - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, key, value, stream, max_threads_per_block)); } // Ungroup if grouped, otherwise use present kv directly diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 1bf91f9c875eb..32341afa0e3fa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -30,6 +30,8 @@ struct GroupQueryAttentionData { int* seqlens_k_total = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; + T* unpacked_qkv_buffer = nullptr; + T* rotary_buffer = nullptr; T* k = nullptr; T* v = nullptr; // Output Tensors diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index b784c83329c76..183d6218567a7 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -1216,8 +1216,6 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) - # print(k.shape) - # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1328,10 +1326,6 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) - # print((out - out_ref)[0, :, 0, 0]) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1724,9 +1718,6 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1939,18 +1930,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((out - out_ref)[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) - - # Make sure past-present buffer updating correctly - # assert numpy.allclose( - # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True - # ) - # assert numpy.allclose( - # present_v[:, :, :-1, :], v_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True - # ) - # Compare results print( "NO buff", @@ -2078,10 +2057,27 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + rtol=2e-3, + atol=2e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + rtol=2e-3, + atol=2e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (PROMPT CASE) --------") @@ -2092,12 +2088,12 @@ def test_gqa_no_past(self): for h in h_sizes: for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + for packed in [False, True]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) parity_check_gqa_prompt( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2105,7 +2101,7 @@ def test_gqa_no_past(self): parity_check_gqa_prompt_no_buff( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2145,21 +2141,28 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (TOKEN GEN) -------") @@ -2170,13 +2173,13 @@ def test_gqa_past(self): for h in h_sizes: for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + for packed in [False, True]: sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) parity_check_gqa_past( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, rotary=rotary, @@ -2186,7 +2189,7 @@ def test_gqa_past(self): parity_check_gqa_past_no_buff( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, rotary=rotary, From d30c81d270894f41ccce7b102b1d4aedd9e628b1 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Mon, 25 Mar 2024 15:05:02 +0800 Subject: [PATCH 07/11] Add Symbolic Shape Hint to Triton Codegen Config (#20056) Add symbolic shape hint to Triton codegen config so that we can avoid unnecessary recompile when input shapes are keeping changing. Below screenshot shows that with proper configuration, we can speed up the training a lot by reducing unnecessary recompile. ![image](https://github.com/microsoft/onnxruntime/assets/11661208/699944d2-81cd-4c22-84e7-73a4fa0d2a28) --- .../python/training/ort_triton/_cache.py | 2 + .../training/ort_triton/triton_op_executor.py | 57 +++++++++++++++---- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/_cache.py b/orttraining/orttraining/python/training/ort_triton/_cache.py index ede9cd86a9da5..b70064377abfc 100644 --- a/orttraining/orttraining/python/training/ort_triton/_cache.py +++ b/orttraining/orttraining/python/training/ort_triton/_cache.py @@ -9,6 +9,7 @@ import getpass import hashlib import os +import sys import tempfile from types import ModuleType from typing import Tuple @@ -61,6 +62,7 @@ def load(cls, source_code) -> ModuleType: mod.__file__ = path mod.key = key exec(code, mod.__dict__, mod.__dict__) + sys.modules[mod.__name__] = mod # another thread might set this first cls.cache.setdefault(key, mod) return cls.cache[key] diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index e104ea13c59a3..14bc2779aa05b 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -6,11 +6,13 @@ import functools import json import os +import re import sys from types import ModuleType from typing import List, Tuple, Union import onnx +from onnx import ModelProto from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -41,18 +43,39 @@ class _ShapeCache: """ cache = dict() # noqa: RUF012 + symbolic_shape_hint = None + min_symbolic_shape = 0 clear = staticmethod(cache.clear) @classmethod - def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]: + def set_symbolic_shape_hint(cls, symbolic_shape_hint_config): + for k, v in symbolic_shape_hint_config.items(): + if k == "*": + cls.min_symbolic_shape = v + else: + if cls.symbolic_shape_hint is None: + cls.symbolic_shape_hint = dict() + cls.symbolic_shape_hint[k] = v + + @classmethod + def get_shape(cls, onnx_key: int, model: ModelProto, shapes: List[List[int]]) -> List[List[Union[int, str]]]: if onnx_key not in cls.cache: + if cls.symbolic_shape_hint is not None: + for i, input in enumerate(model.graph.input): + if input.type.tensor_type.HasField("shape"): + for j, dim in enumerate(input.type.tensor_type.shape.dim): + if dim.dim_param: + for k, v in cls.symbolic_shape_hint.items(): + if re.fullmatch(k, dim.dim_param): + shapes[i][j] = f"i{i}_dim{j}_{v}" + break cls.cache[onnx_key] = shapes else: changed = False for i, shape in enumerate(shapes): for j, dim in enumerate(shape): - if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int): - max_dim = max(dim, cls.cache[onnx_key][i][j]) + if isinstance(cls.cache[onnx_key][i][j], int) and dim != cls.cache[onnx_key][i][j]: + max_dim = max(dim, cls.cache[onnx_key][i][j], cls.min_symbolic_shape) shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}" changed = True elif isinstance(cls.cache[onnx_key][i][j], str): @@ -67,13 +90,12 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in return cls.cache[onnx_key] -def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: +def _gen_key(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") -def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: - model = onnx.load_model_from_string(onnx_str) +def _gen_module(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) @@ -96,14 +118,28 @@ def get_config() -> str: "scalar": only related scalar initializers will be added to subgraphs. "all": all related initializers will be added to subgraphs. The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph. + User can also specify symbolic_shape_hint in the config, which is a dict to control the symbolic shape hint. + Each entry is a regex pattern to match the dim_param in ONNX model and the value is the power of 2 for the symbolic + shape. Each dim_param will be replaced by i{input_index}_dim{dim_index}_{power_of_2} in the symbolic shape. """ + config = dict() config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "") if config_file and os.path.exists(config_file): with open(config_file, encoding="UTF-8") as f: - return f.read() + config = json.load(f) + + if "ops" not in config: + config["ops"] = get_supported_ops() + if "initializer" not in config: + config["initializer"] = "scalar" + if "min_nodes" not in config: + config["min_nodes"] = 2 + + if "symbolic_shape_hint" in config and len(config["symbolic_shape_hint"]) > 0: + _ShapeCache.set_symbolic_shape_hint(config["symbolic_shape_hint"]) + del config["symbolic_shape_hint"] - config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2} return json.dumps(config) @@ -136,8 +172,9 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): assert all(tensor is not None for tensor in tensors) torch_tensors = [_from_dlpack(tensor) for tensor in tensors] concrete_shapes = [list(tensor.size()) for tensor in torch_tensors] - shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes) - func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes) + model = onnx.load_model_from_string(onnx_str) + shapes = _ShapeCache.get_shape(onnx_key, model, concrete_shapes) + func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, model, shapes) func = getattr(mod, func_name) output = func(*torch_tensors) if isinstance(output, tuple): From 7d976cf72098639b0c8427629bef797861497ea9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 25 Mar 2024 14:41:14 -0700 Subject: [PATCH 08/11] [QNN QDQ Quant] Utils to generate mixed-precision quant overrides (#20028) ### Description - Adds a utility to the QNN quantization scripts that "fixes" an initial set of tensor quantization overrides for mixed-precision QDQ models. Follow-up to https://github.com/microsoft/onnxruntime/pull/19925 - Moves existing overrides for QNN compatibility (matmul, layernorm, sigmoid, tanh) to separate functions. PR adds missing unit tests for these. - Adds `weight_symmetric=None` parameter to the `get_qnn_qdq_config()` function to enable user specification (instead of always using default behavior). - If weight_symmetric is set to `None`, it will be set to `weight_symmetric = weight_type in (QUInt8, QUInt16)`. - Otherwise, the user's value is used. #### Example Float model: ``` input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0 ^ | input_1 --> Op2 -+-> Op4 ----+ | +-> Op7 --> output_1 | +-> Op8 --> output_2 ``` If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out" is quantized to 16-bit, then we would specify the following initial tensor quantization overrides: ```python # Op4_out could be an inaccurate tensor that should be upgraded to 16bit initial_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]} ``` These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types are valid: ```python qnn_config = get_qnn_qdq_config( float_model_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, init_overrides=initial_overrides, # These initial overrides will be "fixed" ) ``` The above snippet generates the following "fixed" overrides (get via `qnn_config.extra_options["TensorQuantOverrides"]`): ```python { "Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}], "Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}], "Op4_out": [{"quant_type": QUInt16}], "Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}] } ``` How to interpret the fixed overrides: - Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type, but Op7 and Op8 consume the original u8 type. - Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type. - Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type. - Op5's output is converted from u16 to u8. Op6 consumes the u8 type. ### Motivation and Context Generating mixed-precision quantization overrides is currently a manual process. This PR adds an utility that helps generate valid overrides. --- .../qnn/mixed_precision_overrides_utils.py | 413 +++++++++++++++ .../execution_providers/qnn/quant_config.py | 248 +++++++-- .../quantization/tensor_quant_overrides.py | 131 +++++ .../test_mixed_prec_quant_overrides_fixer.py | 171 +++++++ .../test_tensor_quant_overrides_option.py | 470 +++++++++++++++++- 5 files changed, 1369 insertions(+), 64 deletions(-) create mode 100644 onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py create mode 100644 onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py new file mode 100644 index 0000000000000..d59a0ec74ca7c --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py @@ -0,0 +1,413 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import onnx + +from ...quant_utils import QuantType +from ...tensor_quant_overrides import QuantTypeInfo, TensorQuantOverridesHelper + + +@dataclass +class TensorTypeRequest: + """ + Bundles desired quantization type requests for a tensor. A distinction is made between the + produced type and the consumed type. + """ + + # The tensor's quant type at the producer end. If None, assumed to be the default activation quant type. + producer: QuantTypeInfo | None + + # The tensor's quant type received by a set of consumer nodes. + # If None, assumed to be the default activation quant type for all consumers. + # consumers[1] is a set of consumer node names. + consumers: tuple[QuantTypeInfo, set[str]] | None + + +class MixedPrecisionTensorQuantOverridesFixer: + """ + Helper that generates tensor quantization overrides for mixed-precision QDQ models. + + Specifically, this helper fixes an initial set of quantization overrides that assign a non-default + activation quantization type to one or more tensors by doing the following: + - Inferring which other tensors need to be overridden to the non-default activation quantization type. + - Inserting quantization data type conversions. + + Example: + -------- + + Float model: + + input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0 + ^ + | + input_1 --> Op2 -+-> Op4 ----+ + | + +-> Op7 --> output_1 + | + +-> Op8 --> output_2 + + If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out" + is quantized to 16-bit, then we would specify the following initial tensor quantization overrides: + + ``` + init_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]} + ``` + + These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output + to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types + are valid: + + ``` + overrides = TensorQuantOverridesHelper(init_overrides) + + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, QuantType.QUInt8) + fixer.apply( + default_activation_qtype=QuantType.QUInt8, + default_activation_symmetric=False, + ) + ``` + + The above snippet generates the following "fixed" overrides (get via overrides.get_dict()): + + { + "Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}], + "Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}], + "Op4_out": [{"quant_type": QUInt16}], + "Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}] + } + + How to interpret the fixed overrides: + - Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type, + but Op7 and Op8 consume the original u8 type. + - Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type. + - Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type. + - Op5's output is converted from u16 to u8. Op6 consumes the u8 type. + """ + + def __init__( + self, + overrides: TensorQuantOverridesHelper, + producers: dict[str, onnx.NodeProto], + consumers: dict[str, list[onnx.NodeProto]], + value_infos: dict[str, onnx.ValueInfoProto], + initializers: dict[str, onnx.TensorProto], + ): + """ + Params: + overrides: The initial tensor quantization overrides to fix. + producers: Dictionary that maps a tensor name to the producer node that generates the tensor. + consumers: Dictionary that maps a tensor name to the consumer nodes that take the tensor as input. + value_infos: Dictionary that maps a tensor name to its onnx.ValueInfoProto. + initializers: Dictionary that maps an initializer name to its onnx.TensorProto. + """ + self.overrides = overrides + self.consumers = consumers + self.producers = producers + self.value_infos = value_infos + self.initializers = initializers + + @staticmethod + def create_from_model( + overrides: TensorQuantOverridesHelper, model: onnx.ModelProto, default_activation_qtype: QuantType + ) -> MixedPrecisionTensorQuantOverridesFixer: + """ + Helper function that creates an instance of this class from a loaded ONNX model. + + Params: + overrides: The initial tensor quantization overrides to fix. + model: Loaded ONNX model + default_activation_qtype: The intended default activation quantization type. + Used to validate the initial overrides. + + Returns: + Initialized MixedPrecisionTensorQuantOverridesFixer object + """ + model = onnx.shape_inference.infer_shapes(model) # Need to infer shapes to get value_infos + + # Build dictionaries that enable convenient lookups of initializers and value_infos by name. + initializers = {initializer.name: initializer for initializer in model.graph.initializer} + value_infos = {vi.name: vi for vi in model.graph.value_info} + value_infos.update({ot.name: ot for ot in model.graph.output}) + value_infos.update({it.name: it for it in model.graph.input}) + + # Ensure that the user-provided initial overrides are actually valid. + valid, err = overrides.is_valid(set(initializers), set(value_infos), default_activation_qtype) + if not valid: + pprint_overrides = overrides.pprint_str(indent=4) + logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}") + raise ValueError(err) + + consumers = {} + producers = {} + + # Build dictionaries that map a tensor name to the consumer or producer nodes. + for node in model.graph.node: + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + for output_name in node.output: + producers[output_name] = node + + return MixedPrecisionTensorQuantOverridesFixer(overrides, producers, consumers, value_infos, initializers) + + def apply( + self, + default_activation_qtype: QuantType, + default_activation_symmetric: bool, + ): + """ + Fixes the initial tensor quantization overrides (in-place) for use in mixed-precision QDQ models. + + Params: + default_activation_qtype: The intended default activation quantization type. + default_activation_symmetric: The intended default symmetry used to quantize activations. + """ + type_requests = self.get_desired_tensor_types(default_activation_qtype, default_activation_symmetric) + + # Use type requests to "fix" tensor quantization overrides by adding + # quantization type conversions where necessary. + for tensor_name, type_req in type_requests.items(): + all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])]) + has_producer_req = type_req.producer is not None + has_consumer_req = bool(type_req.consumers) + + # Only producer type: Add conversion back to default activation type + if has_producer_req and not has_consumer_req: + self._update_converted_tensor( + tensor_name, type_req.producer, QuantTypeInfo(default_activation_qtype), all_consumers + ) + # Only consumers + elif not has_producer_req and has_consumer_req: + prod_type_info = self.overrides.get_node_output_qtype_info(tensor_name, default_activation_qtype) + consumer_type_info = type_req.consumers[0] + + if prod_type_info != consumer_type_info: + self._update_converted_tensor( + tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1] + ) + else: + if not self._check_nodes_are_not_convert_consumers(tensor_name, type_req.consumers[1]): + raise ValueError( + f"Tensor override for '{tensor_name}' converts the type for consumers that need the original type." + ) + # Both producer and consumers + elif has_producer_req and has_consumer_req: + prod_type_info = type_req.producer + consumer_type_info = type_req.consumers[0] + + if prod_type_info != consumer_type_info: + self._update_converted_tensor( + tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1] + ) + else: + consumers_for_original_type = all_consumers.difference(type_req.consumers[1]) + + if len(consumers_for_original_type) == 0: + # All consumers want the overridden type, so no need for convert nodes! + # Just add the override to the new new if not already present. + if tensor_name not in self.overrides: + self.overrides[tensor_name] = [{}] + prod_type_info.save_to_dict(self.overrides[tensor_name][0]) + + assert "convert" not in self.overrides[tensor_name][0] + else: + # Some consumers don't want the overridden type. + self._update_converted_tensor( + tensor_name, + prod_type_info, + QuantTypeInfo(default_activation_qtype), + consumers_for_original_type, + ) + else: + raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.") + + # Done. Check if the overrides are valid. + valid, err = self.overrides.is_valid(set(self.initializers), set(self.value_infos), default_activation_qtype) + if not valid: + pprint_overrides = self.overrides.pprint_str(indent=4) + logging.error( + f"Generated invalid tensor quantization overrides for mixed-precision QDQ model:\n{pprint_overrides}" + ) + raise ValueError(err) + + def get_desired_tensor_types( + self, + default_activation_qtype: QuantType, + default_activation_symmetric: bool, + ) -> dict[str, TensorTypeRequest]: + """ + Iterates through the initial tensor quantization overrides and builds a set of TensorTypeRequests objects + that describe the quantization types required at each tensor. These TensorTypeRequests objects are ultimately + used to generated the "fixed" overrides. + + Params: + default_activation_qtype: The intended default activation quantization type. + default_activation_symmetric: The intended default symmetry used to quantize activations. + + Returns: + TensorTypeRequest objects as a dict that maps a tensor name to its requested types. + """ + type_requests = {} + default_activation_type_info = QuantTypeInfo(default_activation_qtype, default_activation_symmetric) + + # Scan tensor overrides for type conversion requests. + for tensor_name, override_list in self.overrides.items(): + if not self.__is_tensor_quantizable(tensor_name): + continue # Skip non-quantizable tensors (e.g., not a float) + + if tensor_name in self.initializers: + continue # Skip initializers + + if not override_list or len(override_list) > 1: + continue # Skip per-channel stuff + + override_dict = override_list[0] + quant_type_info = QuantTypeInfo.load_from_dict(override_dict, default_activation_type_info.quant_type) + producer_node = self.producers.get(tensor_name) # None if this is a model input + + if quant_type_info != default_activation_type_info and "convert" not in override_dict: + if producer_node is not None: + self._add_type_requests_for_node(type_requests, quant_type_info, producer_node) + + # Find all consumer nodes of `tensor_name` and update their inputs/outputs to the new type. + for consumer_node in self.consumers.get(tensor_name, []): + self._add_type_requests_for_node(type_requests, quant_type_info, consumer_node) + + return type_requests + + def _add_type_requests_for_node( + self, + type_requests: dict[str, TensorTypeRequest], + quant_type_info: QuantTypeInfo, + node: onnx.NodeProto, + ): + """ + Adds TensorTypeRequest objects for a given node, assuming that we want all its inputs and outputs + to have the same quantization type (as specified by the `quant_type_info` parameter). + + Params: + type_requests: Dictionary of type requests to append to for this node. + quant_type_info: The quantization type to use for inputs and outputs. + node: The node for which the TensorTypeRequest objects are created and added to type_requests. + """ + # Add output side + for output_name in node.output: + if not self.__is_tensor_quantizable(output_name): + continue + + if output_name not in type_requests: + type_requests[output_name] = TensorTypeRequest(quant_type_info, None) + else: + if ( + type_requests[output_name].producer is not None + and type_requests[output_name].producer != quant_type_info + ): + raise ValueError(f"Tensor {output_name} has multiple types.") + + type_requests[output_name].producer = quant_type_info + + # Add the consumer side + for input_name in node.input: + if input_name and input_name not in self.initializers and self.__is_tensor_quantizable(input_name): + if input_name not in type_requests: + type_requests[input_name] = TensorTypeRequest(None, None) + + if type_requests[input_name].consumers is None: + type_requests[input_name].consumers = (quant_type_info, set()) + + if type_requests[input_name].consumers[0] != quant_type_info: + raise ValueError(f"Tensor {input_name} has consumers requesting different types.") + + if not node.name: + raise ValueError( + f"Node of type {node.op_type} with output 0 {node.output[0]} does not have a name!" + ) + + type_requests[input_name].consumers[1].add(node.name) + + def _update_converted_tensor( + self, + tensor_name: str, + producer_type_info: QuantTypeInfo, + consumer_type_info: QuantTypeInfo, + consumer_names: set[str], + ): + """ + Updates the tensor quantization overrides for a tensor that is converted from one type to another. + + Params: + tensor_name: The name of the tensor for which to update overrides. + producer_type_info: Info for the tensor's produced type. + consumer_type_info: Info for the tensor's consumed (i.e., converted) type. + consumer_names: Nodes names of consumers that consume the converted type. + """ + if tensor_name not in self.overrides or not self.overrides[tensor_name]: + self.overrides[tensor_name] = [{}] + producer_type_info.save_to_dict(self.overrides[tensor_name][0]) + + overrides = self.overrides[tensor_name][0] + if producer_type_info != QuantTypeInfo.load_from_dict(overrides): + raise ValueError(f"Desired producer quant_type for {tensor_name} doesn't match existing type.") + + if consumer_names: + if "convert" not in overrides: + overrides["convert"] = {} + consumer_type_info.save_to_dict(overrides["convert"]) + + convert_dict = overrides["convert"] + if consumer_type_info != QuantTypeInfo.load_from_dict(convert_dict): + raise ValueError(f"Desired consumer quant_type for {tensor_name} doesn't match existing type.") + + if "recv_nodes" not in convert_dict: + convert_dict["recv_nodes"] = set() + + convert_dict["recv_nodes"].update(consumer_names) + + def _check_nodes_are_not_convert_consumers(self, tensor_name: str, node_names: set[str]): + """ + Returns true if the given nodes do not consume/receive a converted quantization type. + + Params: + tensor_name: The name of the tensor to check. + node_names: Set of node names that should not be consumers of the converted type. + """ + if tensor_name not in self.overrides or not self.overrides[tensor_name]: + return True + + overrides = self.overrides[tensor_name][0] + + if "convert" not in overrides: + return True + + convert_dict = overrides["convert"] + + if "recv_nodes" not in convert_dict: + return False + + return not convert_dict["recv_nodes"].intersection(node_names) + + def __is_tensor_quantizable(self, tensor_name): + weight = self.initializers.get(tensor_name) + if weight is not None: + if weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16): + return True + elif tensor_name in self.value_infos: + vi = self.value_infos[tensor_name] + if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in ( + onnx.TensorProto.FLOAT, + onnx.TensorProto.FLOAT16, + ): + return True + + return False diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index e9affae7ac263..479eaf5b0c542 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -3,6 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + +import copy +import logging from pathlib import Path import numpy as np @@ -11,6 +15,8 @@ from ...calibrate import CalibrationDataReader, CalibrationMethod from ...quant_utils import QuantType from ...quantize import StaticQuantConfig +from ...tensor_quant_overrides import TensorQuantOverridesHelper +from .mixed_precision_overrides_utils import MixedPrecisionTensorQuantOverridesFixer Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} @@ -18,6 +24,20 @@ MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB +def warn_unable_to_override( + node: onnx.NodeProto, + what_str: str, + tensor_name: str, + io_kind: str, +): + logging.warning( + f"Unable to override {what_str} for {node.op_type} node's {io_kind} " + "because it has already been overridden! Check the initial quantization overrides provided " + "to get_qnn_qdq_config() if the generated QDQ model does not run on QNN EP. " + f"Node name: {node.name}, {io_kind} name: {tensor_name}" + ) + + def get_qnn_qdq_config( model_input: Path, calibration_data_reader: CalibrationDataReader, @@ -25,14 +45,20 @@ def get_qnn_qdq_config( activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, per_channel=False, + init_overrides=None, + add_qtype_converts=True, + activation_symmetric=False, + weight_symmetric=None, ): if per_channel: raise ValueError("QNN EP does not yet support per-channel quantization.") + if weight_symmetric is None: + weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16} + model = onnx.load_model(model_input, load_external_data=False) op_types = set() - tensor_quant_overrides = {} model_has_external_data = False name_to_initializer = {} @@ -43,52 +69,40 @@ def get_qnn_qdq_config( if onnx.external_data_helper.uses_external_data(initializer): model_has_external_data = True - # Setup quantization overrides for specific operator types - for node in model.graph.node: - op_types.add(node.op_type) + overrides_helper = TensorQuantOverridesHelper(copy.deepcopy(init_overrides) if init_overrides else {}) - if node.op_type == "MatMul" and activation_type in Q16_TYPES and weight_type in Q8_TYPES: - weight_symmetric = weight_type == QuantType.QInt8 + if not overrides_helper.empty() and add_qtype_converts: + # Fix mixed-precision overrides. + overrides_fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model( + overrides_helper, model, activation_type + ) + overrides_fixer.apply(activation_type, activation_symmetric) - # Override initializers to use the weight_type - for input_name in node.input: - if input_name in name_to_initializer: - tensor_quant_overrides[input_name] = [{"quant_type": weight_type, "symmetric": weight_symmetric}] - elif node.op_type == "LayerNormalization" and activation_type in Q16_TYPES and weight_type in Q8_TYPES: - weight_symmetric = weight_type == QuantType.QInt8 + # Setup quantization overrides for specific operator types to ensure compatibility with QNN EP. + qnn_compat = QnnCompatibilityOverrides( + activation_type, + weight_type, + activation_symmetric, + weight_symmetric, + overrides_helper, + name_to_initializer, + ) - # Override initializers to use the weight_type. Don't override the bias input. - for i in range(2): - input_name = node.input[i] - if input_name in name_to_initializer: - tensor_quant_overrides[input_name] = [{"quant_type": weight_type, "symmetric": weight_symmetric}] - elif node.op_type == "Sigmoid": - if activation_type == QuantType.QUInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 65536.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.uint16)} - ] - elif activation_type == QuantType.QInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.int16)} - ] - elif node.op_type == "Tanh": - if activation_type == QuantType.QUInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(32768, dtype=np.uint16)} - ] - elif activation_type == QuantType.QInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.int16)} - ] + for node in model.graph.node: + op_types.add(node.op_type) + qnn_compat.process_node(node) extra_options = { "MinimumRealRange": 0.0001, "DedicatedQDQPair": False, # Let ORT optimizer duplicate DQ nodes - "TensorQuantOverrides": tensor_quant_overrides, + "TensorQuantOverrides": overrides_helper.get_dict(), + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, } # TODO: Remove this extra option once ORT uses an ONNX version that supports 16-bit Q/DQ ops. - if activation_type in Q16_TYPES or weight_type in Q16_TYPES: + overrides_have_int16 = any(t in Q16_TYPES for t in overrides_helper.get_quant_types()) + if activation_type in Q16_TYPES or weight_type in Q16_TYPES or overrides_have_int16: extra_options["UseQDQContribOps"] = True return StaticQuantConfig( @@ -100,3 +114,163 @@ def get_qnn_qdq_config( use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), extra_options=extra_options, ) + + +class QnnCompatibilityOverrides: + """ + Helper that processes nodes to generate quantization overrides that make the resulting QDQ model + compatible with QNN EP. + """ + + def __init__( + self, + default_activation_qtype: QuantType, + default_weight_qtype: QuantType, + activation_symmetric: bool, + weight_symmetric: bool, + overrides: TensorQuantOverridesHelper, + initializers: dict[str, onnx.TensorProto], + ): + self.default_activation_qtype = default_activation_qtype + self.default_weight_qtype = default_weight_qtype + self.activation_symmetric = activation_symmetric + self.weight_symmetric = weight_symmetric + self.overrides = overrides + self.initializers = initializers + + self.process_fns = { + "MatMul": self._process_matmul, + "LayerNormalization": self._process_layernorm, + "Sigmoid": self._process_sigmoid, + "Tanh": self._process_tanh, + } + + def process_node(self, node: onnx.NodeProto): + process_fn = self.process_fns.get(node.op_type) + + if process_fn is not None: + process_fn(node) + + def _process_matmul(self, node: onnx.NodeProto): + """ + Overrides MatMul's initializer input(s) to use the default weight type if: + - The default weight type is 8-bit + - One of the inputs is a 16-bit activation + """ + assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}" + if self.default_weight_qtype not in Q8_TYPES: + return + + input_16bit_act = None + input_wgt = None + + for input_name in node.input: + if input_name and input_name not in self.initializers: + qtype = self.overrides.get_node_input_qtype_info( + input_name, node.name, self.default_activation_qtype + ).quant_type + if qtype in Q16_TYPES: + input_16bit_act = input_name + else: + input_wgt = input_name + + # Override initializer to use the default weight type. + if input_16bit_act and input_wgt: + did_update = self.overrides.update_tensor_overrides( + input_wgt, + {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, + overwrite=False, + ) + + if not did_update: + warn_unable_to_override(node, "quant_type/symmetric", input_wgt, "input weight") + + def _process_layernorm(self, node: onnx.NodeProto): + """ + Overrides LayerNormalization's initializer input(s), except for bias, to use the default weight type if: + - The default weight type is 8-bit + - One of the inputs is a 16-bit activation + """ + assert node.op_type == "LayerNormalization", f"Expected LayerNormalization, but got {node.op_type}" + if self.default_weight_qtype not in Q8_TYPES: + return + + has_q16_activation = False + for input_name in node.input: + if input_name and input_name not in self.initializers: + qtype = self.overrides.get_node_input_qtype_info( + input_name, node.name, self.default_activation_qtype + ).quant_type + if qtype in Q16_TYPES: + has_q16_activation = True + break + + # Override initializers to use the self.default_weight_qtype. Don't override the bias input. + if has_q16_activation: + for i in range(2): + input_name = node.input[i] + if input_name and input_name in self.initializers: + did_update = self.overrides.update_tensor_overrides( + input_name, + {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, + overwrite=False, + ) + + if not did_update: + warn_unable_to_override(node, "quant_type/symmetric", input_name, "input weight") + + def _process_sigmoid(self, node: onnx.NodeProto): + """ + Overrides 16-bit Sigmoid's output scale and zero-point as per QNN requirements. + """ + assert node.op_type == "Sigmoid", f"Expected Sigmoid, but got {node.op_type}" + output_type = self.overrides.get_node_output_qtype_info( + node.output[0], self.default_activation_qtype + ).quant_type + + if output_type == QuantType.QUInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 65536.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.uint16), + }, + ) + elif output_type == QuantType.QInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + }, + ) + + def _process_tanh(self, node: onnx.NodeProto): + """ + Overrides 16-bit Tanh's output scale and zero-point as per QNN requirements. + """ + assert node.op_type == "Tanh", f"Expected Tanh, but got {node.op_type}" + output_type = self.overrides.get_node_output_qtype_info( + node.output[0], self.default_activation_qtype + ).quant_type + + if output_type == QuantType.QUInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(32768, dtype=np.uint16), + }, + ) + elif output_type == QuantType.QInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + }, + ) diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py index 610b96b9d7937..793d58cbc4e3e 100644 --- a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -7,11 +7,52 @@ import json from collections.abc import MutableMapping +from dataclasses import dataclass from typing import Any from .quant_utils import QuantType +@dataclass +class QuantTypeInfo: + """ + The quantization type information for a tensor override. + """ + + quant_type: QuantType + symmetric: bool | None = None # If None, assumes default is used. + reduce_range: bool | None = None # If None, assumes default is used. + + def __eq__(self, other: object): + if isinstance(other, QuantTypeInfo): + return ( + self.quant_type == other.quant_type + and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric) + and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range) + ) + return NotImplemented + + @staticmethod + def load_from_dict( + raw_dict: dict[str, Any], + default_activation_qtype: QuantType | None = None, + default_activation_symmetric: bool | None = None, + default_activation_reduce_range: bool | None = None, + ) -> QuantTypeInfo: + return QuantTypeInfo( + raw_dict.get("quant_type", default_activation_qtype), + raw_dict.get("symmetric", default_activation_symmetric), + raw_dict.get("reduce_range", default_activation_reduce_range), + ) + + def save_to_dict(self, raw_dict: dict[str, Any]): + raw_dict["quant_type"] = self.quant_type + if self.symmetric is not None: + raw_dict["symmetric"] = self.symmetric + if self.reduce_range is not None: + raw_dict["reduce_range"] = self.reduce_range + + class TensorQuantOverridesHelper(MutableMapping): """ Utility wrapper over the tensor quantization overrides passed via extra_options. @@ -184,9 +225,99 @@ def is_valid( return True, None + def update_tensor_overrides( + self, + tensor_name: str, + new_vals: dict[str, Any], + channels: list[int] | None = None, + overwrite: bool = True, + ) -> bool: + if not new_vals: + return False + + channels = set(channels) if channels is not None else None + have_overrides = self.overrides.get(tensor_name) + + # If `overwrite` is False, check if we would overwrite anything. + do_update = True + if not overwrite and have_overrides: + for channel, overrides in enumerate(self.overrides[tensor_name]): + if channels is not None and channel not in channels: + continue + if set(new_vals).intersection(set(overrides)): + do_update = False + break + + # Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites). + if do_update: + if not have_overrides: + self.overrides[tensor_name] = [{}] + + for channel, overrides in enumerate(self.overrides[tensor_name]): + if channels is not None and channel not in channels: + continue + overrides.update(new_vals) + + return do_update + + def get_node_output_qtype_info( + self, + output_name: str, + default_qtype: QuantType | None, + default_symmetric: bool | None = None, + ) -> QuantTypeInfo: + if output_name not in self.overrides: + return QuantTypeInfo(default_qtype, default_symmetric) + + # Get the first overrides dict in the list. This works for both per-tensor and per-channel + # quantization because all channels must use the same quant type. + tensor_overrides = self.overrides[output_name][0] + + return QuantTypeInfo( + tensor_overrides.get("quant_type", default_qtype), + tensor_overrides.get("symmetric", default_symmetric), + ) + + def get_node_input_qtype_info( + self, + input_name: str, + node_name: str, + default_qtype: QuantType | None, + default_symmetric: bool | None = None, + default_reduce_range: bool | None = None, + ) -> QuantTypeInfo: + if input_name not in self.overrides or not self.overrides[input_name]: + return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range) + + # Get the first overrides dict in the list. This works for both per-tensor and per-channel + # quantization because all channels must use the same quant type. + tensor_overrides = self.overrides[input_name][0] + producer_type = tensor_overrides.get("quant_type", default_qtype) + + if "convert" not in tensor_overrides: + return QuantTypeInfo(producer_type, default_symmetric, default_reduce_range) + + # This tensor is converted. Check if the node gets the original qtype or the converted qtype. + convert_dict = tensor_overrides["convert"] + qtype_info = QuantTypeInfo( + producer_type, + convert_dict.get("symmetric", default_symmetric), + convert_dict.get("reduce_range", default_reduce_range), + ) + + # Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node + # is in the list of consumers (recv_nodes). + if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]): + qtype_info.quant_type = convert_dict["quant_type"] + + return qtype_info + def pprint_str(self, indent=None) -> str: return json.dumps(self.overrides, default=str, indent=indent) + def empty(self) -> bool: + return not self.overrides + def get_dict(self) -> dict[str, list[dict[str, Any]]]: return self.overrides diff --git a/onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py b/onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py new file mode 100644 index 0000000000000..96277056adee0 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import onnx + +from onnxruntime.quantization import QuantType +from onnxruntime.quantization.execution_providers.qnn.mixed_precision_overrides_utils import ( + MixedPrecisionTensorQuantOverridesFixer, +) +from onnxruntime.quantization.tensor_quant_overrides import TensorQuantOverridesHelper + + +class TestMixedPrecisionQuantOverridesFixer(unittest.TestCase): + def build_test_model_1(self, shape): + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, shape) + input_1 = onnx.helper.make_tensor_value_info("input_1", onnx.TensorProto.FLOAT, shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, shape) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx.TensorProto.FLOAT, shape) + output_2 = onnx.helper.make_tensor_value_info("output_2", onnx.TensorProto.FLOAT, shape) + + op1_node = onnx.helper.make_node("Sigmoid", ["input_0"], ["op1_out"], name="op1") + op2_node = onnx.helper.make_node("Cos", ["input_1"], ["op2_out"], name="op2") + op3_node = onnx.helper.make_node("Sin", ["op1_out"], ["op3_out"], name="op3") + op4_node = onnx.helper.make_node("Tanh", ["op2_out"], ["op4_out"], name="op4") + op5_node = onnx.helper.make_node("Mul", ["op3_out", "op4_out"], ["op5_out"], name="op5") + op6_node = onnx.helper.make_node("Relu", ["op5_out"], ["output_0"], name="op6") + op7_node = onnx.helper.make_node("Cos", ["op2_out"], ["output_1"], name="op7") + op8_node = onnx.helper.make_node("Sigmoid", ["op2_out"], ["output_2"], name="op8") + + graph = onnx.helper.make_graph( + [ + op1_node, + op2_node, + op3_node, + op4_node, + op5_node, + op6_node, + op7_node, + op8_node, + ], + "mixed_prec_test", + [input_0, input_1], + [output_0, output_1, output_2], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_fixer_1(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QUInt8 + raw_overrides = {"op4_out": [{"quant_type": QuantType.QUInt16}]} + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [ + {"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op6"}}} + ], + } + self.assertDictEqual(overrides.get_dict(), expected) + + def test_fixer_with_symmetric(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QInt8 + raw_overrides = {"op4_out": [{"quant_type": QuantType.QInt16, "symmetric": True}]} + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "op2_out": [ + { + "quant_type": QuantType.QInt8, + "convert": {"quant_type": QuantType.QInt16, "symmetric": True, "recv_nodes": {"op4"}}, + } + ], + "op3_out": [ + { + "quant_type": QuantType.QInt8, + "convert": {"quant_type": QuantType.QInt16, "symmetric": True, "recv_nodes": {"op5"}}, + } + ], + "op4_out": [{"quant_type": QuantType.QInt16, "symmetric": True}], + "op5_out": [ + { + "quant_type": QuantType.QInt16, + "symmetric": True, + "convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"op6"}}, + } + ], + } + self.assertDictEqual(overrides.get_dict(), expected) + + def test_fixer_upgrade_output(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QUInt8 + raw_overrides = { + "op4_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + self.assertDictEqual(overrides.get_dict(), expected) + + def test_fixer_upgrade_input(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QUInt8 + raw_overrides = {"op4_out": [{"quant_type": QuantType.QUInt16}], "input_0": [{"quant_type": QuantType.QUInt16}]} + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "input_0": [{"quant_type": QuantType.QUInt16}], + "op1_out": [ + {"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op3"}}} + ], + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [ + {"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op6"}}} + ], + } + self.assertDictEqual(overrides.get_dict(), expected) diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 9ea4719f3c595..77f20b3caed96 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -11,12 +11,12 @@ import numpy as np import onnx -from onnxruntime import quantization +from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantType, quantize_static from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain -class DummyDataReader(quantization.CalibrationDataReader): +class DummyDataReader(CalibrationDataReader): def __init__(self, activations): self.iterator = ({"INP": act} for act in activations) @@ -81,11 +81,11 @@ def perform_qdq_quantization(self, output_model_name, extra_options=None, per_ch if activation_type is None: activation_type = self.default_act_qtype - quantization.quantize_static( + quantize_static( model_input="model.onnx", model_output=output_model_name, calibration_data_reader=DummyDataReader(self.activations), - quant_format=quantization.QuantFormat.QDQ, + quant_format=QuantFormat.QDQ, activation_type=activation_type, weight_type=self.default_wgt_qtype, per_channel=per_channel, @@ -223,8 +223,8 @@ def test_qdq_overrides1(self): "SIG_OUT": [ {"scale": np.array(1.0, dtype=np.float32), "zero_point": np.array(127, dtype=np.uint8)} ], - "WGT": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}], - "BIAS": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}], + "WGT": [{"quant_type": QuantType.QInt8, "symmetric": True, "reduce_range": True}], + "BIAS": [{"quant_type": QuantType.QInt8, "symmetric": True, "reduce_range": True}], } }, ) @@ -240,7 +240,7 @@ def test_qdq_overrides1(self): self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0)) # Weight should have different type, zero_point, and scale - self.assertEqual(wgt_zp.data_type, quantization.QuantType.QInt8.tensor_type) + self.assertEqual(wgt_zp.data_type, QuantType.QInt8.tensor_type) wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=True, symmetric=True) wgt_rmin, wgt_rmax = np.min(self.weight), np.max(self.weight) @@ -249,7 +249,7 @@ def test_qdq_overrides1(self): self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc)) # Bias should now be treated as a weight and should have different type, zero_point, and scale - self.assertEqual(bias_zp.data_type, quantization.QuantType.QInt8.tensor_type) + self.assertEqual(bias_zp.data_type, QuantType.QInt8.tensor_type) bias_qmin, bias_qmax = get_qmin_qmax_for_qType(bias_zp.data_type, reduce_range=True, symmetric=True) bias_rmin, bias_rmax = np.min(self.bias), np.max(self.bias) @@ -375,7 +375,7 @@ def test_qdq_overrides_per_channel2(self): """ rmin_vals = [0.0, 0.2] rmax_vals = [1.0, 0.8] - quant_type = quantization.QuantType.QUInt8 + quant_type = QuantType.QUInt8 reduce_ranges = [True, False] ( _, @@ -434,8 +434,8 @@ def test_16bit_overrides_set_ms_domain(self): activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations extra_options={ "TensorQuantOverrides": { - "INP": [{"quant_type": quantization.QuantType.QUInt16}], - "SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}], + "INP": [{"quant_type": QuantType.QUInt16}], + "SIG_OUT": [{"quant_type": QuantType.QUInt16}], } }, ) @@ -559,31 +559,446 @@ def test_override_validation_bad_combination(self): self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception)) - def test_get_qnn_qdq_config(self): + def test_get_qnn_qdq_config_sigmoid(self): """ - Test that the QNN-specific configs override the scale and zero-point of Sigmoid. + Test that the QNN-specific configs override the scale and zero-point of 16-bit Sigmoid. + """ + # Create float model with a Abs --> Sigmoid + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_out"], name="Abs_0"), + onnx.helper.make_node("Sigmoid", ["abs_out"], ["output_0"], name="Sigmoid_0"), + ], + "sigmoid_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (1, 2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (1, 2, 3))], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + other_override_0 = {"abs_out": [{"symmetric": True}]} + other_override_1 = { + "abs_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"Sigmoid_0"}}, + } + ] + } + other_override_2 = { + "abs_out": [ + { + "quant_type": QuantType.QInt8, + "convert": {"quant_type": QuantType.QInt16, "recv_nodes": {"Sigmoid_0"}}, + } + ] + } + + # Enumerate subtests (default_act_qtype, sigmoid_out_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt16, None, {}), # Sigmoid gets new scale/zp + (QuantType.QUInt16, None, other_override_0), # Sigmoid gets new scale/zp + (QuantType.QInt16, None, {}), # Sigmoid gets new scale/zp + (QuantType.QInt16, None, other_override_0), # Sigmoid gets new scale/zp + (QuantType.QUInt8, QuantType.QUInt16, other_override_1), # Sigmoid gets new scale/zp + (QuantType.QInt8, QuantType.QInt16, other_override_2), # Sigmoid gets new scale/zp + (QuantType.QUInt8, None, other_override_0), # Sigmoid DOES NOT gets new scale/zp + (QuantType.QInt8, None, {}), # Sigmoid DOES NOT gets new scale/zp + (QuantType.QInt8, QuantType.QInt8, {}), # Sigmoid DOES NOT gets new scale/zp + ] + + # Test that Sigmoid's output scale and zp should be overridden for 16-bit Sigmoid. + for default_act_qtype, sigmoid_out_qtype, abs_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, sigmoid_out_qtype=sigmoid_out_qtype, abs_override=abs_override + ): + init_overrides = {} + init_overrides.update(abs_override) + + if sigmoid_out_qtype is not None: + init_overrides["output_0"] = [{"quant_type": sigmoid_out_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "Sigmoid"}) + + if default_act_qtype == QuantType.QUInt16 or sigmoid_out_qtype == QuantType.QUInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QUInt16, + "scale": np.array(1.0 / 65536.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.uint16), + } + ], + ) + elif default_act_qtype == QuantType.QInt16 or sigmoid_out_qtype == QuantType.QInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QInt16, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + } + ], + ) + + def test_get_qnn_qdq_config_tanh(self): + """ + Test that the QNN-specific configs override the scale and zero-point of 16-bit Tanh. """ - self.build_float32_model() - qnn_config = get_qnn_qdq_config( - "model.onnx", DummyDataReader(self.activations), activation_type=quantization.QuantType.QUInt16 + # Create float model with a Abs --> Tanh + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_out"], name="Abs_0"), + onnx.helper.make_node("Tanh", ["abs_out"], ["output_0"], name="Tanh_0"), + ], + "tanh_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (1, 2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (1, 2, 3))], ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + other_override_0 = {"abs_out": [{"symmetric": True}]} + other_override_1 = { + "abs_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"Tanh_0"}}} + ] + } + other_override_2 = { + "abs_out": [ + {"quant_type": QuantType.QInt8, "convert": {"quant_type": QuantType.QInt16, "recv_nodes": {"Tanh_0"}}} + ] + } - self.assertEqual(qnn_config.extra_options["MinimumRealRange"], 0.0001) + # Enumerate subtests (default_act_qtype, tanh_out_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt16, None, {}), # Tanh gets new scale/zp + (QuantType.QUInt16, None, other_override_0), # Tanh gets new scale/zp + (QuantType.QInt16, None, {}), # Tanh gets new scale/zp + (QuantType.QInt16, None, other_override_0), # Tanh gets new scale/zp + (QuantType.QUInt8, QuantType.QUInt16, other_override_1), # Tanh gets new scale/zp + (QuantType.QInt8, QuantType.QInt16, other_override_2), # Tanh gets new scale/zp + (QuantType.QUInt8, None, other_override_0), # Tanh DOES NOT gets new scale/zp + (QuantType.QInt8, None, {}), # Tanh DOES NOT gets new scale/zp + (QuantType.QInt8, QuantType.QInt8, {}), # Tanh DOES NOT gets new scale/zp + ] - inp_zp, inp_sc, sig_out_zp, sig_out_sc, _, _, _, _, _, _ = self.perform_qdq_quantization( - "model_qnn_quant_overrides.onnx", - extra_options=qnn_config.extra_options, - activation_type=quantization.QuantType.QUInt16, + # Test that Tanh's output scale and zp should be overridden for 16-bit Tanh. + for default_act_qtype, tanh_out_qtype, abs_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, tanh_out_qtype=tanh_out_qtype, abs_override=abs_override + ): + init_overrides = {} + init_overrides.update(abs_override) + + if tanh_out_qtype is not None: + init_overrides["output_0"] = [{"quant_type": tanh_out_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "Tanh"}) + + if default_act_qtype == QuantType.QUInt16 or tanh_out_qtype == QuantType.QUInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QUInt16, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(32768, dtype=np.uint16), + } + ], + ) + elif default_act_qtype == QuantType.QInt16 or tanh_out_qtype == QuantType.QInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QInt16, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + } + ], + ) + + def test_get_qnn_qdq_config_matmul(self): + """ + Test that the QNN-specific configs override MatMul's initializer input type to 8-bit if + the other input is 16-bit and the default weight type is 8-bit. + """ + # Create float model with a Abs --> MatMul + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"), + onnx.helper.make_node("MatMul", ["abs_0_out", "weight"], ["matmul_0_out"], name="MatMul_0"), + onnx.helper.make_node("Abs", ["matmul_0_out"], ["output_0"], name="Abs_1"), + ], + "matmul_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 2))], + initializer=[onnx.numpy_helper.from_array(np.random.random((3, 2)).astype(np.float32), "weight")], ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + q16_qtypes = {QuantType.QUInt16, QuantType.QInt16} + q8_qtypes = {QuantType.QUInt8, QuantType.QInt8} + symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16} + + other_override_0 = {"output_0": [{"symmetric": True}]} + other_override_1 = { + "matmul_0_out": [ + { + "quant_type": QuantType.QUInt16, + "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + other_override_2 = { + "matmul_0_out": [ + { + "quant_type": QuantType.QInt16, + "convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + convert_matmul_input = { + "abs_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"MatMul_0"}}, + } + ] + } - # Input should have uint16 quant type - self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16) + # Enumerate subtests (default_act_qtype, default_wgt_qtype, matmul_in_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_0), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_1), + (QuantType.QInt8, QuantType.QInt8, QuantType.QInt16, other_override_2), + (QuantType.QUInt16, QuantType.QUInt8, None, other_override_0), + (QuantType.QInt16, QuantType.QInt8, None, {}), + (QuantType.QUInt16, QuantType.QUInt16, None, other_override_0), + (QuantType.QInt16, QuantType.QInt16, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, convert_matmul_input), + ] - # Sigmoid output should have overridden scale/zp - self.assertEqual(sig_out_zp.int32_data[0], 0) - self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) - self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0)) + # Test if MatMul's weight input is overridden. + for default_act_qtype, default_wgt_qtype, matmul_input_qtype, other_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, + default_wgt_qtype=default_wgt_qtype, + matmul_input_qtype=matmul_input_qtype, + other_override=other_override, + ): + init_overrides = {} + init_overrides.update(other_override) + + if matmul_input_qtype is not None: + init_overrides["abs_0_out"] = [{"quant_type": matmul_input_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + weight_type=default_wgt_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "MatMul"}) + input_is_16bit = ( + (default_act_qtype in q16_qtypes) + or (matmul_input_qtype in q16_qtypes) + or (other_override == convert_matmul_input) + ) + weight_is_symmetric = default_wgt_qtype in symmetric_wgt_qtypes + + if input_is_16bit and default_wgt_qtype in q8_qtypes: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["weight"], + [ + { + "quant_type": default_wgt_qtype, + "symmetric": weight_is_symmetric, + } + ], + ) + elif init_overrides: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertNotIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + + self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"]) + + def test_get_qnn_qdq_config_layernorm(self): + """ + Test that the QNN-specific configs override LayerNorm's initializer input type to 8-bit if + the other input is 16-bit and the default weight type is 8-bit. + """ + # Create float model with a Abs --> LayerNormalization + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"), + onnx.helper.make_node( + "LayerNormalization", ["abs_0_out", "weight", "bias"], ["layernorm_0_out"], name="LayerNorm_0" + ), + onnx.helper.make_node("Abs", ["layernorm_0_out"], ["output_0"], name="Abs_1"), + ], + "layernorm_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 3))], + initializer=[ + onnx.numpy_helper.from_array(np.random.random((2, 3)).astype(np.float32), "weight"), + onnx.numpy_helper.from_array(np.random.random((2, 3)).astype(np.float32), "bias"), + ], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + q16_qtypes = {QuantType.QUInt16, QuantType.QInt16} + q8_qtypes = {QuantType.QUInt8, QuantType.QInt8} + symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16} + + other_override_0 = {"output_0": [{"symmetric": True}]} + other_override_1 = { + "layernorm_0_out": [ + { + "quant_type": QuantType.QUInt16, + "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + other_override_2 = { + "layernorm_0_out": [ + { + "quant_type": QuantType.QInt16, + "convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + convert_layernorm_input = { + "abs_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"LayerNorm_0"}}, + } + ] + } + + # Enumerate subtests (default_act_qtype, default_wgt_qtype, layernorm_in_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_0), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_1), + (QuantType.QInt8, QuantType.QInt8, QuantType.QInt16, other_override_2), + (QuantType.QUInt16, QuantType.QUInt8, None, other_override_0), + (QuantType.QInt16, QuantType.QInt8, None, {}), + (QuantType.QUInt16, QuantType.QUInt16, None, other_override_0), + (QuantType.QInt16, QuantType.QInt16, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, convert_layernorm_input), + ] + + # Test if LayerNorm's weight input is overridden. + for default_act_qtype, default_wgt_qtype, layernorm_input_qtype, other_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, + default_wgt_qtype=default_wgt_qtype, + layernorm_input_qtype=layernorm_input_qtype, + other_override=other_override, + ): + init_overrides = {} + init_overrides.update(other_override) + + if layernorm_input_qtype is not None: + init_overrides["abs_0_out"] = [{"quant_type": layernorm_input_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + weight_type=default_wgt_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "LayerNormalization"}) + input_is_16bit = ( + (default_act_qtype in q16_qtypes) + or (layernorm_input_qtype in q16_qtypes) + or (other_override == convert_layernorm_input) + ) + weight_is_symmetric = default_wgt_qtype in symmetric_wgt_qtypes + + if input_is_16bit and default_wgt_qtype in q8_qtypes: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["weight"], + [ + { + "quant_type": default_wgt_qtype, + "symmetric": weight_is_symmetric, + } + ], + ) + elif init_overrides: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertNotIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + + self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"]) + self.assertNotIn("bias", qnn_config.extra_options["TensorQuantOverrides"]) def test_get_qnn_qdq_config_ext_data(self): """ @@ -613,6 +1028,7 @@ def test_get_qnn_qdq_config_ext_data(self): ) qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations)) + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Add"}) self.assertTrue(qnn_config.use_external_data_format) From 1a0ba3f69f5075754ecae9c92abce9360861a7a5 Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 26 Mar 2024 13:09:20 +0800 Subject: [PATCH 09/11] Fix softmax export (#20057) ### Description Why we need to define softmax export logic here? For the usage `nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)` in the model, https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302 If dtype is specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. While existing ONNX exporter do the cast after the operation, which is not correct. (https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27). This override can be a workaround before PyTorch fix the issues in coming releases. (TODO: pengwa - add PyTorch versions when the issue is fixed). @thiagocrepaldi We may need a fix in PyTorch repo as well. ### Motivation and Context --- .../ortmodule/_custom_op_symbolic_registry.py | 33 ++++++++++-- .../python/orttraining_test_ortmodule_api.py | 54 +++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index f81aef5f6b9c4..dd7fea3ceda10 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -10,7 +10,7 @@ from packaging import version from packaging.version import Version from torch.onnx import register_custom_op_symbolic -from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args +from torch.onnx.symbolic_helper import parse_args from onnxruntime.training.utils import pytorch_type_to_onnx_dtype @@ -176,9 +176,9 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): try: # Tolerant to the case when sizes of indices are not available or not usable (for example # when DeepSpeed stage3 enabled, all weights size is (0), this will fail.) - indices_shape = _get_tensor_sizes(indices) + indices_shape = sym_help._get_tensor_sizes(indices) if indices_shape is not None and hasattr(weight.type(), "with_sizes"): - output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) + output_type = weight.type().with_sizes([*indices_shape, sym_help._get_tensor_dim_size(weight, 1)]) output.setType(output_type) except IndexError: output.setType(weight.type()) @@ -845,3 +845,30 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): ) return res + + +# Adapted from torch.onnx.symbolic_opset13.softmax - +# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27 +# We don't need overloads symbolic_opset9 because training support opsets >= 13. +# +# Why we need to define softmax export logic here? +# For the usage `nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)` in the model, +# https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302 +# If dtype is specified, the input tensor is casted to dtype before the operation is performed. +# This is useful for preventing data type overflows. While existing ONNX exporter do the cast after the operation. +# This override can be a workaround before PyTorch fix the issues in coming releases. +# (TODO: pengwa - add PyTorch versions when the issue is fixed). +@register_symbolic("softmax") +@parse_args("v", "i", "none") +def softmax(g, input, dim, dtype=None): + from torch.onnx import _type_utils + + casted_input = input + need_cast_for_compute = dtype and dtype.node().kind() != "prim::Constant" + if need_cast_for_compute: + parsed_dtype = sym_help._get_const(dtype, "i", "dtype") + casted_input = g.op("Cast", input, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()) + + softmax = g.op("Softmax", casted_input, axis_i=dim) + + return softmax diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7afad9145ed27..d6f55e787c320 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -33,6 +33,7 @@ from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype DEFAULT_OPSET = 17 @@ -6496,3 +6497,56 @@ def run_step(model, x, y, z): torch.cuda.synchronize() if original_val is not None: os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val + + +@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) +def test_overridden_softmax_export(softmax_compute_type): + class CustomSoftmaxExportTest(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, attn_weight): + return torch.nn.functional.softmax(attn_weight, dim=-1, dtype=softmax_compute_type) + + device = "cuda" + pt_model = CustomSoftmaxExportTest().to(device) + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="overridden_softmax_export") + ) + + def run_step(model, attn_weight): + prediction = model(attn_weight) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + attn_weight = torch.randn([20, 6, 10, 10], dtype=torch.float, device=device, requires_grad=True) + ort_attn_weight = copy.deepcopy(attn_weight) + pt_prediction = run_step(pt_model, attn_weight) + ort_prediction = run_step(ort_model, ort_attn_weight) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(attn_weight.grad, ort_attn_weight.grad) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + + # Check the ONNX Softmax is running in float32. + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + # Keep the logic aligned with _graph_execution_manager.py + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "torch_exported", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = [n for n in onnx_model.graph.node] + + assert onnx_nodes[0].op_type == "Cast" + to_attr = onnx_nodes[0].attribute[0] + assert to_attr.name == "to" + to_value = to_attr.i + assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected" From 0906c57c9e1ec60adaba2bee115eaf04748dee5e Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 26 Mar 2024 17:59:46 +0800 Subject: [PATCH 10/11] Pin Onnx Version (#20073) ### Description 1. change in build.py is to fix DML exception (https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=10&_a=summary) 2. change in requirements.txt is to fix exception in python packaging pipeline. https://dev.azure.com/aiinfra/Lotus/_build/results?buildId=430433&view=results ### Motivation and Context --------- Co-authored-by: Yi Zhang --- onnxruntime/test/python/requirements.txt | 4 ++-- tools/ci_build/build.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/python/requirements.txt b/onnxruntime/test/python/requirements.txt index e33fe0e4daded..dc158e0eebd19 100644 --- a/onnxruntime/test/python/requirements.txt +++ b/onnxruntime/test/python/requirements.txt @@ -1,2 +1,2 @@ -onnx -pytest \ No newline at end of file +onnx==1.15.0 +pytest diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 3c1bdfc54c12e..e1649ae251d88 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2087,7 +2087,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): run_subprocess( [sys.executable, "-m", "pip", "uninstall", "--yes", "onnx"], cwd=cwd, dll_path=dll_path ) - run_subprocess([sys.executable, "-m", "pip", "install", "-q", "onnx"], cwd=cwd, dll_path=dll_path) + run_subprocess( + [sys.executable, "-m", "pip", "install", "-q", "onnx==1.15.0"], cwd=cwd, dll_path=dll_path + ) run_subprocess([sys.executable, "onnxruntime_test_python_iobinding.py"], cwd=cwd, dll_path=dll_path) if args.use_cuda: From dfa891a2d8fab78cff51e27289774baf967214fd Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 26 Mar 2024 21:25:59 +0800 Subject: [PATCH 11/11] Fix memory stats printing (#20061) ### Fix memory stats printing The mmeory stats printing is failed when module is in eval mode, doing ORTModule wrap. At that time, runtime inspector for training manager should have training model being true, but got a false (because existing logic get the boolean from module.training). Runtime inspector as part of training manager or inference manager should know it is serving training or not explicitly, so we cannot depend on the stat of module.training during ORTModule initialization. ### Motivation and Context --- .../ortmodule/_graph_execution_manager.py | 24 +++++++++---- .../training/ortmodule/_inference_manager.py | 3 +- .../training/ortmodule/_runtime_inspector.py | 27 +++++++++++--- .../training/ortmodule/_training_manager.py | 4 +-- .../python/orttraining_test_ortmodule_api.py | 36 +++++++++++++++++++ 5 files changed, 78 insertions(+), 16 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 568c92b71277f..5123594bff387 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -54,10 +54,20 @@ def __init__( self, module: _FlattenedModule, debug_options: DebugOptions, + export_mode: int, fallback_manager: _FallbackManager, logger: logging.Logger, ): - """Manages construction and execution of ONNX graphs""" + """Manages construction and execution of ONNX graphs. + + Args: + module: The flatten PyTorch module to be executed. + debug_options: Debug options for ORTModule. + export_mode: export mode, should be torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL. + fallback_manager: Fallback manager to handle exceptions. + logger: Logger for ORTModule. + + """ super().__init__(module._original_module) @@ -88,16 +98,12 @@ def __init__( self._first_skip_check_warning = True - # Inspector for runtime information, for example input data, memory usage, etc. - self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) - self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) - # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL # To be instantiated in the concrete implementation of GraphExecutionManager - self._export_mode = None + self._export_mode = export_mode # Exporter can take extra arguments for ORTModule extensions # It cannot overlap with required/immutable arguments (validated in runtime) @@ -129,6 +135,12 @@ def __init__( # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False + # Inspector for runtime information, for example input data, memory usage, etc. + self._runtime_inspector = RuntimeInspector( + self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING + ) + self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) + # Load ATen operator executor extension. load_aten_op_executor_cpp_extension() diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 6690af9b71bf1..13145c7c79091 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -28,8 +28,7 @@ class InferenceManager(GraphExecutionManager): """ def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger): - super().__init__(model, debug_options, fallback_manager, logger) - self._export_mode = torch.onnx.TrainingMode.EVAL + super().__init__(model, debug_options, torch.onnx.TrainingMode.EVAL, fallback_manager, logger) @staticmethod def execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index d3fe132609a90..5c86070430e81 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -46,11 +46,18 @@ class RuntimeInspector: Runtime inspector for ORTModule. """ - def __init__(self, logger: Logger, module: torch.nn.Module): + def __init__(self, logger: Logger, module: torch.nn.Module, training: bool): + """Initialize runtime inspector. + + Args: + logger: Logger. + module: Torch module. + training: a boolean indicating whether the module is in training mode. + """ self._logger = logger self.input_density_ob: Union[InputDensityObserver, None] = None - self.memory_ob = MemoryObserver(module, self._logger) + self.memory_ob = MemoryObserver(module, self._logger, training) def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None: """Initialize input inspector from the given ONNX model and user input names. @@ -479,7 +486,14 @@ class MemoryObserver: NORMALIZER_FACTOR = float(1024 * 1024) NORMALIZER_UNIT = "MiB" - def __init__(self, m: torch.nn.Module, logger: Logger): + def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): + """Initialize memory observer. + + Args: + m: Torch module. + logger: Logger. + training: a boolean indicating whether the module is in training mode. + """ self._logger = logger self._is_enabled = True @@ -503,7 +517,10 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._rank_info = f"[{self._rank}/{self._world_size}]" self._pre_phase = Phase.INVALID - self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD + + # Cannot infer it is for training or inferencing purpose from module.training, + # because it probabbly is not set correctly when this happens. + self._last_phase = Phase.POST_BACKWARD if training else Phase.POST_FORWARD self._is_first_inspect = True @@ -721,7 +738,7 @@ def _get_user_config_without_freq(configs: str): notes.append(saving_recommendation) saving_recommendation = ( - "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n" + "[Memory Optimizer] Memory saving is calculated based on the 1st batch symbolic dim values:\n" ) for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5fa332d12f01c..a7426bce38a40 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -38,9 +38,7 @@ def __init__( fallback_manager: _FallbackManager, logger: Logger, ): - super().__init__(model, debug_options, fallback_manager, logger) - - self._export_mode = torch.onnx.TrainingMode.TRAINING + super().__init__(model, debug_options, torch.onnx.TrainingMode.TRAINING, fallback_manager, logger) self._forward_class = self._create_autofunction_class() @staticmethod diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d6f55e787c320..da217eb76949c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6499,6 +6499,42 @@ def run_step(model, x, y, z): os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val +def test_bert_memory_inspection(caplog): + original_val = os.environ.get("ORTMODULE_PRINT_MEMORY_STATS", None) + + # Create PyTorch model with dropout disabled. + pt_model = _get_bert_for_sequence_classification_model( + "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 + ) + + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = "1" + pt_model.eval() # Put it in evaluate mode by intention, in case some initialization in ORTModule use the module.is_training for its checks by mistake. + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO) # The logged memory info is in INFO level. + ) + + def run_step(model, x, y, z): + outputs = model(x, y, None, None, None, None, z) + loss = outputs[0] + loss.backward() + + ort_model.train() + for _ in range(32): + x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") + run_step(ort_model, x, y, z) + + info_records = [ + record.message for record in caplog.records if record.levelname == "INFO" and "(MiB) | phase:" in record.message + ] + + assert len(info_records) == 4 * 11 + + # Make sure environment variable is restored to its original value after the run is completed. + torch.cuda.synchronize() + if original_val is not None: + os.environ["ORTMODULE_PRINT_MEMORY_STATS"] = original_val + + @pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) def test_overridden_softmax_export(softmax_compute_type): class CustomSoftmaxExportTest(torch.nn.Module):