Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quantization tools for issue #19529 #19591

Merged
merged 9 commits into from
Apr 24, 2024
12 changes: 5 additions & 7 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,13 +734,11 @@ def collect_absolute_value(self, name_to_arr):
for tensor, data_arr in name_to_arr.items():
if isinstance(data_arr, list):
for arr in data_arr:
if not isinstance(arr, np.ndarray):
raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}")
dtypes = set(a.dtype for a in arr)
if len(dtypes) != 1:
raise ValueError(
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
)
assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}"
dtypes = set(a.dtype for a in data_arr)
assert (
len(dtypes) == 1
), f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
data_arr_np = np.asarray(data_arr)
elif not isinstance(data_arr, np.ndarray):
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
Expand Down
71 changes: 71 additions & 0 deletions onnxruntime/test/python/quantization/test_quant_issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# -------------------------------------------------------------------------
Fixed Show fixed Hide fixed
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import tempfile
import unittest
import warnings


def ignore_warnings(warns):
"""
Catches warnings.

:param warns: warnings to ignore
"""

def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")

def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)

return call_f

return wrapper


class TestQuantIssues(unittest.TestCase):

@ignore_warnings(DeprecationWarning)
def test_minimal_model(self):
folder = os.path.join(os.path.dirname(__file__), "..", "..", "testdata")
onnx_path = os.path.join(folder, "qdq_minimal_model.onnx")
if not os.path.exists(onnx_path):
raise unittest.SkipTet(f"file {onnx_path!r} is missing")

import onnxruntime.quantization as oq
import numpy as np

class Mock:
def __init__(self):
self.i = 0

def get_next(self):
if self.i > 10:
return None
self.i += 1
return {"input": np.random.randint(0, 255, size=(1, 3, 32, 32), dtype=np.uint8)}

with tempfile.TemporaryDirectory() as temp:
preprocessed_path = os.path.join(temp, "preprocessed.onnx")
quantized_path = os.path.join(temp, "quantized.onnx")
oq.quant_pre_process(onnx_path, preprocessed_path, skip_symbolic_shape=True)
oq.quantize_static(
preprocessed_path,
quantized_path,
Mock(),
calibrate_method=oq.CalibrationMethod.Percentile,
op_types_to_quantize=["Conv", "Mul", "Gemm"],
)
assert os.path.exists(preprocessed_path), f"missing output {preprocessed_path!r}"
assert os.path.exists(quantized_path), f"missing output {quantized_path!r}"


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading