From a1b0065f0121a7de9b882178e69a2a2fc7fc81d3 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 7 Feb 2024 11:02:23 -0800 Subject: [PATCH 1/5] add option DefaultTensorType to specify the default tensor type to quantize --- .../tools/quantization/onnx_quantizer.py | 19 ++++++++++++++++--- .../tools/transformers/quantize_helper.py | 1 + 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 898a5f70ac45e..9ded92fc6f73d 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -446,6 +446,19 @@ def is_valid_quantize_weight(self, weight_name): return False return self.parent.is_valid_quantize_weight(weight_name) + def _get_default_tensor_type(self, tensor_name): + if "DefaultTensorType" in self.extra_options: + logging.info(f"get_tensor_type returns DefaultTensorType for tensor name %r, use %d", tensor_name, self.extra_options["DefaultTensorType"]) + return self.extra_options["DefaultTensorType"] + raise RuntimeError( + f"Unable to find data type for weight_name={tensor_name!r}. " + f"shape_inference failed to return a type probably this node is " + f"from a different domain or using an input produced by such an operator. " + f"This may happen if you quantize a model already quantized. " + f"You may use extra_options `DefaultTensorType` to indicate " + f"the default weight type, usually `onnx.TensorProto.FLOAT`." + ) + def get_tensor_type(self, tensor_name, mandatory=False): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -454,11 +467,11 @@ def get_tensor_type(self, tensor_name, mandatory=False): vi = self.value_infos[tensor_name] if vi.type.HasField("tensor_type"): if mandatory and vi.type.tensor_type.elem_type == 0: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return vi.type.tensor_type.elem_type if (not self.enable_subgraph_quantization) or (self.parent is None): if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None otype = self.parent.is_valid_quantize_weight(tensor_name) if otype is not None: @@ -468,7 +481,7 @@ def get_tensor_type(self, tensor_name, mandatory=False): if res is not None: return res if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None def is_float_tensor(self, tensor_name): diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index a449e881ad361..93a3f228e1a66 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data onnx_model_path, quantized_model_path, use_external_data_format=use_external_data_format, + extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) logger.info(f"quantized model saved to:{quantized_model_path}") # TODO: inlcude external data in total model size. From 1910e3867a4578604313ab1162cdd656be67527c Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 8 Feb 2024 08:12:21 -0800 Subject: [PATCH 2/5] add unit test to check option DefaultTensorType --- .../tools/quantization/onnx_quantizer.py | 8 +- .../python/tools/transformers/benchmark.py | 1 - .../test_quantizer_shape_inference.py | 91 +++++++++++++++++++ 3 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_quantizer_shape_inference.py diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 9ded92fc6f73d..3bedc8da917b0 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -389,7 +389,7 @@ def add_new_nodes(self, nodes): def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( - "Please check if the model is already quantized." + "Please check if the model is already quantized. " "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." ) @@ -448,7 +448,11 @@ def is_valid_quantize_weight(self, weight_name): def _get_default_tensor_type(self, tensor_name): if "DefaultTensorType" in self.extra_options: - logging.info(f"get_tensor_type returns DefaultTensorType for tensor name %r, use %d", tensor_name, self.extra_options["DefaultTensorType"]) + logging.info( + "get_tensor_type returns DefaultTensorType for tensor name %r, use %d", + tensor_name, + self.extra_options["DefaultTensorType"], + ) return self.extra_options["DefaultTensorType"] raise RuntimeError( f"Unable to find data type for weight_name={tensor_name!r}. " diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index f506516442b1e..2486548fc255f 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -39,7 +39,6 @@ It is recommended to use run_benchmark.sh to launch benchmark. """ - import argparse import logging import os diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py new file mode 100644 index 0000000000000..77e8ea775170d --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantType, QuantizationMode + + +class TestQuantizerShapeInference(unittest.TestCase): + def test_com_microsoft(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("MatMul", ["X", "W1"], ["T1"]), + oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"), + oh.make_node("MatMul", ["T2", "W3"], ["T3"]), + oh.make_node("MatMul", ["T3", "W4"], ["Y"]), + ], + "name", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])], + [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])], + [ + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"), + ], + ), + opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], + ) + model_shaped = onnx.shape_inference.infer_shapes(model) + shaped_results = set(t.name for t in model_shaped.graph.value_info) + # every result after T1 depends on T2 coming from a node com.microsoft, + # shape_inference cannot go beyond this point + self.assertEqual(shaped_results, {"T1"}) + + # first try: checks it raises an exception + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + {"MatMulConstBOnly": True}, # extra_options, + # {'DefaultTensorType': 1, } + ) + + with self.assertRaises(RuntimeError) as e: + quantizer.quantize_model() + self.assertIn("Unable to find data type for weight_name=", str(e)) + + # second try: checks it works + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + { + "MatMulConstBOnly": True, + "DefaultTensorType": 1, + }, + ) + + model = quantizer.quantize_model() + ops = {n.op_type for n in model.graph.node} + self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 99b6a4b6bc063744274586a24440b28cc9f507c2 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 8 Feb 2024 08:14:17 -0800 Subject: [PATCH 3/5] restore a space --- onnxruntime/python/tools/transformers/benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 2486548fc255f..f506516442b1e 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -39,6 +39,7 @@ It is recommended to use run_benchmark.sh to launch benchmark. """ + import argparse import logging import os From eb41f47cb7273473610fad0e51b23d7a7a17a1f9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 15 Feb 2024 02:07:14 -0800 Subject: [PATCH 4/5] lint issues --- onnxruntime/python/tools/transformers/quantize_helper.py | 2 +- .../test/python/quantization/test_quantizer_shape_inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index 93a3f228e1a66..6a25196dbc24c 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -7,7 +7,7 @@ import logging import os -import onnx # noqa: F401 +import onnx import torch from transformers.modeling_utils import Conv1D diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py index 77e8ea775170d..413ba1d93a426 100644 --- a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -11,7 +11,7 @@ import onnx.helper as oh import onnx.numpy_helper as onh from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer -from onnxruntime.quantization.quant_utils import QuantType, QuantizationMode +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType class TestQuantizerShapeInference(unittest.TestCase): From deaaf7b41dff3846327a5b4eac75d0e2a4ea03a5 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 15 Feb 2024 03:29:42 -0800 Subject: [PATCH 5/5] lint --- .../test/python/quantization/test_quantizer_shape_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py index 413ba1d93a426..2b5d1f36070e5 100644 --- a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -10,6 +10,7 @@ import onnx import onnx.helper as oh import onnx.numpy_helper as onh + from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType