From 2bc29244b4b6992667d06446c839426917945a29 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 22 Mar 2024 10:28:44 -0700 Subject: [PATCH 1/7] Support model with multiple SCE loss nodes (#20016) --- .../orttraining/core/framework/gradient_graph_builder.cc | 5 +++++ .../orttraining/core/optimizer/insert_output_rewriter.cc | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index d66591318d5c7..2ee4b5e1a173d 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -210,6 +210,11 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c continue; } const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()]; + if (!node_arg) { + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << " because it is not found in the graph."; + continue; + } const auto [is_tensor_type, is_allowed_type_for_grad, type] = IsAllowedForGradient(graph_, node_arg); if (is_tensor_type) { if (!is_allowed_type_for_grad) { diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index 2aade8c9bc1f9..61fc8d5492c2b 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -44,7 +44,7 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits. } - NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t); + NodeArg& node_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(X->Name() + "_log_prob"), &t); outputs.push_back(&node_arg); From 7e84ba0ea30f3642c75d8d3fce5626766ce5a20e Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Fri, 22 Mar 2024 10:39:19 -0700 Subject: [PATCH 2/7] remove const cast for DLManagedTensor (#20015) ### Description Removing const_cast as it might lead to unknown behavior. Specifying DLMangedTensor as a const doesn't seem to be necessary and I have tested this by running torch_ort.configure. Not sure what other tests which needs to be done. Background can be found in this [PR](https://github.com/microsoft/onnxruntime/pull/19982) ### Motivation and Context --- .../torch_cpp_extensions/aten_op_executor/aten_op_executor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 4148e63d58619..f4d2f68d4d8b5 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -36,7 +36,7 @@ struct ATenOperator { size_t return_size; std::vector ret_kinds; - c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const { + c10::IValue ToIValueArgument(DLManagedTensor* dlpack, size_t index) const { TORCH_INTERNAL_ASSERT(index < argument_size); bool is_optional = is_optional_arguments[index]; TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] || @@ -57,7 +57,7 @@ struct ATenOperator { c10::IValue i_value; // Create the torch tensor from this DLPack no matter we need it or not below, // so that the dlpack's deleter will be triggered when torch tensor is out of scope. - at::Tensor tensor = at::fromDLPack(const_cast(dlpack)); + at::Tensor tensor = at::fromDLPack(dlpack); switch (elem_kinds[index]) { case c10::TypeKind::TensorType: { i_value = is_optional ? c10::IValue(c10::optional(tensor)) : c10::IValue(tensor); From f9cddd2cf5730bb330dc417ba461a684ea678444 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 22 Mar 2024 14:44:34 -0700 Subject: [PATCH 3/7] Remove early stopping from LLaMA end-to-end benchmarking (#20033) ### Description This PR removes early stopping from the end-to-end LLaMA-2 benchmark script. ### Motivation and Context This allows models to always generate the requested number of new tokens. --- .../python/tools/transformers/models/llama/benchmark_e2e.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 4d0d2e68e8983..47b7f35cbdd7c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -400,11 +400,7 @@ def main(): sampling_times.append(sampling_end_time - sampling_start_time) all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) - - # Return early if all batch entries have reached EOS token id current_length += 1 - if torch.all(has_eos) or current_length > max_length: - break # Update inputs for next inference run inputs["input_ids"] = tokens_to_add From 3076b569472d0cbdae5e3657e3c267a63830b2b3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 22 Mar 2024 16:17:47 -0700 Subject: [PATCH 4/7] Make MS Debug engine SymInitialize() called as needed. (#20036) ### Description Initialize Symbol engine as needed with no duplicate calls. ### Motivation and Context Currently absel library may call SymInitialize more than once when shared libraries are involved. However, this can only be called only once per process. Our debug_alloc also may call it when enabled. This change enables intialization to proceed only when needed with no duplicate effort. --- cmake/patches/abseil/absl_windows.patch | 98 ++++++++++++++++--- .../core/platform/windows/debug_alloc.cc | 64 ++++++++---- 2 files changed, 131 insertions(+), 31 deletions(-) diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 66ef0c5125a74..584c49d612293 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -25,17 +25,91 @@ index a6efc98e..8c4de8e7 100644 "/wd4800", ] diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index 0d6c1ec3..75fd935f 100644 +index e6e11949..0aa7d868 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py -@@ -132,10 +132,6 @@ COPT_VARS = { - "/wd4068", # unknown pragma - # qualifier applied to function type has no meaning; ignored - "/wd4180", -- # conversion from 'type1' to 'type2', possible loss of data -- "/wd4244", -- # conversion from 'size_t' to 'type', possible loss of data -- "/wd4267", - # The decorated name was longer than the compiler limit - "/wd4503", - # forcing value to bool 'true' or 'false' (performance warning) +@@ -115,10 +115,6 @@ MSVC_WARNING_FLAGS = [ + "/wd4068", # unknown pragma + # qualifier applied to function type has no meaning; ignored + "/wd4180", +- # conversion from 'type1' to 'type2', possible loss of data +- "/wd4244", +- # conversion from 'size_t' to 'type', possible loss of data +- "/wd4267", + # The decorated name was longer than the compiler limit + "/wd4503", + # forcing value to bool 'true' or 'false' (performance warning) +diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc +index 53a099a1..34d210d6 100644 +--- a/absl/debugging/symbolize_win32.inc ++++ b/absl/debugging/symbolize_win32.inc +@@ -35,15 +35,15 @@ ABSL_NAMESPACE_BEGIN + + static HANDLE process = NULL; + +-void InitializeSymbolizer(const char*) { +- if (process != nullptr) { +- return; +- } ++namespace { ++void InitializeSymbolizerImpl() { ++ + process = GetCurrentProcess(); + + // Symbols are not loaded until a reference is made requiring the + // symbols be loaded. This is the fastest, most efficient way to use + // the symbol handler. ++ + SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); + if (!SymInitialize(process, nullptr, true)) { + // GetLastError() returns a Win32 DWORD, but we assign to +@@ -54,6 +54,36 @@ void InitializeSymbolizer(const char*) { + } + } + ++bool LookupAndInitialize(const void* pc, SYMBOL_INFO* symbol) { ++ auto hProcess = (process != NULL) ? process : GetCurrentProcess(); ++ if (SymFromAddr(hProcess, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ if (GetLastError() == ERROR_INVALID_HANDLE && process == NULL) { ++ InitializeSymbolizerImpl(); ++ if (SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ return false; ++ } ++ } else { ++ return false; ++ } ++ return false; ++ } ++ return true; ++} ++} ++ ++void InitializeSymbolizer(const char*) { ++ if (process != nullptr) { ++ return; ++ } ++ ++ alignas(SYMBOL_INFO) char buf[sizeof(SYMBOL_INFO) + MAX_SYM_NAME]; ++ SYMBOL_INFO* symbol = reinterpret_cast(buf); ++ symbol->SizeOfStruct = sizeof(SYMBOL_INFO); ++ symbol->MaxNameLen = MAX_SYM_NAME; ++ ++ static_cast(LookupAndInitialize(reinterpret_cast(&InitializeSymbolizer), symbol)); ++} ++ + bool Symbolize(const void* pc, char* out, int out_size) { + if (out_size <= 0) { + return false; +@@ -62,9 +92,11 @@ bool Symbolize(const void* pc, char* out, int out_size) { + SYMBOL_INFO* symbol = reinterpret_cast(buf); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; +- if (!SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol)) { ++ ++ if(!LookupAndInitialize(pc, symbol)) { + return false; + } ++ + const size_t out_size_t = static_cast(out_size); + strncpy(out, symbol->Name, out_size_t); + if (out[out_size_t - 1] != '\0') { diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index ff6a059607367..f3520b4f7f7f5 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -55,41 +55,67 @@ struct MemoryBlock { }; struct SymbolHelper { - SymbolHelper() noexcept { - SymSetOptions(SymGetOptions() | SYMOPT_DEFERRED_LOADS); - SymInitialize(GetCurrentProcess(), nullptr, true); + HANDLE process_handle_ = GetCurrentProcess(); + bool initialized_ = false; + + bool InitializeWhenNeeded() { + // We try only once + if (!initialized_) { + SymSetOptions(SymGetOptions() | SYMOPT_DEFERRED_LOADS); + // We use GetCurrentProcess() because other libs are likely to use it + if (!SymInitialize(process_handle_, nullptr, true)) { + const unsigned long long error{GetLastError()}; + std::cerr << "SymInitialize() failed: " << error << std::endl; + return false; + } + initialized_ = true; + } + return true; + } + + SymbolHelper() = default; + + static constexpr size_t kInitialBufferSize = sizeof(SYMBOL_INFO) + MAX_SYM_NAME; + + bool LoookupSymAndInitialize(const ULONG_PTR address, char* buffer, size_t buffer_size, SYMBOL_INFO* symbol) { + if (SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + if (GetLastError() == ERROR_INVALID_HANDLE) { + // Try to initialize first + if (!InitializeWhenNeeded() || SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + return false; + } + } else { + _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + return false; + } + } + return true; } void Lookup(std::string& string, const ULONG_PTR address) { - char buffer[2048] = {0}; - Symbol symbol; - if (SymFromAddr(GetCurrentProcess(), address, 0, &symbol) == false) { - _snprintf_s(buffer, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + alignas(SYMBOL_INFO) char buffer[kInitialBufferSize] = {0}; + SYMBOL_INFO* symbol = reinterpret_cast(buffer); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; + + if (!LoookupSymAndInitialize(address, buffer, kInitialBufferSize, symbol)) { string.append(buffer); return; } Line line; DWORD displacement; - if (SymGetLineFromAddr(GetCurrentProcess(), address, &displacement, &line) == false) { - _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol.Name); + if (SymGetLineFromAddr(process_handle_, address, &displacement, &line) == false) { + _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol->Name); string.append(buffer); return; } - _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol.Name); + _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol->Name); string.append(buffer); } - struct Symbol : SYMBOL_INFO { - Symbol() noexcept { - SizeOfStruct = sizeof(SYMBOL_INFO); - MaxNameLen = _countof(buffer); - } - - char buffer[1024] = {0}; - }; - struct Line : IMAGEHLP_LINE { Line() noexcept { SizeOfStruct = sizeof(IMAGEHLP_LINE); From 71551dacd510a9b85d6ef9fa12af319fa4687592 Mon Sep 17 00:00:00 2001 From: Xiaoyu <85524621+xiaoyu-work@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:40:58 -0700 Subject: [PATCH 5/7] Add ModelProto support for transformers optimize_model (#19990) ### Description Add `ModelProto` support as an input to transformers `optimize_model` API. ### Motivation and Context Currently, the `optimize_model` API only accepts a model path as the input model. However, for large models, saving and loading from disk can be time-consuming. By adding `ModelProto` as an input option to the `optimize_model` API, significant time can be saved. --- .../python/tools/transformers/onnx_utils.py | 55 +++++++++++++++ .../python/tools/transformers/optimizer.py | 69 +++++++++++++------ .../python/transformers/test_onnx_utils.py | 38 ++++++++++ 3 files changed, 140 insertions(+), 22 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/onnx_utils.py create mode 100644 onnxruntime/test/python/transformers/test_onnx_utils.py diff --git a/onnxruntime/python/tools/transformers/onnx_utils.py b/onnxruntime/python/tools/transformers/onnx_utils.py new file mode 100644 index 0000000000000..64fade9369395 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_utils.py @@ -0,0 +1,55 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from fusion_utils import NumpyHelper +from onnx import ModelProto, TensorProto +from onnx.external_data_helper import set_external_data +from onnx_model import OnnxModel + +from onnxruntime import OrtValue + + +def extract_raw_data_from_model(model: ModelProto): + """ + Extract external data from model and return the external data as a list of tuples (name, value). + Note this function does not handle external data that is not loaded into the model as raw data. + + Args: + model (ModelProto): the model proto to extract external data from. + Returns: + (external_names, external_values): a tuple of two lists of external data names and values. + """ + external_data = [] + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + name = initializer.name + + if initializer.HasField("raw_data"): + numpy_tensor = NumpyHelper.to_array(initializer) + ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) + external_data.append((name, ort_value)) + # mimic set_external_data + set_external_data(initializer, location="foo.bin") + initializer.name = name + initializer.ClearField("raw_data") + + return zip(*external_data) + + +def has_external_data(model: ModelProto): + """ + Check if the model has external data. + + Args: + model (ModelProto): the model proto to check for external data. + Returns: + bool: True if the model has external data, False otherwise. + """ + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: + return True + return False diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..068ccefef7d97 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -21,11 +21,12 @@ import logging import os import tempfile -from typing import Dict, List, Optional +from pathlib import Path +from typing import Dict, List, Optional, Union import coloredlogs from fusion_options import FusionOptions -from onnx import ModelProto, TensorProto, load_model +from onnx import ModelProto, load_model from onnx_model import OnnxModel from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel @@ -40,6 +41,9 @@ from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel +import onnxruntime +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + logger = logging.getLogger(__name__) # Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level @@ -64,7 +68,7 @@ def optimize_by_onnxruntime( - onnx_model_path: str, + onnx_model: Union[str, ModelProto], use_gpu: bool = False, optimized_model_path: Optional[str] = None, opt_level: Optional[int] = 99, @@ -80,7 +84,7 @@ def optimize_by_onnxruntime( Use onnxruntime to optimize model. Args: - onnx_model_path (str): the path of input onnx model. + onnx_model (str | ModelProto): the path of input onnx model or ModelProto. use_gpu (bool): whether the optimized model is targeted to run in GPU. optimized_model_path (str or None): the path of optimized model. opt_level (int): graph optimization level. @@ -95,8 +99,6 @@ def optimize_by_onnxruntime( assert opt_level in [1, 2, 99] from torch import version as torch_version - import onnxruntime - if ( use_gpu and provider is None @@ -105,9 +107,13 @@ def optimize_by_onnxruntime( ) ): logger.error("There is no gpu for onnxruntime to do optimization.") - return onnx_model_path + return onnx_model - model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) + model = ( + OnnxModel(load_model(onnx_model, load_external_data=False)) + if isinstance(onnx_model, str) + else OnnxModel(onnx_model) + ) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -125,7 +131,10 @@ def optimize_by_onnxruntime( sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL if optimized_model_path is None: - path_prefix = onnx_model_path[:-5] # remove .onnx suffix + if isinstance(onnx_model, str): + path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix + else: + path_prefix = "optimized_model" optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu") sess_options.optimized_model_filepath = optimized_model_path @@ -174,7 +183,20 @@ def optimize_by_onnxruntime( else: providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) + # For large model, extract external data from model and add to session options + if isinstance(onnx_model, ModelProto): + if has_external_data(onnx_model): + raise ValueError( + "ModelProto has external data not loaded into memory, ORT cannot create session. " + "Please load external data before calling this function. " + "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information." + ) + external_names, external_values = extract_raw_data_from_model(onnx_model) + sess_options.add_external_initializers(list(external_names), list(external_values)) + + # Inference session is only used to optimize the model. + onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model + onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -187,7 +209,7 @@ def optimize_by_fusion( num_heads: int = 0, hidden_size: int = 0, optimization_options: Optional[FusionOptions] = None, -): +) -> OnnxModel: """Optimize Model by graph fusion logic. Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable @@ -241,7 +263,7 @@ def optimize_by_fusion( def optimize_model( - input: str, + input: Union[str, ModelProto], model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, @@ -252,7 +274,7 @@ def optimize_model( verbose: bool = False, *, provider: Optional[str] = None, -): +) -> OnnxModel: """Optimize Model by OnnxRuntime and/or python fusion logic. ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). @@ -275,7 +297,7 @@ def optimize_model( For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters. Args: - input (str): input model path. + input (str | ModelProto): input model path or ModelProto. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. 0 allows detect the parameter from graph automatically. @@ -298,9 +320,9 @@ def optimize_model( if model_type not in MODEL_TYPES: logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.") - return OnnxModel(load_model(input)) + return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input) - (optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type] + (optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type] if opt_level is None: opt_level = default_opt_level @@ -316,11 +338,9 @@ def optimize_model( # Auto detect if input model has external data has_external_data_file = False - original_model = load_model(input, load_external_data=False) - for initializer in original_model.graph.initializer: - if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: - has_external_data_file = True - break + original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input + if has_external_data(original_model): + has_external_data_file = True del original_model if opt_level > 1: @@ -365,7 +385,12 @@ def optimize_model( if only_onnxruntime and not temp_model_path: logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True") - model = load_model(temp_model_path or input) + if temp_model_path is not None: + model = load_model(temp_model_path) + elif isinstance(input, str): + model = load_model(input) + else: + model = input if only_onnxruntime: optimizer = optimizer_class(model, num_heads, hidden_size) diff --git a/onnxruntime/test/python/transformers/test_onnx_utils.py b/onnxruntime/test/python/transformers/test_onnx_utils.py new file mode 100644 index 0000000000000..974991359795e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_utils.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import unittest + +import numpy +from onnx import ModelProto, TensorProto, helper +from onnx.external_data_helper import set_external_data + +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + + +class TestOnnxUtils(unittest.TestCase): + def test_extract_raw_data_from_model(self): + model = self._get_model_proto_with_raw_data(False) + external_names, external_values = extract_raw_data_from_model(model) + self.assertEqual(list(external_names), ["inputs"]) + self.assertEqual(len(external_values), 1) + self.assertEqual(external_values[0].numpy(), [0.0]) + + def test_has_external_data(self): + model = self._get_model_proto_with_raw_data() + self.assertTrue(has_external_data(model)) + + def test_has_external_data_with_no_external_data(self): + model = self._get_model_proto_with_raw_data(False) + self.assertFalse(has_external_data(model)) + + def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto: + input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None]) + output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None]) + raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes() + tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True) + if has_external_data: + set_external_data(tensor, location="foo.bin") + node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"]) + return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor])) From 3b4b99b90b7de7848e5c1e817ad19b32bf598b27 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 23 Mar 2024 08:53:50 -0700 Subject: [PATCH 6/7] Fix a bug in WASM's GEMM (#20023) ### Description Fix a bug in WASM's GEMM. The bug was found when running "ConvAddActivationFusionTests.ConvGemmDirect" unit test in a wasm build with address sanitizer enabled. When CountK=25, CountN=1, lda=25, ldc=1, the function I am modifying triggered a read out of bound error. The bug fix was provided by @fs-eire. --- onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp index 955b7c5deee9a..43a12b37e4ffa 100644 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp @@ -171,11 +171,9 @@ Return Value: if (k > 0) { Row0AElements0 = a[0]; - Row0AElements1 = a[1]; if (ProcessTwoRows) { Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; } BElements0 = MlasLoadFloat32x4(B + 0); From cdc5d72ba9dfcba38462d7fcfa7047fd6005fa5a Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Sat, 23 Mar 2024 11:05:08 -0700 Subject: [PATCH 7/7] [QDQ Quant] Support mixed-precision integer quantization via overrides (#19925) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Adds support for specifying mixed precision QDQ models via tensor quantization overrides. ### Motivation and Context This PR implements an approach for supported "mixed precision" models. The following figure demonstrates an example mixed precision model as defined in this PR. ![image](https://github.com/microsoft/onnxruntime/assets/19691973/40ae3bf9-b21a-4ba5-a1cd-41c1e08c21e7) A mixed precision QDQ model consists of regions with different activation/weight quantization data types. The boundary between regions converts between activation quantization data types (e.g., uint8 to uint16) using a DQ to Q sequence. The ability to specify regions with different quantization data types enables exploring the tradeoffs between accuracy and latency. A higher integer precision may improve accuracy at the expense of latency, so selectively promoting certain regions to a higher precision can aid in achieving a desirable balance in key metrics. #### Current support By default, the ORT quantizer supports specifying default activation and weight quantization data types for the entire model. A recent PR added support for specifying basic quantization overrides at the tensor level via the `extra_options["TensorQuantOverrides"]` configuration: ``` TensorQuantOverrides = dictionary : Default is {}. Set tensor quantization overrides. The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list contains a single dictionary. For per-channel quantization, the list contains a dictionary for each channel in the tensor. Each dictionary contains optional overrides with the following keys and values. 'quant_type' = QuantType : The tensor's quantization data type. 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also set `scale` or `zero_point`. 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also set `scale` or `zero_point`. 'rmax' = Float : Override the maximum real tensor value in calibration data. Invalid if also set `scale` or `zero_point`. 'rmin' = Float : Override the minimum real tensor value in calibration data. Invalid if also set `scale` or `zero_point`. ``` The tensor-level overrides are currently used to override the quantization type for weights/initializers or to set specific scale/zero-point values for a tensor (e.g., QNN requires Sigmoid to use a specific scale/zero-point at its output). However, these overrides are not typically used to override activation quantization types due in large part to operator data type constraints. Consider, for example, that all inputs and outputs to an Add operator must be of the same data type. Consequently, using tensor-level overrides to promote the Add’s output to 16-bits would force the inputs to also be overridden to 16-bit. In turn, this would have a cascading effect on potentially the entire graph. The solution implemented by this PR is to allow the specification of tensor boundaries where the activation quantization data type changes. #### The approach The following figure shows a model with a region that has been promoted to 16-bit from the default 8-bit activation type. ![image](https://github.com/microsoft/onnxruntime/assets/19691973/5998c301-ae20-4ac9-8a43-37f335cfcf8b) Note the following observations: - Op2’s output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type, while Op7 and Op8 consume the original u8 type. - Op3’s output is converted from u8 to u16. Op5 consumes the converted u16 type. - Op4’s output is just u16 (not converted). - Op5’s output is converted from u16 to u8. Op6 consumes the u8 type. The approach implemented by this PR uses the tensor-level quantization overrides to specify a tensor’s quantization type at both the producer and consumer ends. **The following shows the overrides necessary to create this mixed precision QDQ model.** ```python3 overrides = { “Op2_out”: [{“quant_type”: QUInt8, “convert”: {“quant_type”: QUInt16, “recv_nodes”: {“Op4”}}}], “Op3_out”: [{“quant_type”: QUInt8, “convert”: {“quant_type”: QUInt16, “recv_nodes”: {“Op5”}}}], “Op4_out”: [{“quant_type”: QUInt16}], “Op5_out”: [{“quant_type”: QUInt16, “convert”: {“quant_type”: QUInt8, “recv_nodes”: {“Op6”}}}] } ``` --- .../tools/quantization/base_quantizer.py | 323 +------- .../python/tools/quantization/onnx_model.py | 10 + .../tools/quantization/onnx_quantizer.py | 227 +++++- .../tools/quantization/operators/conv.py | 2 +- .../tools/quantization/operators/direct_q8.py | 4 +- .../tools/quantization/operators/gather.py | 4 +- .../tools/quantization/operators/gemm.py | 4 +- .../tools/quantization/operators/norm.py | 2 +- .../tools/quantization/operators/softmax.py | 38 +- .../tools/quantization/operators/split.py | 2 +- .../tools/quantization/qdq_quantizer.py | 711 ++++++++++++++++-- .../python/tools/quantization/registry.py | 3 +- .../quantization/tensor_quant_overrides.py | 214 ++++++ .../test/python/quantization/test_qdq.py | 594 ++++++++++++++- 14 files changed, 1744 insertions(+), 394 deletions(-) create mode 100644 onnxruntime/python/tools/quantization/tensor_quant_overrides.py diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 667d7047c1fbd..80617b7b5edaa 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -21,19 +21,15 @@ from .quant_utils import ( ONNX_TYPE_TO_NP_TYPE, TENSOR_NAME_QUANT_SUFFIX, - QuantizedValue, - QuantizedValueType, QuantType, - compute_scale_zp, - compute_scale_zp_float8, find_by_name, - get_qmin_qmax_for_qType, model_has_infer_metadata, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, tensor_proto_to_array, ) +from .tensor_quant_overrides import TensorQuantOverridesHelper class QuantizationParams: @@ -121,27 +117,17 @@ def __init__( self.opset_version = self.check_opset_version() - # Map of all original value names to quantized value names - self.quantized_value_map = {} + # Get tensor-level quantization overrides and ensure they are valid. + self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {})) - self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides() - self.quantization_params = self.calculate_quantization_params() - - # to store specified scale and zeropoint instead of calculated value, tensor_name->(scale, zeropoint) - self.used_scale_zp_map = {} - - def set_quant_scale_zp(self, tensor_name, value): - assert isinstance(value, tuple) and len(value) == 2, "value must be scale(float or float16) and zeropoint" - assert hasattr(value[0], "dtype") - assert tensor_name not in self.used_scale_zp_map, f"{tensor_name} has been setted before" - self.used_scale_zp_map[tensor_name] = value + initializer_names = {initzer.name for initzer in self.model.initializer()} + overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid( + initializer_names, self.value_infos.keys(), activation_qType + ) + if not overrides_valid: + raise ValueError(overrides_err) - def find_quant_scale_zp(self, input_name): - if input_name in self.used_scale_zp_map: - return self.used_scale_zp_map[input_name] - if self.parent is not None: - return self.parent.find_quantized_value(input_name) - return (None, None) + self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types() def quantize_model(self): raise NotImplementedError @@ -212,36 +198,16 @@ def check_opset_version(self): return opset_version - def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0): """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale """ - # Handle case where bias already in quantization map - if bias_name in self.quantized_value_map: - return self.quantized_value_map[bias_name].q_name - - # get scale for weight - weight_scale_name = self.quantized_value_map[weight_name].scale_name - weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) - weight_scale = tensor_proto_to_array(weight_initializer) - # get bias bias_initializer = find_by_name(bias_name, self.model.initializer()) bias_data = tensor_proto_to_array(bias_initializer) quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX - # get scale for input - if input_name in self.quantized_value_map: - input_scale_name = self.quantized_value_map[input_name].scale_name - elif input_name in self.quantization_params: - _, input_scale_name, _, _, _ = self._get_quantization_params(input_name) - else: - raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization") - - inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) - input_scale = tensor_proto_to_array(inputscale_initializer) - # quantize bias if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN: data = np.asarray(bias_data) @@ -293,22 +259,16 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0]) self.model.initializer_extend([packed_bias_zp_initializer]) - assert bias_name not in self.quantized_value_map - quantized_value = QuantizedValue( - bias_name, + return ( quantized_bias_name, quantized_bias_scale_name, quantized_bias_zp_name, - QuantizedValueType.Initializer, - 0 if bias_scale_data.size > 1 else None, - node_type=node_type, - node_qtype=node_qtype, + bias_scale_data, + node_type, + node_qtype, ) - self.quantized_value_map[bias_name] = quantized_value - - return quantized_bias_name - def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False): + def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False): """ :param weight: TensorProto initializer :param qType: type to quantize to @@ -316,22 +276,13 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei If keep_float_weight is False, quantize the weight, or don't quantize the weight. :return: quantized weight name, zero point name, scale name """ - # Find if this input is already quantized - if weight.name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight.name] - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" # Quantize weight data. Use quantization overrides if provided by the user. weight_data = tensor_proto_to_array(weight) - quant_overrides = self.get_per_tensor_quant_overrides(weight.name) + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name) if "quant_type" in quant_overrides: qType = quant_overrides["quant_type"].tensor_type # noqa: N806 @@ -392,19 +343,9 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) self.model.initializer_extend([q_weight_initializer]) - # Log entry for this quantized weight - quantized_value = QuantizedValue( - weight.name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight.name] = quantized_value return q_weight_name, zp_name, scale_name - def quantize_weight_per_channel( + def quantize_weight_per_channel_impl( self, weight_name, weight_qType, @@ -412,22 +353,13 @@ def quantize_weight_per_channel( reduce_range=True, keep_float_weight=False, ): - # Find if this input is already quantized - if weight_name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight_name] - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: raise ValueError("{} is not an initializer", weight_name) weights = tensor_proto_to_array(initializer) channel_count = weights.shape[channel_axis] - quant_overrides_for_channels = self.get_per_channel_quant_overrides(weight_name, channel_count) + quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(weight_name, channel_count) # If user provides per-channel quantization overrides, all channels must use the same quantization type. # So, just use the first channel's type. @@ -499,16 +431,6 @@ def quantize_weight_per_channel( zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale" - quantized_value = QuantizedValue( - weight_name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight_name] = quantized_value - # Update packed weight, zero point, and scale initializers zero_scale_shape = [initializer.dims[channel_axis]] scale_initializer = onnx.helper.make_tensor( @@ -530,194 +452,25 @@ def quantize_weight_per_channel( return q_weight_name, zp_name, scale_name - def _get_and_check_tensor_quant_overrides(self): - """ - Get tensor quantization overrides and check correctness. - """ - tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {}) - tensor_quant_override_types = set() - - # Validate that compatible/valid overrides are provided. - if tensor_quant_overrides: - initializer_names = self.model.get_initializer_name_set() - value_info_names = set(self.value_infos.keys()) - keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} - - for tensor_name, quant_overrides_list in tensor_quant_overrides.items(): - if tensor_name not in initializer_names and tensor_name not in value_info_names: - raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model") - - if not isinstance(quant_overrides_list, list): - raise ValueError(f"Tensor quantization overrides for '{tensor_name}' are not in a list") - - is_initializer = tensor_name in initializer_names - if not is_initializer and len(quant_overrides_list) > 1: - raise ValueError( - f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer" - ) - - quant_type = None - for index, quant_overrides in enumerate(quant_overrides_list): - if not isinstance(quant_overrides, dict): - raise ValueError( - f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict" - ) - - # For per-channel quantization, all channels must use the same quantization type. - # Therefore, if the user tries to override the quant_type for a channel, it must match in all - # other channels. - if index == 0: - quant_type = quant_overrides.get("quant_type") - if quant_type: - tensor_quant_override_types.add(quant_type.tensor_type) - elif quant_type != quant_overrides.get("quant_type"): - raise ValueError( - "Channel quantization types for tensor '{tensor_name}' do not match at index {index}." - ) - - has_scale = "scale" in quant_overrides - has_zero_point = "zero_point" in quant_overrides - - if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): - raise ValueError( - "Must provide both 'scale' and 'zero_point' if one of the overrides is provided" - ) - - if has_scale: - for key in keys_unsupported_with_scale_zp: - if key in quant_overrides: - raise ValueError( - f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'" - ) - - return tensor_quant_overrides, tensor_quant_override_types - - def get_per_tensor_quant_overrides(self, tensor_name): - quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}]) - num_overrides = len(quant_overrides_list) - if num_overrides > 1: - raise ValueError( - f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " - f"but found {num_overrides} per-channel overrides." - ) - - return quant_overrides_list[0] if num_overrides > 0 else {} - - def get_per_channel_quant_overrides(self, tensor_name, num_channels): - quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{} for i in range(num_channels)]) - - if len(quant_overrides_list) != num_channels: - raise ValueError( - f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " - f"but found {len(quant_overrides_list)} instead." - ) - - return quant_overrides_list - - def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): - """ - Create initializers and inputs in the graph for zero point and scale of output. - Zero point and scale values are obtained from self.quantization_params if specified. - parameter param_name: Name of the quantization parameter. - return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. - """ - zero_point_type = self.activation_qType - - if use_scale is None or use_zeropoint is None: - if self.quantization_params is None or param_name not in self.quantization_params: - logging.info(f'Quantization parameters for tensor:"{param_name}" not specified') - return False, "", "", "", "" - - params = self.quantization_params[param_name] - if not isinstance(params, QuantizationParams): - raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.") - if params is None or len(params) != 3: - raise ValueError( - "Quantization parameters should contain zero point, scale, quant type. " - f"Specified values for output {param_name}: {params}" - ) - - zero_point_values = np.array([params["zero_point"]]) - if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): - raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") - scale_values = np.array([params["scale"]]) - assert scale_values.dtype != np.float64 - zero_point_type = params["quant_type"] - else: - zero_point_values = np.array([use_zeropoint]) - scale_values = np.array([use_scale]) - params = self.quantization_params[param_name] - if "scale" in params: - dtype = params["scale"].dtype - scale_values = scale_values.astype(dtype) - assert scale_values.dtype != np.float64 - - zero_point_shape = [] - zero_point_name = param_name + "_zero_point" - scale_shape = [] - scale_name = param_name + "_scale" - - # Add initializers - init_zp = onnx.helper.make_tensor( - zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() - ) - self.model.add_initializer(init_zp) - if scale_values.dtype == np.float32: - scale_type = onnx.TensorProto.FLOAT - elif scale_values.dtype == np.float16: - scale_type = onnx.TensorProto.FLOAT16 - else: - raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") - init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) - self.model.add_initializer(init_scale) - - return True, scale_name, zero_point_name, scale_shape, zero_point_shape - - def calculate_quantization_params(self): + def adjust_tensor_ranges(self): if self.tensors_range is None: - return {} + return - # adjust tensor_ranges for input of Clip and Relu node for node in self.model.nodes(): - if node.op_type not in ["Clip", "Relu"]: - continue - if self.is_activation_symmetric: - continue - if not self.should_quantize_node(node): - continue - if len(self.model.input_name_to_nodes()[node.input[0]]) != 1: - continue - if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range: - continue - td = self.tensors_range[node.output[0]] - if not isinstance(td, TensorData): - raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") - self.tensors_range[node.input[0]] = td - - quantization_params = {} - for tensor_name in self.tensors_range: - td = self.tensors_range[tensor_name] - if not isinstance(td, TensorData): - raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") - - quant_overrides = self.get_per_tensor_quant_overrides(tensor_name) - - quant_type = self.activation_qType - if "quant_type" in quant_overrides: - quant_type = quant_overrides["quant_type"].tensor_type - - if "scale" in quant_overrides and "zero_point" in quant_overrides: - zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] - elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: - zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1]) - else: - rmin = quant_overrides.get("rmin", td.range_value[0]) - rmax = quant_overrides.get("rmax", td.range_value[1]) - symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) - reduce_range = quant_overrides.get("reduce_range", False) - qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) - zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) - - quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) - - return quantization_params + # adjust tensor_ranges for input of Clip and Relu node + if node.op_type in ["Clip", "Relu"]: + if self.is_activation_symmetric: + continue + if not self.should_quantize_node(node): + continue + if len(self.model.input_name_to_nodes()[node.input[0]]) != 1: + continue + if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range: + continue + td = self.tensors_range[node.output[0]] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") + self.tensors_range[node.input[0]] = td + # Adjust Softmax to range from 0.0 to 1.0 + elif node.op_type == "Softmax": + self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 716dd1eacec6a..174bf5fd1509c 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -441,6 +441,11 @@ def replace_input_of_all_nodes(self, old_input_name, new_input_name): for node in self.model.graph.node: ONNXModel.replace_node_input(node, old_input_name, new_input_name) + def replace_input_of_nodes(self, old_input_name, new_input_name, node_names_set): + for node in self.model.graph.node: + if node.name in node_names_set: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + @staticmethod def replace_node_output(node, old_output_name, new_output_name): assert isinstance(old_output_name, str) and isinstance(new_output_name, str) @@ -452,6 +457,11 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name): for node in self.model.graph.node: ONNXModel.replace_node_output(node, old_output_name, new_output_name) + def replace_output_of_nodes(self, old_output_name, new_output_name, node_names_set): + for node in self.model.graph.node: + if node.name in node_names_set: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + def remove_unused_constant(self): input_name_to_nodes = self.input_name_to_nodes() diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index e2044db04303d..4b76de6ecf1cb 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -5,30 +5,31 @@ # -------------------------------------------------------------------------- import logging +import numpy as np import onnx import onnx.numpy_helper from onnx import onnx_pb as onnx_proto -try: - from onnx.reference.op_run import to_array_extended -except ImportError: - # old version of onnx. - to_array_extended = None - -from .base_quantizer import BaseQuantizer +from .base_quantizer import BaseQuantizer, QuantizationParams +from .calibrate import TensorData from .onnx_model import ONNXModel from .quant_utils import ( TENSOR_NAME_QUANT_SUFFIX, QuantizationMode, QuantizedValue, + QuantizedValueType, __producer__, __version__, add_infer_metadata, attribute_to_kwarg, + compute_scale_zp, + compute_scale_zp_float8, find_by_name, + get_qmin_qmax_for_qType, get_qrange_for_qType, ms_domain, save_and_reload_model_with_shape_infer, + tensor_proto_to_array, ) from .registry import CreateOpQuantizer @@ -77,6 +78,7 @@ def __init__( self.fuse_dynamic_quant = self.opset_version > 10 self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] + self.new_nodes = [] self.graph_scope = "/" # for human readable debug information self.tensor_names = {} # in case the shape inference not totally working @@ -88,6 +90,8 @@ def __init__( if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") + self.quantization_params = self.calculate_quantization_params() + # QuantizeRange tensor name and zero tensor name for scale and zero point calculation. # Used when static is False self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8" @@ -97,6 +101,8 @@ def __init__( # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor) self.fixed_zero_zp_name = "fixed_zero_zp" + # Map of all original value names to quantized value names + self.quantized_value_map = {} # some output from nodes will be quantized, yet itself should be treat as existing so # no dequantized will be applied when needed later self.generated_value_names = self.model.get_non_initializer_inputs() @@ -494,6 +500,65 @@ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, i return input_scale_name, input_zp_name, [], [] + def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): + """ + Create initializers and inputs in the graph for zero point and scale of output. + Zero point and scale values are obtained from self.quantization_params if specified. + parameter param_name: Name of the quantization parameter. + return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. + """ + zero_point_type = self.activation_qType + + if use_scale is None or use_zeropoint is None: + if self.quantization_params is None or param_name not in self.quantization_params: + logging.info(f'Quantization parameters for tensor:"{param_name}" not specified') + return False, "", "", "", "" + + params = self.quantization_params[param_name] + if not isinstance(params, QuantizationParams): + raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.") + if params is None or len(params) != 3: + raise ValueError( + "Quantization parameters should contain zero point, scale, quant type. " + f"Specified values for output {param_name}: {params}" + ) + + zero_point_values = np.array([params["zero_point"]]) + if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): + raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") + scale_values = np.array([params["scale"]]) + assert scale_values.dtype != np.float64 + zero_point_type = params["quant_type"] + else: + zero_point_values = np.array([use_zeropoint]) + scale_values = np.array([use_scale]) + params = self.quantization_params[param_name] + if "scale" in params: + dtype = params["scale"].dtype + scale_values = scale_values.astype(dtype) + assert scale_values.dtype != np.float64 + + zero_point_shape = [] + zero_point_name = param_name + "_zero_point" + scale_shape = [] + scale_name = param_name + "_scale" + + # Add initializers + init_zp = onnx.helper.make_tensor( + zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + ) + self.model.add_initializer(init_zp) + if scale_values.dtype == np.float32: + scale_type = onnx_proto.TensorProto.FLOAT + elif scale_values.dtype == np.float16: + scale_type = onnx_proto.TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + self.model.add_initializer(init_scale) + + return True, scale_name, zero_point_name, scale_shape, zero_point_shape + def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None): """ Given an input for a node (which is not a initializer), this function @@ -564,6 +629,55 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None + def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + """ + Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale + """ + + # Handle case where bias already in quantization map + if bias_name in self.quantized_value_map: + return self.quantized_value_map[bias_name].q_name + + # get scale for weight + weight_scale_name = self.quantized_value_map[weight_name].scale_name + weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_initializer) + + # get scale for input + if input_name in self.quantized_value_map: + input_scale_name = self.quantized_value_map[input_name].scale_name + elif input_name in self.quantization_params: + _, input_scale_name, _, _, _ = self._get_quantization_params(input_name) + else: + raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization") + + inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(inputscale_initializer) + + ( + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + bias_scale_data, + node_type, + node_qtype, + ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, beta) + + assert bias_name not in self.quantized_value_map + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + node_type=node_type, + node_qtype=node_qtype, + ) + self.quantized_value_map[bias_name] = quantized_value + + return quantized_bias_name + def contains_tensor(self, tensor_name): """ only check for value info and newly generated tensor names, initializers are checked separately @@ -721,6 +835,71 @@ def __quantize_inputs( return quantized_input_names, zero_point_names, scale_names, nodes + def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False): + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ + # Find if this input is already quantized + if weight.name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight.name] + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( + weight, qType, reduce_range, keep_float_weight + ) + + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight.name] = quantized_value + return q_weight_name, zp_name, scale_name + + def quantize_weight_per_channel( + self, + weight_name, + weight_qType, + channel_axis, + reduce_range=True, + keep_float_weight=False, + ): + # Find if this input is already quantized + if weight_name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight_name] + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( + weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight + ) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight_name] = quantized_value + + return q_weight_name, zp_name, scale_name + def _dequantize_value(self, value_name): """ Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize @@ -771,3 +950,37 @@ def _dequantize_outputs(self): dequantize_node = self._dequantize_value(output.name) if dequantize_node is not None: self.new_nodes.append(dequantize_node) + + def calculate_quantization_params(self): + if self.tensors_range is None: + return None + + self.adjust_tensor_ranges() + + quantization_params = {} + for tensor_name in self.tensors_range: + td = self.tensors_range[tensor_name] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") + + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + + quant_type = self.activation_qType + if "quant_type" in quant_overrides: + quant_type = quant_overrides["quant_type"].tensor_type + + if "scale" in quant_overrides and "zero_point" in quant_overrides: + zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] + elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: + zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1]) + else: + rmin = quant_overrides.get("rmin", td.range_value[0]) + rmax = quant_overrides.get("rmax", td.range_value[1]) + symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) + reduce_range = quant_overrides.get("reduce_range", False) + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) + zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) + + quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index 06204585ba1ca..7054173450569 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -252,4 +252,4 @@ def quantize(self): self.quantizer.quantize_weight_tensor(node.input[1]) if len(node.input) == 3: - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) diff --git a/onnxruntime/python/tools/quantization/operators/direct_q8.py b/onnxruntime/python/tools/quantization/operators/direct_q8.py index c14532b96acbc..ae9679ae8ec7a 100644 --- a/onnxruntime/python/tools/quantization/operators/direct_q8.py +++ b/onnxruntime/python/tools/quantization/operators/direct_q8.py @@ -73,6 +73,6 @@ def quantize(self): if self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(self.node.input[0]) if not self.disable_qdq_for_node_output: - self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0]) + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) elif self.quantizer.is_tensor_quantized(self.node.input[0]) and not self.disable_qdq_for_node_output: - self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0]) + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) diff --git a/onnxruntime/python/tools/quantization/operators/gather.py b/onnxruntime/python/tools/quantization/operators/gather.py index f48725d1e428f..e390e874a2662 100644 --- a/onnxruntime/python/tools/quantization/operators/gather.py +++ b/onnxruntime/python/tools/quantization/operators/gather.py @@ -59,6 +59,6 @@ def quantize(self): if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(node.input[0]) - self.quantizer.quantize_activation_tensor(node.output[0], node.input[0]) + self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name) elif self.quantizer.is_tensor_quantized(node.input[0]): - self.quantizer.quantize_activation_tensor(node.output[0], node.input[0]) + self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name) diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py index d269c8fb47bd1..df24e256aa7fc 100644 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ b/onnxruntime/python/tools/quantization/operators/gemm.py @@ -153,7 +153,9 @@ def quantize(self): if len(node.input) == 3: if self.quantizer.is_input_a_initializer(node.input[2]): - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node)) + self.quantizer.quantize_bias_tensor( + node.name, node.input[2], node.input[0], node.input[1], get_beta(self.node) + ) set_default_beta(self.node) else: logging.warning( diff --git a/onnxruntime/python/tools/quantization/operators/norm.py b/onnxruntime/python/tools/quantization/operators/norm.py index e825fe6075601..3c14c926a7e75 100644 --- a/onnxruntime/python/tools/quantization/operators/norm.py +++ b/onnxruntime/python/tools/quantization/operators/norm.py @@ -29,7 +29,7 @@ def quantize(self): self.quantizer.quantize_activation_tensor(node.input[1]) # Bias - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) # Output if not self.disable_qdq_for_node_output: diff --git a/onnxruntime/python/tools/quantization/operators/softmax.py b/onnxruntime/python/tools/quantization/operators/softmax.py index 61a69ab3649dd..4b39fae8ac063 100644 --- a/onnxruntime/python/tools/quantization/operators/softmax.py +++ b/onnxruntime/python/tools/quantization/operators/softmax.py @@ -1,18 +1,8 @@ -import numpy as np import onnx import onnx.helper -from ..quant_utils import ( - TENSOR_NAME_QUANT_SUFFIX, - QuantizedValue, - QuantizedValueType, - attribute_to_kwarg, - compute_scale_zp, - get_qmin_qmax_for_qType, - ms_domain, -) +from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase class QLinearSoftmax(QuantOperatorBase): @@ -82,29 +72,3 @@ def quantize(self): nodes.append(qnode) self.quantizer.new_nodes += nodes return None - - -class QDQSoftmax(QDQOperatorBase): - def quantize(self): - super().quantize() - output_name = self.node.output[0] - quant_overrides = self.quantizer.get_per_tensor_quant_overrides(output_name) - - quant_type = self.quantizer.activation_qType - if "quant_type" in quant_overrides: - quant_type = quant_overrides["quant_type"].tensor_type - - if "scale" in quant_overrides and "zero_point" in quant_overrides: - out_zero_point, out_scale = quant_overrides["zero_point"], quant_overrides["scale"] - else: - # Unless overridden by the user, force Softmax to range from 0.0 to 1.0 - qparams = self.quantizer.quantization_params[output_name] - dtype = qparams.data["scale"].dtype - rmin = quant_overrides.get("rmin", np.array(0, dtype=dtype)) - rmax = quant_overrides.get("rmax", np.array(1, dtype=dtype)) - symmetric = quant_overrides.get("symmetric", self.quantizer.is_activation_symmetric) - reduce_range = quant_overrides.get("reduce_range", False) - qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) - out_zero_point, out_scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=symmetric) - - self.quantizer.set_quant_scale_zp(output_name, (out_scale, out_zero_point)) diff --git a/onnxruntime/python/tools/quantization/operators/split.py b/onnxruntime/python/tools/quantization/operators/split.py index c36b767f5abcc..74fc30cd075d2 100644 --- a/onnxruntime/python/tools/quantization/operators/split.py +++ b/onnxruntime/python/tools/quantization/operators/split.py @@ -60,4 +60,4 @@ def quantize(self): self.quantizer.quantize_activation_tensor(node.input[0]) if not self.disable_qdq_for_node_output: for output in node.output: - self.quantizer.quantize_activation_tensor(output, node.input[0]) + self.quantizer.quantize_output_same_as_input(output, node.input[0], node.name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 1875c552fab9c..c323c6fec545a 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -3,15 +3,21 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import logging +from dataclasses import dataclass from enum import Enum +from typing import Any +import numpy as np import onnx import onnx.numpy_helper from onnx import TensorProto from onnx import onnx_pb as onnx_proto -from .base_quantizer import BaseQuantizer +from .base_quantizer import BaseQuantizer, QuantizationParams +from .calibrate import TensorData from .quant_utils import ( DEQUANT_OP_NAME, QUANT_OP_NAME, @@ -24,8 +30,12 @@ add_quant_input_suffix, add_quant_output_suffix, add_quant_suffix, + compute_scale_zp, + compute_scale_zp_float8, find_by_name, + get_qmin_qmax_for_qType, ms_domain, + tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -36,6 +46,17 @@ class QDQQuantTensorType(Enum): BIAS = 2 +# Holds the name of the node input from which a node output will share the +# same quantization param initializers (zero-point and scale initializers). +# Ex: A Transpose node's output will use the same quant param initializers used at the input. +@dataclass +class QDQQuantParamProvider: + input_name: str + node_name: str + + +# Holds information for tensors that have been marked for quantization by operator quantizers. +# Does not hold information for bias tensors. class QDQTensorQuantInfo: def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None): self.tensor_type = tensor_type @@ -46,6 +67,64 @@ def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provide self.data_type = data_type +# Holds information for bias tensors that have been marked for quantization by operator quantizers. +@dataclass +class QDQBiasQuantInfo: + node_name: str + input_name: str + weight_name: str + beta: float + + +# Holds quantization parameter values (scale, zp) for a tensor. +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorQuantParams: + original: QuantizationParams # Generated by producer node. + converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes. + converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type. + + +# Holds scale and zero_point initializer TensorProtos. +@dataclass +class QDQScaleZpInitializers: + scale: TensorProto + zero_point: TensorProto + + +# Holds all scale and zero-point initializers for a tensor. +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorScaleZpInitializers: + original: QDQScaleZpInitializers + converted: QDQScaleZpInitializers | None + converted_recv_nodes: set[str] | None + + +# Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.). +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorQuantizedValue: + original: QuantizedValue + converted: QuantizedValue | None + converted_recv_nodes: set[str] | None + + def get_for_consumer(self, consumer_node_name) -> QuantizedValue: + if self.converted is None: # Quantized value is not converted, return original + return self.original + + if self.converted_recv_nodes is None: # All consumers receive the converted value + return self.converted + + # Check if consumer node name is in the list of nodes that + # receive the converted quantization value. If not, return the original value generated + # by the tensor's producer. + return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original + + class QDQQuantizer(BaseQuantizer): def __init__( self, @@ -74,7 +153,7 @@ def __init__( extra_options, ) self.tensors_to_quantize = {} - self.bias_to_quantize = [] + self.bias_to_quantize = {} self.nodes_to_remove = [] @@ -100,8 +179,7 @@ def __init__( # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) - if self.dedicated_qdq_pair: - self.tensor_to_its_receiving_nodes = {} + self.tensor_to_its_receiving_nodes = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) @@ -112,7 +190,7 @@ def __init__( # if the activation or weight types are 16-bit integers. # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. int16_types = (TensorProto.UINT16, TensorProto.INT16) - overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types) + overrides_have_int16 = any(t.tensor_type in int16_types for t in self.tensor_quant_override_qtypes) if not self.qdq_op_domain and ( self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 ): @@ -123,6 +201,11 @@ def __init__( ) self.qdq_op_domain = ms_domain + self.quantization_params = self.calc_graph_quant_params() + + # Map of all original value names to quantized value names + self.quantized_value_map = {} + def _get_tensor_type(self, tensor_name): """ Check if tensor can be quantized @@ -158,45 +241,71 @@ def _is_tensor_quantizable(self, tensor_name): return False - def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=QDQQuantTensorType.ACTIVATION): + def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION): """ - Quantize tensors. If quant_param_tensor is not None, tensor with name tensor_name will be quantized with same - quantization parameters as tensor quant_param_tensor + Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that + want to quantize a tensor (i.e., "mark" a tensor for quantization). + + If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same + quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output + will typically use the same quantization parameter initializers used at the Transpose node's input. Args: tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter + quant_sharing_provider: name of the tensor and node that provides quantization parameter tensor_type: QDQQuantTensorType default ACTIVATION """ if self._is_tensor_quantizable(tensor_name): - if quant_sharing_param: + if quant_sharing_provider: + if not isinstance(quant_sharing_provider, QDQQuantParamProvider): + raise TypeError( + f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}." + ) + data_type = self._get_tensor_type(tensor_name) self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( - tensor_type=tensor_type, quant_para_provider=quant_sharing_param, data_type=data_type + tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type ) elif tensor_name not in self.tensors_to_quantize: data_type = self._get_tensor_type(tensor_name) self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type) - def quantize_activation_tensor(self, tensor_name, quant_sharing_param=None): + def quantize_activation_tensor(self, tensor_name: str): """ - Quantize Activation Tensor + Adds a tensor to the list of tensors to quantize. Called by op quantizers that + want to quantize a tensor (i.e., "mark" a tensor for quantization). + Args: tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter - """ - return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.ACTIVATION) + return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION) - def quantize_weight_tensor(self, tensor_name, quant_sharing_param=None): + def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str): """ - Quantize Weight Tensor + Adds a tensor to the list of tensors to quantize. Called by op quantizers that + want to quantize an output tensor using the same quantization parameters as one of the node's inputs. + + Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at + the Transpose node's input. + Args: - tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter + output_name: name of the node output to quantize so that it uses the same quantization params as an input. + input_name: name of the node input from which the output tensor will get its quantization params. + node_name: name of the node that consumes `input_name`. + """ + return self.__quantize_tensor( + output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION + ) + def quantize_weight_tensor(self, tensor_name: str): """ - return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.WEIGHT) + Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that + want to quantize a weight (i.e., "mark" a weight for quantization). + + Args: + tensor_name: name of the weight to quantize + """ + return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT) def quantize_weight_tensor_per_channel(self, tensor_name, axis): weight = find_by_name(tensor_name, self.model.initializer()) @@ -208,7 +317,19 @@ def quantize_weight_tensor_per_channel(self, tensor_name, axis): else: logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.") - def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): + def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0): + """ + Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that + want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta. + TODO: Explain the reasoning for using this formula. + + Args: + node_name: name of the node that consumes the bias, input, and weight tensors. + bias_name: name of the bias tensor to quantize. + input_name: name of the input tensor whose scale is used to compute the bias's scale. + weight_name: name of the weight tensor whose scale is used to compute the bias's scale. + beta: Multiplier used to compute the bias's scale. + """ # If the user provided quantization overrides for this tensor, treat it as a regular weight. if self.tensor_quant_overrides.get(bias_name): logging.info( @@ -223,7 +344,10 @@ def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): weight = find_by_name(bias_name, self.model.initializer()) if weight is not None: if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): - self.bias_to_quantize.append((bias_name, input_name, weight_name, beta)) + if bias_name not in self.bias_to_quantize: + self.bias_to_quantize[bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) + else: + logging.warning(f"Bias {bias_name} has already been marked for quantization") else: logging.warning(f"Expected {bias_name} to be a weight") @@ -239,11 +363,10 @@ def quantize_model(self): op_quantizer = CreateQDQQuantizer(self, node) op_quantizer.quantize() - if self.dedicated_qdq_pair: - for tensor_name in node.input: - if tensor_name not in self.tensor_to_its_receiving_nodes: - self.tensor_to_its_receiving_nodes[tensor_name] = [] - self.tensor_to_its_receiving_nodes[tensor_name].append(node) + for tensor_name in node.input: + if tensor_name not in self.tensor_to_its_receiving_nodes: + self.tensor_to_its_receiving_nodes[tensor_name] = [] + self.tensor_to_its_receiving_nodes[tensor_name].append(node) self._quantize_normal_tensors() self._quantize_sharing_param_tensors() @@ -263,6 +386,8 @@ def quantize_model(self): def try_replacing_upstream_output(self, upstream_output_name, output_name): if ( output_name in self.quantization_params + and self.quantization_params[output_name].converted is None + and self.quantization_params[upstream_output_name].converted is None and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1 and not self.model.is_graph_output(upstream_output_name) and not self.model.is_graph_input(upstream_output_name) @@ -273,6 +398,50 @@ def try_replacing_upstream_output(self, upstream_output_name, output_name): return True return False + def _create_q_node( + self, + q_input: str, + q_output: str, + quant_node_name: str, + scale_name: str, + zp_name: str, + axis: int | None = None, + ): + """ + Creates a QuantizeLinear node and adds it to the model. + """ + qlinear_node = onnx.helper.make_node( + QUANT_OP_NAME, + [q_input, scale_name, zp_name], + [q_output], + quant_node_name, + axis=axis, + domain=self.qdq_op_domain, + ) + self.model.add_nodes([qlinear_node]) + + def _create_dq_node( + self, + dq_input: str, + dq_output: str, + dequant_node_name: str, + scale_name: str, + zp_name: str, + axis: int | None = None, + ): + """ + Creates a DequantizeLinear node and adds it to the model. + """ + dequant_node = onnx.helper.make_node( + DEQUANT_OP_NAME, + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name, + axis=axis, + domain=self.qdq_op_domain, + ) + self.model.add_nodes([dequant_node]) + def _create_qdq_nodes( self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None ): @@ -383,7 +552,7 @@ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_ty QuantizedValueType.Input, scale_type=data_type, ) - self.quantized_value_map[tensor_name] = quantized_value + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None) else: q_input = tensor_name dq_output = add_dequant_output_suffix(tensor_name) @@ -413,9 +582,165 @@ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_ty QuantizedValueType.Input, scale_type=data_type, ) - self.quantized_value_map[tensor_name] = quantized_value + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + def _add_qdq_ops_for_converted_activation( + self, + tensor_name, + first_scale_name, + first_zp_name, + scale_data_type, + convert_scale_name, + convert_zp_name, + convert_recv_nodes, + ): + """ + Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the + original data type from the producer, while other consumers use the converted data type. + This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16). + + T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float' + where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) ---> + + This function handles the following scenarios: + + 1) Tensor T is not a graph output; all consumers use the converted type + + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> + + 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type + + ---> Q1 -+-> DQ1 ---> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + + 3) Tensor T is a graph output; all consumers use the converted type + + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> + | + +-> + + 4) Tensor T is a graph output; some consumers use the original type, others use the converted type + + ---> Q1 -+-> DQ1 -+-> + | | + | +-> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + tensor_recv_nodes = set([node.name for node in self.tensor_to_its_receiving_nodes[tensor_name]]) + + if ( + self.dedicated_qdq_pair + and tensor_name in self.tensor_to_its_receiving_nodes + and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1 + ): + # TODO: Add support for dedicated_qdq_pair if/when needed. + raise ValueError( + "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled" + ) + + # Determine which nodes consume the original quantized type and which nodes + # consume the converted quantized type. + original_recv_nodes = tensor_recv_nodes + if convert_recv_nodes is None: # In this case, all consumers receive the converted type. + convert_recv_nodes = tensor_recv_nodes + original_recv_nodes = set() + else: + original_recv_nodes = original_recv_nodes - convert_recv_nodes + + all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes) + is_graph_output = self.model.is_graph_output(tensor_name) + + # Create first Q op. + first_q_input = tensor_name + if is_graph_output: + first_q_input = add_quant_input_suffix(tensor_name) + self.model.replace_output_of_all_nodes(tensor_name, first_q_input) + + first_q_output = add_quant_output_suffix(tensor_name) + self._create_q_node( + first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name + ) + + # Create first DQ op. + first_dq_output = add_dequant_output_suffix(tensor_name) + if is_graph_output and not all_use_converted: + first_dq_output = tensor_name + if original_recv_nodes and first_dq_output != tensor_name: + self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes) + + self._create_dq_node( + first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name + ) + + # Create parallel clone of first DQ op if _not all_ consumers use the converted type. + # --> DQ1' --> Q2 --> DQ2 --> + # + # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with + # it by some EPs (e.g., QNN) without breaking other "node units". + # Ex QNN fusion: + # --> Convert (fused) --> DQ2 --> + second_q_input = first_dq_output + if not all_use_converted: + second_q_input = add_quant_input_suffix(f"{tensor_name}_convert") + self._create_dq_node( + first_q_output, + second_q_input, + add_dequant_suffix(f"{tensor_name}_convert_clone"), + first_scale_name, + first_zp_name, + ) + + # Create second Q op. + second_q_output = add_quant_output_suffix(f"{tensor_name}_convert") + self._create_q_node( + second_q_input, + second_q_output, + add_quant_suffix(f"{tensor_name}_convert"), + convert_scale_name, + convert_zp_name, + ) + + # Create second DQ op. + second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert") + if is_graph_output and all_use_converted: + second_dq_output = tensor_name + if convert_recv_nodes and second_dq_output != tensor_name: + self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes) + self._create_dq_node( + second_q_output, + second_dq_output, + add_dequant_suffix(f"{tensor_name}_convert"), + convert_scale_name, + convert_zp_name, + ) + + # Store in quantized_value_map + original_quantized_value = QuantizedValue( + tensor_name, + first_dq_output, + first_scale_name, + first_zp_name, + QuantizedValueType.Input, + scale_type=scale_data_type, + ) + converted_quantized_value = QuantizedValue( + tensor_name, + second_dq_output, + convert_scale_name, + convert_zp_name, + QuantizedValueType.Input, + scale_type=scale_data_type, + ) + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue( + original_quantized_value, converted_quantized_value, convert_recv_nodes + ) def _quantize_normal_tensors(self): + """ + Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers. + """ for tensor_name, tensor_info in self.tensors_to_quantize.copy().items(): if tensor_name in self.quantized_value_map: continue @@ -426,53 +751,105 @@ def _quantize_normal_tensors(self): if initializer: self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis) else: - used_scale, used_zp = self.find_quant_scale_zp(tensor_name) - if used_scale is not None and not hasattr(used_scale, "dtype"): - raise TypeError( - f"Unexpected type {type(used_scale)} for used_scale and tensor_name={tensor_name!r}" - ) - data_found, scale_name, zp_name, _, _ = self._get_quantization_params( - tensor_name, used_scale, used_zp - ) - - if not data_found: + tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name) + if not tensor_qparam_initializers: raise ValueError( f"Quantization parameters are not specified for param {tensor_name}. " "In static mode quantization params for inputs and outputs of nodes to be quantized are required." ) - self._add_qdq_pair_for_activation(tensor_name, scale_name, zp_name, data_type=tensor_info.data_type) + if tensor_qparam_initializers.converted is None: + # Normal case: --> Q --> DQ --> + self._add_qdq_pair_for_activation( + tensor_name, + tensor_qparam_initializers.original.scale.name, + tensor_qparam_initializers.original.zero_point.name, + data_type=tensor_info.data_type, + ) + else: + # Conversion case: ---> Q1 -+-> DQ1 --> + # | + # +-> DQ1' --> Q2 --> DQ2 --> + assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type + self._add_qdq_ops_for_converted_activation( + tensor_name, + tensor_qparam_initializers.original.scale.name, + tensor_qparam_initializers.original.zero_point.name, + tensor_info.data_type, + tensor_qparam_initializers.converted.scale.name, + tensor_qparam_initializers.converted.zero_point.name, + tensor_qparam_initializers.converted_recv_nodes, + ) del self.tensors_to_quantize[tensor_name] def _quantize_sharing_param_tensors(self): + """ + Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers. + Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor. + For example, a Transpose node's output tensor will typically want to use the same quantization parameter + initializers as the Transpose node's input. + """ while self.tensors_to_quantize: for tensor_name, tensor_info in self.tensors_to_quantize.copy().items(): - tensor_provider_name = tensor_info.quant_para_provider - if tensor_provider_name in self.quantized_value_map: + quant_provider = tensor_info.quant_para_provider + if quant_provider and quant_provider.input_name in self.quantized_value_map: del self.tensors_to_quantize[tensor_name] - quantized_value = self.quantized_value_map[tensor_provider_name] - # Quantize the input - initializer = find_by_name(tensor_name, self.model.initializer()) - if initializer is not None: + quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer( + quant_provider.node_name + ) + if self.is_input_a_initializer(tensor_name): raise ValueError("Quantization parameter shared mode is not supported for weight yet") - self._add_qdq_pair_for_activation(tensor_name, quantized_value.scale_name, quantized_value.zp_name) + + # Need to check if this tensor's quant_type is converted for some consumers. + # If so, create new scale/zp initializers for these consumers. + converted_qparam_inits = None + converted_recv_nodes = None + if tensor_name in self.quantization_params: + tensor_params = self.quantization_params[tensor_name] + if tensor_params.converted: + converted_qparam_inits = self._make_scale_zp_initializers( + tensor_name, tensor_params.converted, "_convert" + ) + converted_recv_nodes = tensor_params.converted_recv_nodes + + if converted_qparam_inits is None: + # Normal case: --> Q_shared --> DQ_shared --> + self._add_qdq_pair_for_activation( + tensor_name, quantized_value.scale_name, quantized_value.zp_name + ) + else: + # Conversion case: ---> Q_shared -+-> DQ_shared --> + # | + # +-> DQ_shared' --> Q2 --> DQ2 --> + self._add_qdq_ops_for_converted_activation( + tensor_name, + quantized_value.scale_name, + quantized_value.zp_name, + converted_qparam_inits.scale.data_type, + converted_qparam_inits.scale.name, + converted_qparam_inits.zero_point.name, + converted_recv_nodes, + ) def _quantize_bias_tensors(self): - for bias_name, input_name, weight_name, beta in self.bias_to_quantize: + """ + Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers. + """ + for bias_name, bias_info in self.bias_to_quantize.items(): if bias_name in self.quantized_value_map: continue # Quantize the input - self.quantize_bias_static(bias_name, input_name, weight_name, beta) + self.quantize_bias_static(bias_name, bias_info) init = find_by_name(bias_name, self.model.initializer()) self.model.remove_initializer(init) - quant_value = self.quantized_value_map[bias_name] + quant_value = self.quantized_value_map[bias_name].original if quant_value.node_type == "Cast": # simple cast to float 16 and not DequantizeLinear # cublasLtMatmul only supports (b)float16, float bias. if not isinstance(init.data_type, int): - raise TypeError(f"Unexpected type {type(init.data_type)} for input={input_name!r}") + raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}") node_name = add_dequant_suffix(bias_name) dequant_node = onnx.helper.make_node( "Cast", @@ -511,5 +888,233 @@ def _quantize_bias_tensors(self): raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.") self.model.add_node(dequant_node) - def is_tensor_quantized(self, tensor_name): + def is_tensor_quantized(self, tensor_name: str): return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize + + def quantize_initializer( + self, + weight: onnx.TensorProto, + qType: onnx.TensorProto.DataType, + reduce_range: bool = False, + keep_float_weight: bool = False, + ) -> tuple[str, str, str]: + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ + # Find if this input is already quantized + if weight.name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight.name].original + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( + weight, qType, reduce_range, keep_float_weight + ) + + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) + return q_weight_name, zp_name, scale_name + + def quantize_weight_per_channel( + self, + weight_name: str, + weight_qType: onnx.TensorProto.DataType, + channel_axis: int, + reduce_range: bool = True, + keep_float_weight: bool = False, + ) -> tuple[str, str, str]: + # Find if this input is already quantized + if weight_name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight_name].original + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( + weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight + ) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + return q_weight_name, zp_name, scale_name + + def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str: + """ + Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale + """ + + # Handle case where bias already in quantization map + if bias_name in self.quantized_value_map: + return self.quantized_value_map[bias_name].original.q_name + + # get scale for weight + weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name + weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_initializer) + + # get scale for input + input_scale_name = ( + self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name + ) + inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(inputscale_initializer) + + ( + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + bias_scale_data, + node_type, + node_qtype, + ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta) + + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + node_type=node_type, + node_qtype=node_qtype, + ) + self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + return quantized_bias_name + + def _make_scale_zp_initializers( + self, param_name: str, params: QuantizationParams, init_name_suffix: str = "" + ) -> QDQScaleZpInitializers: + """ + Creates and returns scale and zero-point initializers for the given quantization params. The initializers are + named: + - {param_name}_zero_point{init_name_suffix} + - {param_name}_scale{init_name_suffix} + """ + zero_point_values = np.array([params["zero_point"]]) + if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): + raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") + scale_values = np.array([params["scale"]]) + assert scale_values.dtype != np.float64 + zero_point_type = params.data.get("quant_type", self.activation_qType) + + zero_point_shape = [] + zero_point_name = param_name + "_zero_point" + init_name_suffix + scale_shape = [] + scale_name = param_name + "_scale" + init_name_suffix + + # Add initializers to model + init_zp = onnx.helper.make_tensor( + zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + ) + self.model.add_initializer(init_zp) + + if scale_values.dtype == np.float32: + scale_type = onnx_proto.TensorProto.FLOAT + elif scale_values.dtype == np.float16: + scale_type = onnx_proto.TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + self.model.add_initializer(init_scale) + + return QDQScaleZpInitializers(init_scale, init_zp) + + def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None: + """ + Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted + to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise, + only one pair of zp/scale initializers is created. + """ + if self.quantization_params is None or tensor_name not in self.quantization_params: + logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified') + return None + + tensor_params = self.quantization_params[tensor_name] + if not isinstance(tensor_params, QDQTensorQuantParams): + raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.") + + original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original) + converted_inits = ( + self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert") + if tensor_params.converted + else None + ) + + return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes) + + def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams: + """ + Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional + user-provided overrides. + """ + quant_type = self.activation_qType + if "quant_type" in quant_overrides: + quant_type = quant_overrides["quant_type"].tensor_type + + if "scale" in quant_overrides and "zero_point" in quant_overrides: + zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] + elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: + zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1]) + else: + rmin = quant_overrides.get("rmin", tensor_data.range_value[0]) + rmax = quant_overrides.get("rmax", tensor_data.range_value[1]) + symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) + reduce_range = quant_overrides.get("reduce_range", False) + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) + zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) + + return QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + + def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: + """ + Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range + and optional user-provided overrides. + """ + if self.tensors_range is None: + return {} + + self.adjust_tensor_ranges() + + quantization_params = {} + for tensor_name in self.tensors_range: + td = self.tensors_range[tensor_name] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") + + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + original = self.calc_quant_params(td, quant_overrides) + converted = None + converted_recv_nodes = None + + if "convert" in quant_overrides: + converted = self.calc_quant_params(td, quant_overrides["convert"]) + converted_recv_nodes = quant_overrides["convert"].get("recv_nodes") + + quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index a693f4192bc2b..b00e830a2a366 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -18,7 +18,7 @@ from .operators.pooling import QLinearPool from .operators.qdq_base_operator import QDQOperatorBase from .operators.resize import QDQResize, QResize -from .operators.softmax import QDQSoftmax, QLinearSoftmax +from .operators.softmax import QLinearSoftmax from .operators.split import QDQSplit, QSplit from .operators.where import QDQWhere, QLinearWhere from .quant_utils import QuantizationMode @@ -79,7 +79,6 @@ "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, - "Softmax": QDQSoftmax, "Where": QDQWhere, "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py new file mode 100644 index 0000000000000..610b96b9d7937 --- /dev/null +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -0,0 +1,214 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import json +from collections.abc import MutableMapping +from typing import Any + +from .quant_utils import QuantType + + +class TensorQuantOverridesHelper(MutableMapping): + """ + Utility wrapper over the tensor quantization overrides passed via extra_options. + """ + + def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]): + self.overrides = raw_overrides + self.quant_types = None + + def get_per_tensor_overrides(self, tensor_name: str) -> dict[str, Any]: + overrides_list = self.overrides.get(tensor_name, [{}]) + num_overrides = len(overrides_list) + if num_overrides > 1: + raise ValueError( + f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " + f"but found {num_overrides} per-channel overrides." + ) + + return overrides_list[0] if num_overrides > 0 else {} + + def get_per_channel_overrides( + self, + tensor_name: str, + num_channels: int, + ) -> list[dict[str, Any]]: + overrides_list = self.overrides.get(tensor_name, [{} for i in range(num_channels)]) + + if len(overrides_list) != num_channels: + raise ValueError( + f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " + f"but found {len(overrides_list)} instead." + ) + + return overrides_list + + def get_quant_types(self) -> set[QuantType]: + if self.quant_types is not None: + return self.quant_types + + self.quant_types = set() + + if self.overrides: + for quant_overrides_list in self.overrides.values(): + for quant_overrides in quant_overrides_list: + if "quant_type" in quant_overrides: + self.quant_types.add(quant_overrides["quant_type"]) + + if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]: + self.quant_types.add(quant_overrides["convert"]["quant_type"]) + + return self.quant_types + + def is_valid( + self, + initializer_names: set[str], + activation_names: set[str], + default_activation_qtype, + ) -> tuple[bool, str | None]: + self.quant_types = set() + + # Validate that compatible/valid overrides are provided. + if self.overrides: + keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} + + for tensor_name, quant_overrides_list in self.overrides.items(): + if tensor_name not in initializer_names and tensor_name not in activation_names: + return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model" + + if not isinstance(quant_overrides_list, list): + return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list" + + is_initializer = tensor_name in initializer_names + if not is_initializer and len(quant_overrides_list) > 1: + return ( + False, + f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer", + ) + + quant_type = None + for index, quant_overrides in enumerate(quant_overrides_list): + if not isinstance(quant_overrides, dict): + return ( + False, + f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict", + ) + + # For per-channel quantization, all channels must use the same quantization type. + # Therefore, if the user tries to override the quant_type for a channel, it must match in all + # other channels. + if index == 0: + quant_type = quant_overrides.get("quant_type") + if quant_type: + self.quant_types.add(quant_type) + elif quant_type != quant_overrides.get("quant_type"): + return ( + False, + "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.", + ) + + has_scale = "scale" in quant_overrides + has_zero_point = "zero_point" in quant_overrides + + if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): + return ( + False, + "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", + ) + + if has_scale: + for key in keys_unsupported_with_scale_zp: + if key in quant_overrides: + return ( + False, + f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'", + ) + + if "reduce_range" in quant_overrides and not is_initializer: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + if "convert" in quant_overrides: + if index > 0: + return ( + False, + f"Per-channel overrides (tensor '{tensor_name}') do not support 'convert'.", + ) + + if is_initializer: + return False, "Cannot use 'convert' override for initializers" + + if "quant_type" not in quant_overrides["convert"]: + return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'" + + if "reduce_range" in quant_overrides["convert"]: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + convert_quant_type = quant_overrides["convert"]["quant_type"] + original_quant_type = quant_type if quant_type is not None else default_activation_qtype + if convert_quant_type == original_quant_type: + return ( + False, + f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')", + ) + + convert_has_scale = "scale" in quant_overrides["convert"] + convert_has_zero_point = "zero_point" in quant_overrides["convert"] + + if (convert_has_scale and not convert_has_zero_point) or ( + convert_has_zero_point and not convert_has_scale + ): + return ( + False, + f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')", + ) + + if convert_has_scale: + for key in keys_unsupported_with_scale_zp: + if key in quant_overrides["convert"]: + return ( + False, + f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point' (tensor '{tensor_name}')", + ) + + self.quant_types.add(convert_quant_type) + + return True, None + + def pprint_str(self, indent=None) -> str: + return json.dumps(self.overrides, default=str, indent=indent) + + def get_dict(self) -> dict[str, list[dict[str, Any]]]: + return self.overrides + + # Required implementations of abstract methods in collections.abc.MutableMapping + # so that this class can be used like a dict. + def __setitem__(self, key: str, value: list[dict]): + self.overrides[key] = value + + def __getitem__(self, key: str) -> list[dict]: + return self.overrides[key] + + def __delitem__(self, key: str): + del self.overrides[key] + + def __iter__(self): + return iter(self.overrides) + + def __len__(self): + return len(self.overrides) + + def __str__(self) -> str: + return str(self.overrides) + + def __repr__(self) -> str: + return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})" diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 9e7a4a125121d..db4ab7e8a412c 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -4,7 +4,9 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations +import os import tempfile import unittest from pathlib import Path @@ -25,12 +27,12 @@ class TestQDQFormat(unittest.TestCase): - def input_feeds(self, n, name2shape): + def input_feeds(self, n, name2shape, np_float_type=np.float32): input_data_list = [] for _i in range(n): inputs = {} for name, shape in name2shape.items(): - inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + inputs.update({name: np.random.randint(-1, 2, shape).astype(np_float_type)}) input_data_list.extend([inputs]) dr = TestDataFeeds(input_data_list) return dr @@ -720,5 +722,593 @@ def test_activation_only(self): check_op_type_count(self, qdq_model_path, **qop_nodes) +class TestQDQMixedPrecision(TestQDQFormat): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.mixed_prec_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_test_model_for_add_qdq_ops( + self, + num_consumers: int, + is_graph_output: bool, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + op0_transpose: bool = False, + ): + """ + Builds a float32 model with a single producer node and a configurable number of consumer nodes. + The tensor between the producer and consumers can be optionally made a graph output. + op_0 can optionally be made a Transpose node to test sharing qparams across the input and output. + + +-> op_0_out (optional graph output) + | + input_0 --> op_0 --+-> op_1 --> output_0 + | + +-> op_2 --> output_1 + | + ... + | + +-> op_{n} --> output_{n-1} + """ + shape = (1, 2, 3) + shape_t = (1, 3, 2) + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, shape) + output_shape = shape if not op0_transpose else shape_t + + outputs = [] + for i in range(num_consumers): + outputs.append(onnx.helper.make_tensor_value_info(f"output_{i}", float_type, output_shape)) + + if is_graph_output: + outputs.append(onnx.helper.make_tensor_value_info("op_0_out", float_type, output_shape)) + + nodes = [] + if op0_transpose: + nodes.append(onnx.helper.make_node("Transpose", ["input_0"], ["op_0_out"], perm=[0, 2, 1], name="op_0")) + else: + nodes.append(onnx.helper.make_node("Sigmoid", ["input_0"], ["op_0_out"], name="op_0")) + + for i in range(num_consumers): + op_index = i + 1 + nodes.append(onnx.helper.make_node("Cos", ["op_0_out"], [f"output_{i}"], name=f"op_{op_index}")) + + graph = onnx.helper.make_graph( + nodes, + "test_add_qdq_ops_for_converted_activation", + [input_0], + outputs, + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_add_tensor_qdq_ops_case_1(self): + """ + Tensor T is not a graph output; all consumers use the converted type + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_1{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_1{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 2, False, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_1", "op_2"}}, + } + ], + "output_0": [{"quant_type": QuantType.QUInt16}], + "output_1": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 --> DQ_8 --> Q_16 --> DQ_16 -+-> op_1 --> Q --> DQ --> output_0 + # | + # +-> op_2 --> Q --> DQ --> output_1 + qdq_node_counts = {"QuantizeLinear": 5, "DequantizeLinear": 5} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT16) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + + def test_add_tensor_qdq_ops_case_2(self): + """ + Tensor T is not a graph output; some consumers use the original type, others use the converted type + ---> Q1 -+-> DQ1 ---> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_2{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_2{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 4, False, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_3", "op_4"}}, + } + ], + "output_2": [{"quant_type": QuantType.QUInt16}], + "output_3": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 -+-> DQ_8 -+-> op_1 --> Q --> DQ --> output_0 + # | | + # | +-> op_2 --> Q --> DQ --> output_1 + # | + # +-> DQ_8' --> Q_16 --> DQ_16 -+-> op_3 --> Q --> DQ --> output_2 + # | + # +-> op_4 --> Q --> DQ --> output_3 + qdq_node_counts = {"QuantizeLinear": 7, "DequantizeLinear": 8} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT8) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT8) + output_2_zp_init = initializers["output_2_zero_point"] + self.assertEqual(output_2_zp_init.data_type, onnx.TensorProto.UINT16) + output_3_zp_init = initializers["output_3_zero_point"] + self.assertEqual(output_3_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + output_2_scale_init = initializers["output_2_scale"] + self.assertEqual(output_2_scale_init.data_type, float_type) + output_3_scale_init = initializers["output_3_scale"] + self.assertEqual(output_3_scale_init.data_type, float_type) + + def test_add_tensor_qdq_ops_case_3(self): + """ + Tensor T is a graph output; all consumers use the converted type + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> + | + +-> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_3{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_3{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 2, True, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_1", "op_2"}}, + } + ], + "output_0": [{"quant_type": QuantType.QUInt16}], + "output_1": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 --> DQ_8 --> Q_16 --> DQ_16 -+-> op_1 --> Q --> DQ --> output_0 + # | + # +-> op_2 --> Q --> DQ --> output_1 + # | + # +--> op_0_out (is graph output) + qdq_node_counts = {"QuantizeLinear": 5, "DequantizeLinear": 5} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + graph_outputs = {g_output.name: g_output for g_output in qdq_model.graph.output} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + self.assertNotIn("op_0_out_scale", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT16) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + + self.assertIn("op_0_out", graph_outputs) + + def test_add_tensor_qdq_ops_case_4(self): + """ + Tensor T is a graph output; some consumers use the original type, others use the converted type + ---> Q1 -+-> DQ1 -+-> + | | + | +-> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_4{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_4{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 4, True, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_3", "op_4"}}, + } + ], + "output_2": [{"quant_type": QuantType.QUInt16}], + "output_3": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 -+-> DQ_8 -+-> op_1 --> Q --> DQ --> output_0 + # | | + # | +-> op_2 --> Q --> DQ --> output_1 + # | | + # | +-> op_0_out (is graph output) + # | + # +-> DQ_8' --> Q_16 --> DQ_16 -+-> op_3 --> Q --> DQ --> output_2 + # | + # +-> op_4 --> Q --> DQ --> output_3 + qdq_node_counts = {"QuantizeLinear": 7, "DequantizeLinear": 8} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + graph_outputs = {g_output.name: g_output for g_output in qdq_model.graph.output} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT8) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT8) + output_2_zp_init = initializers["output_2_zero_point"] + self.assertEqual(output_2_zp_init.data_type, onnx.TensorProto.UINT16) + output_3_zp_init = initializers["output_3_zero_point"] + self.assertEqual(output_3_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + output_2_scale_init = initializers["output_2_scale"] + self.assertEqual(output_2_scale_init.data_type, float_type) + output_3_scale_init = initializers["output_3_scale"] + self.assertEqual(output_3_scale_init.data_type, float_type) + + self.assertIn("op_0_out", graph_outputs) + + def build_test_model_1(self, shape): + """ + Returns the following float32 model. + + input_0 --> op1 --> op3 --> op5 --> op6 --> output_0 + ^ + | + input_1 --> op2 -+-> op4 ----+ + | + +-> op7 --> output_1 + | + +-> op8 --> output_2 + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, shape) + input_1 = onnx.helper.make_tensor_value_info("input_1", onnx.TensorProto.FLOAT, shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, shape) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx.TensorProto.FLOAT, shape) + output_2 = onnx.helper.make_tensor_value_info("output_2", onnx.TensorProto.FLOAT, shape) + + op1_node = onnx.helper.make_node("Sigmoid", ["input_0"], ["op1_out"], name="op1") + op2_node = onnx.helper.make_node("Cos", ["input_1"], ["op2_out"], name="op2") + op3_node = onnx.helper.make_node("Sin", ["op1_out"], ["op3_out"], name="op3") + op4_node = onnx.helper.make_node("Tanh", ["op2_out"], ["op4_out"], name="op4") + op5_node = onnx.helper.make_node("Mul", ["op3_out", "op4_out"], ["op5_out"], name="op5") + op6_node = onnx.helper.make_node("Relu", ["op5_out"], ["output_0"], name="op6") + op7_node = onnx.helper.make_node("Cos", ["op2_out"], ["output_1"], name="op7") + op8_node = onnx.helper.make_node("Sigmoid", ["op2_out"], ["output_2"], name="op8") + + graph = onnx.helper.make_graph( + [ + op1_node, + op2_node, + op3_node, + op4_node, + op5_node, + op6_node, + op7_node, + op8_node, + ], + "mixed_prec_test", + [input_0, input_1], + [output_0, output_1, output_2], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_16bit_subgraph(self): + """ + Test correctness of a qdq model that uses a default 8-bit quantization type and contains + a subgraph that uses 16-bit activations. + """ + shape = (1, 2, 3) + f32_model_path = os.path.join(self._tmp_dir_path, "model.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "model.qdq.onnx") + qdq_mixed_model_path = os.path.join(self._tmp_dir_path, "model.mixed.qdq.onnx") + f32_model = self.build_test_model_1(shape) + onnx.save_model(f32_model, f32_model_path) + + data_reader = self.input_feeds(3, {"input_0": shape, "input_1": shape}) + + # Create pure 8-bit qdq model + quantize_static( + f32_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + ) + + # Create mixed precision 8-bit/16-bit qdq model + mixed_prec_overrides = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + data_reader.rewind() + quantize_static( + f32_model_path, + qdq_mixed_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + extra_options={"TensorQuantOverrides": mixed_prec_overrides}, + ) + + qop_nodes = {"Relu": 0, "QuantizeLinear": 11, "DequantizeLinear": 12} + check_op_type_count(self, qdq_mixed_model_path, **qop_nodes) + data_reader.rewind() + check_model_correctness(self, f32_model_path, qdq_mixed_model_path, data_reader.get_next()) + data_reader.rewind() + check_model_correctness(self, f32_model_path, qdq_model_path, data_reader.get_next()) + + if __name__ == "__main__": unittest.main()