Skip to content

Commit

Permalink
Fix quantization tools for issue microsoft#19529 (microsoft#19591)
Browse files Browse the repository at this point in the history
### Description
Fix issue microsoft#19529, the code was using a variable loop outside a loop.
  • Loading branch information
xadupre authored and Ted Themistokleous committed May 7, 2024
1 parent 123da24 commit 76dd035
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 7 deletions.
12 changes: 5 additions & 7 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,13 +733,11 @@ def collect_absolute_value(self, name_to_arr):
for tensor, data_arr in name_to_arr.items():
if isinstance(data_arr, list):
for arr in data_arr:
if not isinstance(arr, np.ndarray):
raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}")
dtypes = set(a.dtype for a in arr)
if len(dtypes) != 1:
raise ValueError(
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
)
assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}"
dtypes = set(a.dtype for a in data_arr)
assert (
len(dtypes) == 1
), f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
data_arr_np = np.asarray(data_arr)
elif not isinstance(data_arr, np.ndarray):
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/test/python/quantization/test_quant_issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import tempfile
import unittest
import warnings


def ignore_warnings(warns):
"""
Catches warnings.
:param warns: warnings to ignore
"""

def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")

def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)

return call_f

return wrapper


class TestQuantIssues(unittest.TestCase):
@ignore_warnings(DeprecationWarning)
def test_minimal_model(self):
folder = os.path.join(os.path.dirname(__file__), "..", "..", "testdata")
onnx_path = os.path.join(folder, "qdq_minimal_model.onnx")
if not os.path.exists(onnx_path):
# The file does seem to be the same location in every CI job.
raise unittest.SkipTest("unable to find {onnx_path!r}")

import numpy as np

import onnxruntime.quantization as oq

class Mock:
def __init__(self):
self.i = 0

def get_next(self):
if self.i > 10:
return None
self.i += 1
return {"input": np.random.randint(0, 255, size=(1, 3, 32, 32), dtype=np.uint8)}

with tempfile.TemporaryDirectory() as temp:
preprocessed_path = os.path.join(temp, "preprocessed.onnx")
quantized_path = os.path.join(temp, "quantized.onnx")
oq.quant_pre_process(onnx_path, preprocessed_path, skip_symbolic_shape=True)
oq.quantize_static(
preprocessed_path,
quantized_path,
Mock(),
calibrate_method=oq.CalibrationMethod.Percentile,
op_types_to_quantize=["Conv", "Mul", "Gemm"],
)
assert os.path.exists(preprocessed_path), f"missing output {preprocessed_path!r}"
assert os.path.exists(quantized_path), f"missing output {quantized_path!r}"


if __name__ == "__main__":
unittest.main(verbosity=2)
Binary file added onnxruntime/test/testdata/qdq_minimal_model.onnx
Binary file not shown.

0 comments on commit 76dd035

Please sign in to comment.