-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] MatMulNBits Performance #23004
Comments
For int4 I believe things are more optimized for accuracy_level == 4 which can be specified when calling onnxruntime/onnxruntime/core/graph/contrib_ops/contrib_defs.cc Lines 3457 to 3462 in 9b9f881
onnxruntime/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Lines 35 to 41 in 9b9f881
|
Thank you for your reply. We quantized the model with
However, the performance remains noticeably slower compared to using We would like to know if the slower performance we’re experiencing is due to incorrect settings or processes on our end, or if |
@fajin-corp could you please take a look? |
@DakeQQ in your int4 quantization script, you quantized both MatMul and Gather to int4. However, in your int8 script, you only quantized MatMul. Please not quantize Gather to int4 and run the benchmarking. |
@fajin-corp The expected output should look like this:
Test_Script.py
|
@DakeQQ int4 1.6x slower than int8 is expected, as int4 needs to convert to int8 during calculation. Int4 model size is ~ half of int8. This size reduction comes with a sacrifice of speed. The 10x slower of the LLM result does not make sense, given the benchmark result. I suspect it is due to your set up of the comparison. I suggest look deeper into your code. |
@fajin-corp The benchmark for these inferences is identical; the only difference is the model call path, and we have ensured the correct model is invoked during benchmarking. How do Gen-AI libraries work? Why doesn’t the performance impact occur in their models when using Comparison of Model Operators between
|
@DakeQQ if MatMul takes most of the decoding time, and in your benchmark MatMulNbits is 1.6x slower than (DynamicQuantizeMatMul + MatMulIntegerToFloat), then the decoding time of MatMulNBits LLM should be 1.6x slower than the int8 model. Gen-AI is using the exact same script as you used to quantize the model. The default setting of the script in Gen-AI is https://github.com/microsoft/onnxruntime-genai/blob/bde55ad4f950f62c9ce8b4b892ecb86747e377b7/src/python/py/models/builder.py#L293C1-L298C102. I suggest you compare the int8 model vs int4 model to see what's the difference in the model structure. |
Describe the issue
Why do
MatMulNBits
operators with quant typesint4
/uint4
(bothf32
andf16
as dtypes) perform at least 10x slower thanMatMulIntegerToFloat
/DynamicQuantizeMatMul
in the dynamic quantization process (int8/uint8
) on CPUs across platforms?The test models range from small LLMs (0.5B to 3B parameters), including Llama, Phi, Qwen, and Gemma, all showing the same performance results. Any insights?
To reproduce
INT4:
INT8:
Urgency
None
Platform
Linux
OS Version
22.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.20.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
1.20.1
Model File
No response
Is this a quantized model?
Yes
The text was updated successfully, but these errors were encountered: