From 63a8ea90083f20e1ec63dea3d9cfb0c8f07d7fc1 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 25 Dec 2023 23:13:52 +0100 Subject: [PATCH] fix wrong types --- .../python/tools/quantization/quant_utils.py | 39 ++++++++++--------- .../test_tensor_quant_overrides_option.py | 4 +- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 223039d3d6696..c8ace8a2b3a64 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -124,22 +124,22 @@ def from_string(format): } ONNX_INT_TYPE_RANGE = { - onnx_proto.TensorProto.UINT8: (0, 255), - onnx_proto.TensorProto.INT8: (-128, 127), - onnx_proto.TensorProto.UINT16: (0, 65535), - onnx_proto.TensorProto.INT16: (-32768, 32767), + onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(255, dtype=numpy.uint8)), + onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), + onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)), + onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), } ONNX_INT_TYPE_SYMMETRIC_RANGE = { - onnx_proto.TensorProto.INT8: (-127, 127), - onnx_proto.TensorProto.INT16: (-32767, 32767), + onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), + onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), } ONNX_INT_TYPE_REDUCED_RANGE = { - onnx_proto.TensorProto.UINT8: (0, 127), - onnx_proto.TensorProto.INT8: (-64, 64), - onnx_proto.TensorProto.UINT16: (0, 32767), - onnx_proto.TensorProto.INT16: (-16384, 16384), + onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(127, dtype=numpy.uint8)), + onnx_proto.TensorProto.INT8: (numpy.array(-64, dtype=numpy.int8), numpy.array(64, dtype=numpy.int8)), + onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(32767, dtype=numpy.uint16)), + onnx_proto.TensorProto.INT16: (numpy.array(-16384, dtype=numpy.uint16), numpy.array(16384, dtype=numpy.uint16)), } @@ -226,6 +226,7 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non :return: zero and scale [z, s] """ + assert rmin <= rmax, f"rmin={rmin} > rmax={rmax}" if qmin > 0 or qmax < 0: raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}") @@ -244,12 +245,17 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non rmin = -absmax rmax = +absmax - scale = numpy.array(numpy.array(rmax - rmin, dtype=numpy.float64) / numpy.array(qmax - qmin, dtype=numpy.float64)) + assert rmin <= rmax, f"rmin={rmin} > rmax={rmax}" + assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}" + dr = numpy.array(rmax - rmin, dtype=numpy.float64) + dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64) + scale = numpy.array(dr / dq) + assert scale >= 0, "scale isse" if scale < numpy.finfo(rmax.dtype).tiny: scale = numpy.array(1.0, dtype=rmax.dtype) - zero_point = numpy.array(0, dtype=numpy.int32) + zero_point = numpy.array(0, dtype=qmin.dtype) else: - zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=numpy.int32) + zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype) scale = scale.astype(rmax.dtype) return [zero_point, scale] @@ -325,20 +331,17 @@ def quantize_data( if rmin_override is not None: rmin = rmin_override else: - rmin = 0.0 + rmin = data.min() if len(data) else 0.0 if rmax_override is not None: rmax = rmax_override else: - rmax = 0.0 + rmax = data.max() if len(data) else 0.0 rmin = numpy.array(rmin, dtype=data.dtype) rmax = numpy.array(rmax, dtype=data.dtype) zero_point = 0 scale = numpy.array(1.0, dtype=data.dtype) - if len(data): - rmin = data.min() - rmax = data.max() if qType == TensorProto.FLOAT8E4M3FN: if reduce_range: diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 82af0319e2d34..9f0ee380cad15 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -292,8 +292,6 @@ def test_qdq_overrides3(self): wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type) new_wgt_zp, new_wgt_sc = compute_scale_zp(wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax) - print("****", [wgt_zp.data_type, wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax], [new_wgt_zp, new_wgt_sc]) - # [2, array(0., dtype=float32), array(1., dtype=float32), 0, 255] [array(0, dtype=int32), array(0.00392157, dtype=float32)] self.assertEqual(wgt_zp.int32_data[0], new_wgt_zp) self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc)) @@ -508,5 +506,5 @@ def test_override_validation_bad_combination(self): if __name__ == "__main__": t = TestTensorQuantOverridesOption() t.setUp() - t.test_qdq_overrides_per_channel1() + t.test_qdq_default_per_channel() unittest.main()