Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static quantization crashes with "TypeError: iteration over a 0-d array" #19529

Open
onnxruntime-user opened this issue Feb 15, 2024 · 2 comments
Assignees
Labels
quantization issues related to quantization

Comments

@onnxruntime-user
Copy link

Describe the issue

During static quantization with percentile calibration, the following crash happens with onnxruntime-gpu 1.17.0:

Collecting tensor data and making histogram ...
Traceback (most recent call last):
File ".\new_onnx_quantize_static.py", line 21, in
oq.quantize_static(preprocessed_path, quantized_path, Mock(),
File "D:\Tools\Venv\NewestOnnx\lib\site-packages\onnxruntime\quantization\quantize.py", line 496, in quantize_static
calibrator.collect_data(calibration_data_reader)
File "D:\Tools\Venv\NewestOnnx\lib\site-packages\onnxruntime\quantization\calibrate.py", line 546, in collect_data
self.collector.collect(clean_merged_dict)
File "D:\Tools\Venv\NewestOnnx\lib\site-packages\onnxruntime\quantization\calibrate.py", line 724, in collect
return self.collect_absolute_value(name_to_arr)
File "D:\Tools\Venv\NewestOnnx\lib\site-packages\onnxruntime\quantization\calibrate.py", line 739, in collect_absolute_value
dtypes = set(a.dtype for a in arr)
TypeError: iteration over a 0-d array

Older onnxruntime versions did not have this issue.

To reproduce

Model is: minimal_model.zip. Code to reproduce:

import onnxruntime.quantization as oq
import numpy as np


class Mock:
    def __init__(self):
        self.i = 0

    def get_next(self):
        if self.i > 10:
            return None
        self.i += 1
        return {"input": np.random.randint(0, 255, size=(1, 3, 32, 32), dtype=np.uint8)}

if __name__ == "__main__":
    onnx_path = "minimal_model.onnx"
    preprocessed_path = onnx_path[:-5] + "_preprocessed.onnx"
    quantized_path = onnx_path[:-5] + "_quantized.onnx"
    oq.quant_pre_process(onnx_path, preprocessed_path, skip_symbolic_shape=True)
    oq.quantize_static(preprocessed_path, quantized_path, Mock(),
                       calibrate_method=oq.CalibrationMethod.Percentile,
                       op_types_to_quantize=["Conv", "Mul", "Gemm"]) 

Urgency

Urgent, since it worked with previous versions of onnxruntime-gpu.

Platform

Windows

OS Version

10

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.6

@skottmckay skottmckay added the quantization issues related to quantization label Feb 20, 2024
xadupre added a commit to xadupre/onnxruntime that referenced this issue Feb 21, 2024
@xadupre
Copy link
Member

xadupre commented Feb 21, 2024

PR #19591 fixes this issue. Am I allowed to use the model to write the unit test checking the bug never shows up again?

@onnxruntime-user
Copy link
Author

onnxruntime-user commented Feb 22, 2024 via email

xadupre added a commit that referenced this issue Apr 24, 2024
### Description
Fix issue #19529, the code was using a variable loop outside a loop.
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this issue May 7, 2024
### Description
Fix issue microsoft#19529, the code was using a variable loop outside a loop.
rexlee8776 pushed a commit to Deep-Spark/DeepSparkInference that referenced this issue Jan 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants