From aee75fff1a7775a0b669144cb9757c0f7c3c9c52 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 25 Dec 2023 22:11:36 +0100 Subject: [PATCH] fix type issue --- .../python/tools/quantization/quant_utils.py | 21 +++++++++++++++---- .../test_tensor_quant_overrides_option.py | 2 ++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index cb5c6fa961ace..223039d3d6696 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -143,7 +143,7 @@ def from_string(format): } -def _check_type(*args): +def _check_type(*args, zero_point_index=-1): new_args = [] for i, a in enumerate(args): if numpy.issubdtype(type(a), numpy.number): @@ -152,6 +152,10 @@ def _check_type(*args): new_args.append(a) else: raise TypeError(f"arg {i} is not an array: {a}") + if i == zero_point_index: + v = new_args[-1] + if v.dtype == numpy.float32 or v.dtype == numpy.float16: + raise TypeError(f"zero_point cannot be {v.dtype}") return tuple(new_args) if len(new_args) > 1 else new_args[0] @@ -261,10 +265,13 @@ def compute_scale_zp_float8(element_type, std): More details in notebook `quantization_fp8.ipynb `_. """ + zp_dtype = None if element_type not in FLOAT8_DISTRIBUTIONS: if element_type == TensorProto.FLOAT8E4M3FN: from onnx.numpy_helper import float8e4m3_to_float32 + from onnx.reference.custom_element_types import float8e4m3fn + zp_dtype = float8e4m3fn all_values = [float8e4m3_to_float32(i) for i in range(0, 256)] values = numpy.array( [f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32 @@ -272,9 +279,15 @@ def compute_scale_zp_float8(element_type, std): else: raise ValueError(f"Quantization to element_type={element_type} not implemented.") FLOAT8_DISTRIBUTIONS[element_type] = values + elif element_type == TensorProto.FLOAT8E4M3FN: + from onnx.reference.custom_element_types import float8e4m3fn + zp_dtype = float8e4m3fn + + if zp_dtype is None: + raise TypeError(f"Unexpected element_type {element_type}.") std_f8 = numpy.std(FLOAT8_DISTRIBUTIONS[element_type]) - zero = numpy.array(0, dtype=std.dtype) + zero = numpy.array(0, dtype=zp_dtype) scale = numpy.array(std / std_f8, dtype=std.dtype) return [zero, scale] @@ -339,14 +352,14 @@ def quantize_data( f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], " f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]." ) - return _check_type(rmin, rmax, zero_point, scale, quantized_data) + return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16): if len(data): qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) quantized_data = quantize_nparray(qType, data, scale, zero_point) - return _check_type(rmin, rmax, zero_point, scale, quantized_data) + return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) raise ValueError(f"Unexpected value for qType={qType}.") diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 7ace20feb6a09..82af0319e2d34 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -292,6 +292,8 @@ def test_qdq_overrides3(self): wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type) new_wgt_zp, new_wgt_sc = compute_scale_zp(wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax) + print("****", [wgt_zp.data_type, wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax], [new_wgt_zp, new_wgt_sc]) + # [2, array(0., dtype=float32), array(1., dtype=float32), 0, 255] [array(0, dtype=int32), array(0.00392157, dtype=float32)] self.assertEqual(wgt_zp.int32_data[0], new_wgt_zp) self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc))