From 3aaa7839ed6ae5d61ff8c87e7025f84e66f22871 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 27 Feb 2024 17:14:36 -0800 Subject: [PATCH] Ensure 16bit tensor quant overrides set MS domain --- .../tools/quantization/onnx_quantizer.py | 11 ++++--- .../tools/quantization/qdq_quantizer.py | 5 ++- .../test_tensor_quant_overrides_option.py | 32 ++++++++++++++++++- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 9450426f12444..19a72e38dea33 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -154,7 +154,7 @@ def __init__( if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") - self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides() + self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides() self.quantization_params = self.calculate_quantization_params() # QuantizeRange tensor name and zero tensor name for scale and zero point calculation. @@ -177,8 +177,10 @@ def __init__( def _get_and_check_tensor_quant_overrides(self): """ Get tensor quantization overrides and check correctness. + Also returns a set of quantization types (as TensorProto) specified across all overrides. """ tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {}) + tensor_quant_override_types = set() # Validate that compatible/valid overrides are provided. if tensor_quant_overrides: @@ -211,6 +213,8 @@ def _get_and_check_tensor_quant_overrides(self): # other channels. if index == 0: quant_type = quant_overrides.get("quant_type") + if quant_type is not None: + tensor_quant_override_types.add(quant_type.tensor_type) elif quant_type != quant_overrides.get("quant_type"): raise ValueError( "Channel quantization types for tensor '{tensor_name}' do not match at index {index}." @@ -231,7 +235,7 @@ def _get_and_check_tensor_quant_overrides(self): f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'" ) - return tensor_quant_overrides + return tensor_quant_overrides, tensor_quant_override_types def get_per_tensor_quant_overrides(self, tensor_name): quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}]) @@ -747,8 +751,7 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") scale_values = np.array([params["scale"]]) assert scale_values.dtype != np.float64 - # zero_point_type = params["quant_type"] - assert zero_point_type == params["quant_type"] + zero_point_type = params["quant_type"] else: zero_point_values = np.array([use_zeropoint]) scale_values = np.array([use_scale]) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 775a3e8b8b588..76cd0d21fca37 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -116,7 +116,10 @@ def __init__( # if the activation or weight types are 16-bit integers. # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. int16_types = (TensorProto.UINT16, TensorProto.INT16) - if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types): + overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types) + if not self.qdq_op_domain and ( + self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 + ): logging.warning( "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. " f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 0470953e385b6..1d9325e3c1301 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -13,7 +13,7 @@ from onnxruntime import quantization from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config -from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType +from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain class DummyDataReader(quantization.CalibrationDataReader): @@ -423,6 +423,36 @@ def test_qdq_overrides_per_channel2(self): self.assertEqual(zp, expected_zp) self.assertEqual(scale, np.float32(expected_scale)) + def test_16bit_overrides_set_ms_domain(self): + """ + Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft' + domain on DQ and Q ops. + """ + qdq_model_name = "model_quant_overrides_to_16bit.onnx" + inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization( + qdq_model_name, + activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations + extra_options={ + "TensorQuantOverrides": { + "INP": [{"quant_type": quantization.QuantType.QUInt16}], + "SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}], + } + }, + ) + + # Input and Sigmoid's output should be overridden to 16bit + self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16) + self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) + + # Output should the default uint8 type + self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8) + + # Q/DQ ops should all have the 'com.microsoft' domain + qdq_model = onnx.load_model(qdq_model_name) + for node in qdq_model.graph.node: + if node.op_type in {"QuantizeLinear", "DequantizeLinear"}: + self.assertEqual(node.domain, ms_domain) + def test_override_validation_nonexisting_tensor(self): """ Test that specifying a non-existing tensor should fail.