-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization stable diffusion model sd2.1 fp into onnx int8][pytorch to fp32 succefully converted then fp32 onnx to int8 quantization problem occours #19183
Comments
Which version of onnxruntime are you using? I can't find it in the files you attached to this issue. You are using an old version of torch, any reason for that? |
I tried new or old both versions but still getting same issue. |
Outside I installed. And I also give you unet.txt that is ipynb file. |
@siddharth062022, are you able to run the fp32 version successfully with ONNXRuntime? And could you please provide an instruction to repro the issue? |
I successfully convert pytorch fp32 to onxx fp32 but when I try to do static onnx quantization this problem occurs. |
import gc
import onnx
import torch
import numpy as np
import onnxruntime
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, CalibrationMethod, QuantFormat
# Define a custom CalibrationDataReader class
class UNetDataReader(CalibrationDataReader):
def __init__(self, model_path: str):
self.model_path = model_path
self.input_names = None
self.enum_data = None
self.load_model()
self.generate_calibration_data()
def load_model(self):
# Load the ONNX model and get the input tensor names
session = onnxruntime.InferenceSession(self.model_path, providers=['CPUExecutionProvider'])
self.input_names = [input.name for input in session.get_inputs()]
def generate_calibration_data(self):
# Generate random NHWC data for calibration with the correct input names
self.calibration_data = {
#'latent_model_input': torch.randn(2, 4, 96, 96).numpy().astype(np.float32), # Change to float32
#'t': np.array([1], dtype=np.float32), # Change to float32
#'encoder_hidden_states': np.random.rand(2, 77, 1024).astype(np.float32), # Change to float32
'latent_model_input': torch.randn(2, 4, 96, 96,).numpy().astype(np.float32), # Change to float32
't': np.array([1], dtype=np.float32), # Change to float32
'encoder_hidden_states': np.random.rand(2, 77, 1024).astype(np.float32), # Change to float32
}
self.datasize = len(self.calibration_data)
def get_next(self):
if self.enum_data is None:
self.enum_data = iter([self.calibration_data])
return next(self.enum_data, None)
def rewind(self):
self.enum_data = None
# Define paths for the input and quantized models
model_path = '/content/sd2.1/unet/unet.onnx'
#model_path = '/content/unet_fp16.onnx'
quantized_model_path = '/content/unetint8/unet.onnx'
# Create a calibration data reader
data_reader = UNetDataReader(model_path)
gc.collect()
# Perform static quantization
quantize_static(
model_input=model_path,
model_output=quantized_model_path,
calibration_data_reader=data_reader,
activation_type=QuantType.QInt8,
weight_type=QuantType.QInt8,
use_external_data_format=True,
calibrate_method=CalibrationMethod.MinMax,
quant_format=QuantFormat.QDQ,
)
gc.collect()
|
I have solved this problem myself.if any face problem ask. |
Describe the issue
RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'/time_proj/Mul' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 160
To reproduce
requirements.5 (1).txt
unet (3).TXT
Urgency
only stcuk on unet part last please support.
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
onnx 1.15.0,onnxruntime 1.16.3
PyTorch Version
torch 2.1.2
Execution Provider
Default CPU
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: