-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
onnxruntime shape mismatch during quantization of yolov8 models #21048
Comments
Can you share the full reproducer that quantize the model so I can reproduce it? |
Sure @yihonglyu, here's a minimal example where I was able to reproduce it with: Data reader: import numpy as np
import onnxruntime
import os
import random
from PIL import Image
from tqdm import tqdm
from onnxruntime.quantization import CalibrationDataReader
class RandomCalibrationDataReader(CalibrationDataReader):
def __init__(self, model_path, none1, limit=10):
self.model_path = model_path
self.limit = limit
self.index = 0
# Initialize ONNX runtime session to get input shape.
self.session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
self.input_shape = self.session.get_inputs()[0].shape
self.target_size = (640, 640) # Assuming the target size
self.datasize = limit
def get_next(self):
if self.index < self.datasize:
self.index += 1
return {self.session.get_inputs()[0].name: np.random.random(self.input_shape).astype(np.float32)}
def rewind(self):
self.index = 0 Quantization: from ultralytics import YOLO
import sys
import os
import onnxruntime as ort
from onnxruntime.quantization import QuantType, quantize
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model
from onnxruntime.quantization.shape_inference import quant_pre_process
from utils.data_reader import CalibrationDataReader, RandomCalibrationDataReader
model_name = 'yolov8x.pt'
model = YOLO(model_name)
model.export(format='onnx')
input_model_path = model_name.replace('.pt', '.onnx')
# Quantization
data_reader = RandomCalibrationDataReader(input_model_path, '.', limit=200)
preproc_model_path = 'model.preproc.onnx'
quant_pre_process(input_model_path, preproc_model_path, skip_optimization=False)
model_changed = qnn_preprocess_model(preproc_model_path, preproc_model_path)
print(f'Model changed? {model_changed}')
model_to_quantize = preproc_model_path if model_changed else input_model_path
print(f'Model to quantize: {model_to_quantize}')
qnn_config = get_qnn_qdq_config(model_to_quantize,
data_reader,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QUInt8,
per_channel=False,
activation_symmetric=True,
weight_symmetric=True)
output_model_path = 'model.qdq.onnx'
quantize(model_to_quantize, 'model.qdq.onnx', qnn_config)
def test_model(model_path, input_data):
print(input_data['images'].shape)
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
outputs = session.run(None, input_data)
return outputs[0]
# Initialize the data reader for the validation dataset
validation_data_reader = CalibrationDataReader(input_model_path, '.', limit=10)
# Accumulate errors
errors = []
# Loop through all data provided by the data reader
while True:
input_data = validation_data_reader.get_next()
if input_data is None:
break # End of data
orig_outputs = test_model(input_model_path, input_data)
quant_outputs = test_model(output_model_path, input_data)
# Compute absolute error for the current batch and store it
batch_error = np.abs(orig_outputs - quant_outputs)
errors.append(batch_error)
# Compute the mean of all errors
if errors:
avg_abs_error = np.mean(np.concatenate(errors)) # Concatenate to handle multiple batches
print(f'Average absolute error per output: {avg_abs_error}')
else:
print("No data available to compute error.") The error happens during inference after quantization:
|
Let me know if you are able to reproduce or have issues running this! |
Could you share the model for the reproducer, too? Thanks |
When you said, "regression in onnxruntime functionality", do you mean it used to work before? |
Yes, I have a model that I previously quantized with ORT successfully but I don't remember which versions of ultralytics/ort/onnx I used. I'm trying to reproduce it now. |
@HectorSVC @yihonglyu Ok, I've been able to reproduce. This is the issue I get with the latest versions of ORT:
With ORT 1.17, it runs fine. When running with more recent versions, I get the error. Potentially related (but different op?): #16462 Note that this is when I exclude the last conv layer from quantization. |
Describe the issue
When trying to quantize a Yolov8 model (exported with
yolo export model=yolov8x.pt format=onnx
) withonnxruntime
, I get the following error:To reproduce
yolo export model=yolov8x.pt format=onnx
Urgency
This is blocking for the project I'm working on, and seems like a regression in
onnxruntime
functionality.Platform
Linux
OS Version
Ubuntu 22.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.18.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Other / Unknown
Execution Provider Library Version
QNN
The text was updated successfully, but these errors were encountered: