From 8d5ecc4dae0686d032a81c3633fdaf213572a722 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 28 Nov 2023 09:46:47 -0800 Subject: [PATCH] [Quantization] Fix scale/zero-point for 16-bit QDQ Softmax (#18589) ### Description Sets the appropriate scale and zero-point values for 16-bit QDQ Softmax. Previously, the scale/zp were set to fixed values that were specific to 8-bit quantization. ### Motivation and Context Generate more accurate 16-bit QDQ models that contain Softmax. --- .../tools/quantization/operators/softmax.py | 28 +++--- .../test/python/quantization/op_test_utils.py | 3 + .../python/quantization/test_op_softmax.py | 96 ++++++++++++++----- 3 files changed, 93 insertions(+), 34 deletions(-) diff --git a/onnxruntime/python/tools/quantization/operators/softmax.py b/onnxruntime/python/tools/quantization/operators/softmax.py index 1e380d7764952..bd09b05ddd9ff 100644 --- a/onnxruntime/python/tools/quantization/operators/softmax.py +++ b/onnxruntime/python/tools/quantization/operators/softmax.py @@ -1,6 +1,14 @@ import onnx -from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain +from ..quant_utils import ( + TENSOR_NAME_QUANT_SUFFIX, + QuantizedValue, + QuantizedValueType, + attribute_to_kwarg, + compute_scale_zp, + get_qmin_qmax_for_qType, + ms_domain, +) from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase @@ -77,15 +85,11 @@ def quantize(self): class QDQSoftmax(QDQOperatorBase): def quantize(self): super().quantize() - if self.quantizer.activation_qType == onnx.onnx_pb.TensorProto.UINT8: - out_scale = 1 / 256.0 - out_zero_point = 0 - elif self.quantizer.is_activation_symmetric: - # results are all greater or equal to 0, so we can only use - # half of the range - out_scale = 1 / 127.0 - out_zero_point = 0 - else: - out_scale = 1 / 256.0 - out_zero_point = -128 + symmetric = self.quantizer.is_activation_symmetric + + # Enforce Softmax range: 0.0 to 1.0 + rmin, rmax = 0.0, 1.0 + qmin, qmax = get_qmin_qmax_for_qType(self.quantizer.activation_qType, symmetric=symmetric) + out_zero_point, out_scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=symmetric) + self.quantizer.set_quant_scale_zp(self.node.output[0], (out_scale, out_zero_point)) diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index f26b6297cdbda..eede1be05f85f 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -393,6 +393,9 @@ def check_qtype_by_node_type(testcase, model_to_check, check_list): model = onnx.load(model_to_check) elif isinstance(model_to_check, onnx.ModelProto): model = model_to_check + # NOTE: ONNX shape inference does not work on MS domain nodes. + # Therefore, this function cannot currently be used for graphs that contain ops such as + # com.microsoft.QuantizeLinear, which support 16-bit quantization. model = onnx.shape_inference.infer_shapes(model) value_infos = {vi.name: vi for vi in model.graph.value_info} value_infos.update({ot.name: ot for ot in model.graph.output}) diff --git a/onnxruntime/test/python/quantization/test_op_softmax.py b/onnxruntime/test/python/quantization/test_op_softmax.py index 8e6e4d4100348..3416198450137 100644 --- a/onnxruntime/test/python/quantization/test_op_softmax.py +++ b/onnxruntime/test/python/quantization/test_op_softmax.py @@ -43,6 +43,7 @@ def construct_model_conv_softmax( softmax_input_shape, softmax_attributes, output_shape, + add_ms_domain_opset=False, ): # (input) # \ @@ -74,11 +75,16 @@ def construct_model_conv_softmax( [identity_out, output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + opset_imports = [helper.make_opsetid("", 13)] + if add_ms_domain_opset: + opset_imports.append(helper.make_opsetid("com.microsoft", 1)) + + model = helper.make_model(graph, opset_imports=opset_imports) model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + def quantize_softmax_test_qop(self, activation_type, weight_type, extra_options={}): # noqa: B006 np.random.seed(1) model_fp32_path = "softmax_fp32.onnx" self.construct_model_conv_softmax( @@ -91,11 +97,10 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): ) data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) - activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" - weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + activation_proto_qtype = activation_type.tensor_type + activation_type_str = str(activation_type) + weight_type_str = str(weight_type) model_q8_path = f"softmax_{activation_type_str}{weight_type_str}.onnx" - model_q8_qdq_path = f"softmax_qdq_{activation_type_str}{weight_type_str}.onnx" # Verify QOperator mode data_reader.rewind() @@ -138,11 +143,30 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next()) + def quantize_softmax_test_qdq(self, activation_type, weight_type, extra_options={}): # noqa: B006 + np.random.seed(1) + model_fp32_path = "softmax_fp32.onnx" + self.construct_model_conv_softmax( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"axis": -2}, + [1, 3, 24, 40], + add_ms_domain_opset=extra_options.get("UseQDQContribOps", False), + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + + activation_proto_qtype = activation_type.tensor_type + activation_type_str = str(activation_type) + weight_type_str = str(weight_type) + model_qdq_path = f"softmax_qdq_{activation_type_str}{weight_type_str}.onnx" + # Verify QDQ mode data_reader.rewind() quantize_static( model_fp32_path, - model_q8_qdq_path, + model_qdq_path, data_reader, quant_format=QuantFormat.QDQ, activation_type=activation_type, @@ -150,7 +174,7 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): extra_options=extra_options, ) - result_model = onnx.load(Path(model_q8_qdq_path)) + result_model = onnx.load(Path(model_qdq_path)) qnode_cnt = 0 dqnode_cnt = 0 softmax_cnt = 0 @@ -166,9 +190,15 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): self.assertEqual(3, qnode_cnt, f"Expected 3 QuantizeLinear nodes, found {qnode_cnt}") self.assertEqual(4, dqnode_cnt, f"Expected 4 DequantizeLinear nodes, found {dqnode_cnt}") self.assertEqual(1, softmax_cnt, f"Expected 1 Softmax node, found {softmax_cnt}") - if extra_options.get("ActivationSymmetric", False): - for tensor in result_model.graph.initializer: - if tensor.name in qnode_zeropoints: + for tensor in result_model.graph.initializer: + if tensor.name in qnode_zeropoints: + self.assertEqual( + tensor.data_type, + activation_proto_qtype, + f"QuantizeLinear zero-point must be of proto type {activation_proto_qtype}, " + f"but found {tensor.data_type} instead.", + ) + if extra_options.get("ActivationSymmetric", False): np_value = numpy_helper.to_array(tensor) self.assertEqual( 0, @@ -176,30 +206,52 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): f"QuantizeLinear node zero point value must be 0, found {np_value} instead!", ) - qnode_io_qtypes = { - "QuantizeLinear": [ - ["i", 2, activation_proto_qtype], - ["o", 0, activation_proto_qtype], - ] - } - check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes) data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next()) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) def test_quantize_softmax(self): - self.quantize_softmax_test(QuantType.QUInt8, QuantType.QUInt8) + self.quantize_softmax_test_qop(QuantType.QUInt8, QuantType.QUInt8) + self.quantize_softmax_test_qdq(QuantType.QUInt8, QuantType.QUInt8) def test_quantize_softmax_s8s8(self): - self.quantize_softmax_test( + self.quantize_softmax_test_qop( + QuantType.QInt8, + QuantType.QInt8, + ) + self.quantize_softmax_test_qdq( + QuantType.QInt8, + QuantType.QInt8, + ) + self.quantize_softmax_test_qop( QuantType.QInt8, QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, ) - self.quantize_softmax_test( + self.quantize_softmax_test_qdq( QuantType.QInt8, QuantType.QInt8, extra_options={"ActivationSymmetric": True}, ) + def test_quantize_softmax_qdq_u16u16(self): + self.quantize_softmax_test_qdq( + QuantType.QUInt16, + QuantType.QUInt16, + extra_options={"UseQDQContribOps": True}, + ) + + def test_quantize_softmax_qdq_s16s16(self): + self.quantize_softmax_test_qdq( + QuantType.QInt16, + QuantType.QInt16, + extra_options={"UseQDQContribOps": True}, + ) + self.quantize_softmax_test_qdq( + QuantType.QInt16, + QuantType.QInt16, + extra_options={"UseQDQContribOps": True, "ActivationSymmetric": True}, + ) + if __name__ == "__main__": unittest.main()