From 76dd03518f41fbeec6c4a431f7bd8150f953a4ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 24 Apr 2024 19:16:27 +0200 Subject: [PATCH] Fix quantization tools for issue #19529 (#19591) ### Description Fix issue #19529, the code was using a variable loop outside a loop. --- .../python/tools/quantization/calibrate.py | 12 ++- .../python/quantization/test_quant_issues.py | 72 ++++++++++++++++++ .../test/testdata/qdq_minimal_model.onnx | Bin 0 -> 6372 bytes 3 files changed, 77 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_quant_issues.py create mode 100644 onnxruntime/test/testdata/qdq_minimal_model.onnx diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index ef1ecd20a0d6f..fe37cf3c87880 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -733,13 +733,11 @@ def collect_absolute_value(self, name_to_arr): for tensor, data_arr in name_to_arr.items(): if isinstance(data_arr, list): for arr in data_arr: - if not isinstance(arr, np.ndarray): - raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}") - dtypes = set(a.dtype for a in arr) - if len(dtypes) != 1: - raise ValueError( - f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}" - ) + assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}" + dtypes = set(a.dtype for a in data_arr) + assert ( + len(dtypes) == 1 + ), f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}" data_arr_np = np.asarray(data_arr) elif not isinstance(data_arr, np.ndarray): raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}") diff --git a/onnxruntime/test/python/quantization/test_quant_issues.py b/onnxruntime/test/python/quantization/test_quant_issues.py new file mode 100644 index 0000000000000..66960978748ad --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quant_issues.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import tempfile +import unittest +import warnings + + +def ignore_warnings(warns): + """ + Catches warnings. + + :param warns: warnings to ignore + """ + + def wrapper(fct): + if warns is None: + raise AssertionError(f"warns cannot be None for '{fct}'.") + + def call_f(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", warns) + return fct(self) + + return call_f + + return wrapper + + +class TestQuantIssues(unittest.TestCase): + @ignore_warnings(DeprecationWarning) + def test_minimal_model(self): + folder = os.path.join(os.path.dirname(__file__), "..", "..", "testdata") + onnx_path = os.path.join(folder, "qdq_minimal_model.onnx") + if not os.path.exists(onnx_path): + # The file does seem to be the same location in every CI job. + raise unittest.SkipTest("unable to find {onnx_path!r}") + + import numpy as np + + import onnxruntime.quantization as oq + + class Mock: + def __init__(self): + self.i = 0 + + def get_next(self): + if self.i > 10: + return None + self.i += 1 + return {"input": np.random.randint(0, 255, size=(1, 3, 32, 32), dtype=np.uint8)} + + with tempfile.TemporaryDirectory() as temp: + preprocessed_path = os.path.join(temp, "preprocessed.onnx") + quantized_path = os.path.join(temp, "quantized.onnx") + oq.quant_pre_process(onnx_path, preprocessed_path, skip_symbolic_shape=True) + oq.quantize_static( + preprocessed_path, + quantized_path, + Mock(), + calibrate_method=oq.CalibrationMethod.Percentile, + op_types_to_quantize=["Conv", "Mul", "Gemm"], + ) + assert os.path.exists(preprocessed_path), f"missing output {preprocessed_path!r}" + assert os.path.exists(quantized_path), f"missing output {quantized_path!r}" + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/testdata/qdq_minimal_model.onnx b/onnxruntime/test/testdata/qdq_minimal_model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..04e71789e63565321daa7f3ab6b46d4a63ceceef GIT binary patch literal 6372 zcmbtY-EZ7P5YOduAK4~tFoi}&LX}k^aj3=i+Fm=TAxer$i$XtAUf{tv7hliS2|cv(Tj9)+#l*~ z>y~BTvNnFP$Oc)PjK`-N8#|M6KT1Q(bToCw)ychsc$_Bbi_kQ>@J$@DTF~A}o^wK5 zlZF&gS%|8N=vELJqP6GJR;Be2Ek(39Nh=v2O;ha>5?XKXL?>zJ7)#l2Z<#Enll8_~ zW03@SpRE(FpOZ8ir=e?f;k$R0w4O(!Y253sHSRAhEWF$~Yb=qQ1Vi9~t+kTF=xMCC zGL?BolmC!=fV4N=4;_TW5Yi1mTJJWxP5!^Z5m)|v^w;KZr+;jI^82rw1#-b3FmsN; zK7(}iRyOAwP5v`O&ZR-f!5fkP{_)x7_P6ooU@*v$O@Pc%@W9hL`I&D~y`3#uW^wEf z3Cx4v?2VFs9Qt~9Zx|iLp=BAZ{EKYMI|BU@SRk~xd%&QlgOVprc7s|g(H9Okn z!tMo%Wnw1WPtIlX&+Fo^j){etICGekJYLo@u`v_uUFvZf=r~7k_e!M^<#fu}lD7xS zvaMYylyf&qhw*V}QEr6X8kWb6kbQ9@D0>_%H}=(aFvE&!ghxc-LDdNdOr76L%hi`b zB;RNkpS??DDLPF~))%sSS`o{`Vqr&?#lp&bvsj)i787FGX1!Pp6keQzX)Cj+W1D&> zpSG=P-0^|1xGOf12!YLRG7fF4xKOrTTqp(7g}oUg+W|=WMo20_I>07}bW2DFAWP{A zd4RCkDhQ{e(Dw9pF>5sWFF8-v-E1!dS-m=00Wsex5OWO4NVd-*^GzIbI#MBpMe;oX zK>BZj4CX*mvxE!^$lGAc!rH#36;ANg_~g539RCo9)GC3gDFJsf*OtktSwn6Ex$H?f z`Q4iR6kQH_(a`yr)W}^kMHgW#rcl1a^S)vZ2f%JX!PD9OF3P8l#Ju(INQ5URbT&nYW{Vww>lp35jp>24|xdc5jg5Z0Ldh)UR_h9Ux(GF-*b=PT2Nj({&`1m zaYg3fyCNT*9aBb-bR2;!Qr-JxY+K-DZmp`Y(&67piLL%LUOqJ`*kw!JS z#O1j?-^t5~>~SFD){MhC!C}rTo2ESuyw+@HNtsHMDv&=oR5WjUVYPpGm4{Cwz%pR zCLT=rYs~XS%yWbKm{+3|hBJmYFJt@%i!n=43bV0NJPRwOPB>g}c>kjsa%(Kc)PjRa z1&1e9L(_^mJjEQ$u54lI7P%;4v!cXv>J4@W^jF4RhGVyP(fehKd|f(=F#$5d&WN?{*jYJAbWgalFQf{T?{&OoYlzAmU8ka0pDDz-~s3PWi3& zZvq@x*hL=;N>Lo(e8T4WI~=a2e-q$06~hIO8u!bV!>c+bIKQ)b$pcF=6*sR`oid)? z%wnA?>q?~^IFOZ}WF_a(^QqC~zrDB10=?TWpHx{PUTu||T6;20c^6wGW4}JO zYs-17KYW&?y=!ZYZK*YHXT9oI+r8^qn{-=k_dB!3Y41J0UXqq^^?n!Hs$YK$9rX{$ zjqaj$opiL;ew6lyWU<>MU9CNglc&QpoBlpINQe91kWzZ>5qZ#UXkW}Iu4-9C