Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] MatMulNBits Performance #23004

Open
DakeQQ opened this issue Dec 4, 2024 · 8 comments
Open

[Performance] MatMulNBits Performance #23004

DakeQQ opened this issue Dec 4, 2024 · 8 comments
Assignees
Labels
performance issues related to performance regressions platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization

Comments

@DakeQQ
Copy link

DakeQQ commented Dec 4, 2024

Describe the issue

Why do MatMulNBits operators with quant types int4/uint4 (both f32 and f16 as dtypes) perform at least 10x slower than MatMulIntegerToFloat/DynamicQuantizeMatMul in the dynamic quantization process (int8/uint8) on CPUs across platforms?

  • Linux+Intel
  • Windows+Intel
  • Android+Qualcomm+Arm64-v8.6
  • Mac+M3

The test models range from small LLMs (0.5B to 3B parameters), including Llama, Phi, Qwen, and Gemma, all showing the same performance results. Any insights?

To reproduce

INT4:

from onnxruntime.quantization import (
    matmul_4bits_quantizer,
    quant_utils,
    quantize
)
from pathlib import Path

model_fp32_path="path/to/orignal/model.onnx"
model_int4_path="path/to/save/quantized/model.onnx"

quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
  block_size=256, # 2's exponential and >= 16
  is_symmetric=True, # if true, quantize to Int4. otherwsie, quantize to uint4.
  accuracy_level=4, # used by MatMulNbits, see https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#attributes-35
  quant_format=quant_utils.QuantFormat.QOperator, 
  op_types_to_quantize=("MatMul","Gather"), # specify which op types to quantize
  quant_axes=(("MatMul", 0), ("Gather", 1),) # specify which axis to quantize for an op type.

model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(
  model, 
  nodes_to_exclude=None, # specify a list of nodes to exclude from quantizaiton
  nodes_to_include=None, # specify a list of nodes to force include from quantization
  algo_config=quant_config,)
quant.process()
quant.model.save_model_to_file(
  model_int4_path,
  True) # save data to external file

INT8:

quantize_dynamic(
    model_input=model_path,
    model_output=quanted_model_path,
    per_channel=True,                                        
    reduce_range=False,                                     
    weight_type=QuantType.QUInt8,                           
    extra_options={'ActivationSymmetric': True,             
                   'WeightSymmetric': True,               
                   'EnableSubgraph': True,                   
                   'ForceQuantizeNoInputCheck': False,       
                   'MatMulConstBOnly': True                 
                   },
    nodes_to_exclude=None,                                   
    use_external_data_format=False                 
)

Urgency

None

Platform

Linux

OS Version

22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

1.20.1

Model File

No response

Is this a quantized model?

Yes

@DakeQQ DakeQQ added the performance issues related to performance regressions label Dec 4, 2024
@github-actions github-actions bot added the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Dec 4, 2024
@skottmckay skottmckay added the quantization issues related to quantization label Dec 4, 2024
@skottmckay
Copy link
Contributor

For int4 I believe things are more optimized for accuracy_level == 4 which can be specified when calling matmul_4bits_quantizer.MatMul4BitsQuantizer

.Attr("accuracy_level",
"The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
"(default unset). It is used to control how input A is quantized or downcast internally while "
"doing computation, for example: 0 means input A will not be quantized or downcast while doing "
"computation. 4 means input A can be quantized with the same block_size to int8 internally from "
"type T1.",

typedef enum {
Level0, /*!< input fp32, accumulator fp32 */
Level1, /*!< input fp32, accumulator fp32 */
Level2, /*!< input fp16, accumulator fp16 */
Level3, /*!< input bf16, accumulator fp32 */
Level4, /*!< input int8, accumulator int32 */
} ACCURACY_LEVEL;

@DakeQQ
Copy link
Author

DakeQQ commented Dec 4, 2024

Thank you for your reply. We quantized the model with accuracy_level=4 and set the following config:

session_opts.add_session_config_entry("session.qdq_matmulnbits_accuracy_level", "4")

However, the performance remains noticeably slower compared to using quantize_dynamic()-int8, both on the x64 Python API and Android-Arm64-C API. Additionally, we noticed that the Gen-AI uses MatMulNBits for the Phi model with almost no performance issues.

We would like to know if the slower performance we’re experiencing is due to incorrect settings or processes on our end, or if MatMulNBits is inherently slower than quantize_dynamic()-int8.

@skottmckay
Copy link
Contributor

@fajin-corp could you please take a look?

@fajin-corp
Copy link
Contributor

@DakeQQ in your int4 quantization script, you quantized both MatMul and Gather to int4. However, in your int8 script, you only quantized MatMul. Please not quantize Gather to int4 and run the benchmarking.

@DakeQQ
Copy link
Author

DakeQQ commented Dec 5, 2024

@fajin-corp
Thank you for the suggestion. We follow the suggestion and create a benchmark test. The test shows MatMulNBits with level 4 obviously slower then using quantize_dynamic() across the platform. Although simple qkv test model shows 1.6x times slower, result in real LLM 10x times decode speed. We provide the benchmark python script, and you can try.

The expected output should look like this:

[Benchmark] (DynamicQuantizeLinear + MatMulIntegerToFloat)
  Total Time: 5.096 seconds

[Benchmark] MatMulNBits
  Total Time: 8.287 seconds

Test_Script.py

import time
import torch
import torch.nn as nn
import onnxruntime
from onnxruntime.transformers.optimizer import optimize_model
from onnxruntime.quantization import (
    matmul_4bits_quantizer,
    quant_utils,
    quantize_dynamic,
    QuantType,
)
from pathlib import Path

# Define file paths for saving models
save_path = "model.onnx"                                     # Path to save the original ONNX model
quantize_dynamic_model_path = "model_quantize_dynamic.onnx"  # Path for dynamically quantized model
matmulnbit_model_path = "model_matmulnbit.onnx"              # Path for MatMul-N-bit quantized model


# Define a simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        # Linear projection layers for Q, K, and V
        self.q_proj = nn.Linear(input_dim, output_dim, bias=False)
        self.k_proj = nn.Linear(input_dim, output_dim, bias=False)
        self.v_proj = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        return q, k, v


# Set model dimensions
input_dim = 4096
output_dim = 4096

# Instantiate the model
model = SimpleModel(input_dim, output_dim)

# Generate a random input tensor
input_tensor = torch.randn(1, 2048, input_dim, dtype=torch.float32)

# Export the PyTorch model to ONNX
torch.onnx.export(
    model,
    (input_tensor,),
    save_path,
    input_names=['input_tensor'],   # Define input tensor name
    output_names=['q', 'k', 'v'],   # Define output tensor names
    do_constant_folding=True,       # Enable constant folding for optimization
    opset_version=17                # Use ONNX opset version 17
)

# Step 1: Apply dynamic quantization to the model
quantize_dynamic(
    model_input=save_path,                      # Path to the original ONNX model
    model_output=quantize_dynamic_model_path,   # Path to save the quantized model
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QInt8,                # Use 8-bit integer for weights
    extra_options={
        'ActivationSymmetric': True,
        'WeightSymmetric': True,
        'EnableSubgraph': True,
        'ForceQuantizeNoInputCheck': False,
        'MatMulConstBOnly': True
    },
    nodes_to_exclude=None,
    use_external_data_format=False
)

# Step 2: Optimize the dynamically quantized model
model = optimize_model(
    quantize_dynamic_model_path,    # Input quantized model
    use_gpu=False,                  # Optimization will run on CPU
    opt_level=2,                    # Optimization level
    num_heads=1,
    hidden_size=input_dim,
    provider='CPUExecutionProvider',
    verbose=False,
    model_type='bert'               # Assume BERT-like model for optimization
)
model.save_model_to_file(quantize_dynamic_model_path, use_external_data_format=False)

# Step 3: Perform 4-bit quantization on MatMul operations
quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
    block_size=256,                                  # Block size for quantization
    is_symmetric=True,
    accuracy_level=4,                                
    quant_format=quant_utils.QuantFormat.QOperator,
    op_types_to_quantize=("MatMul",),                # Apply quantization to MatMul operations
    quant_axes=(("MatMul", 0),)
)

# Load the original ONNX model
model = quant_utils.load_model_with_shape_infer(Path(save_path))

# Initialize MatMul-4-bit quantizer
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(
    model,
    nodes_to_exclude=None,
    nodes_to_include=None,
    algo_config=quant_config,
)

# Apply quantization and save the model
quant.process()
quant.model.save_model_to_file(
    matmulnbit_model_path,
    False
)

# Step 4: Configure ONNX Runtime settings
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3
session_opts.inter_op_num_threads = 0
session_opts.intra_op_num_threads = 0
session_opts.enable_cpu_mem_arena = True
session_opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
session_opts.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session_opts.add_session_config_entry("session.intra_op.allow_spinning", "1")
session_opts.add_session_config_entry("session.inter_op.allow_spinning", "1")
session_opts.add_session_config_entry("session.qdq_matmulnbits_accuracy_level", "4")

# Load ONNX Runtime sessions for both models
ort_session_A = onnxruntime.InferenceSession(quantize_dynamic_model_path, sess_options=session_opts, providers=['CPUExecutionProvider'])
in_name_A = ort_session_A.get_inputs()[0].name
out_name_A0 = ort_session_A.get_outputs()[0].name
out_name_A1 = ort_session_A.get_outputs()[1].name
out_name_A2 = ort_session_A.get_outputs()[2].name

ort_session_B = onnxruntime.InferenceSession(matmulnbit_model_path, sess_options=session_opts, providers=['CPUExecutionProvider'])
in_name_B = ort_session_B.get_inputs()[0].name
out_name_B0 = ort_session_B.get_outputs()[0].name
out_name_B1 = ort_session_B.get_outputs()[1].name
out_name_B2 = ort_session_B.get_outputs()[2].name

# Prepare the input tensor for ONNX inference
input_tensor = input_tensor.numpy()

# Step 5: Run Benchmark
print("\n[Benchmark] (DynamicQuantizeLinear + MatMulIntegerToFloat)")
start = time.time()
for i in range(10):
    q, k, v = ort_session_A.run(
        [out_name_A0, out_name_A1, out_name_A2],
        {in_name_A: input_tensor}
    )
end = time.time()
print(f"  Total Time: {(end - start):.3f} seconds")

print("\n[Benchmark] MatMulNBits")
start = time.time()
for i in range(10):
    q, k, v = ort_session_B.run(
        [out_name_B0, out_name_B1, out_name_B2],
        {in_name_B: input_tensor}
    )
end = time.time()
print(f"  Total Time: {(end - start):.3f} seconds")

@fajin-corp
Copy link
Contributor

@DakeQQ int4 1.6x slower than int8 is expected, as int4 needs to convert to int8 during calculation. Int4 model size is ~ half of int8. This size reduction comes with a sacrifice of speed.

The 10x slower of the LLM result does not make sense, given the benchmark result. I suspect it is due to your set up of the comparison. I suggest look deeper into your code.

@DakeQQ
Copy link
Author

DakeQQ commented Dec 6, 2024

@fajin-corp
Sure, I've revised the Qwen 0.5B model information to present a side-by-side comparison of the two models, making it easier to identify differences in operator usage. Notably, the number of (DynamicQuantizeMatMul + MatMulIntegerToFloat) operations matches exactly with MatMulNBits, while the other operators remain identical. Although the latency of MatMulNBits is slightly lower than (DynamicQuantizeMatMul + MatMulIntegerToFloat) individually, their cumulative effect results in a 10x slower decoding performance overall. It may sound surprising, but even small inefficiencies can add up to a massive impact over time, don't they.

The benchmark for these inferences is identical; the only difference is the model call path, and we have ensured the correct model is invoked during benchmarking.

How do Gen-AI libraries work? Why doesn’t the performance impact occur in their models when using MatMulNBits? I’m curious if there’s any special optimization or 'magic' behind Gen-AI.

Comparison of Model Operators between Dynamic_Quant.onnx and MatMulNBits.onnx

Operator Qwen-Dynamic_Quant.onnx Qwen-MatMulNBits.onnx
Add 171 171
ArgMax 1 1
Cast 55 55
Concat 102 102
ConstantOfShape 1 1
DequantizeLinear 2 -
Div 49 49
DynamicQuantizeLinear 48 -
DynamicQuantizeMatMul 49 -
Expand 48 48
Gather 4 4
MatMul 48 48
MatMulIntegerToFloat 120 -
MatMulNBits - 169
Mul 244 244
Neg 48 48
QuickGelu 24 24
ReduceMean 49 49
Reshape 144 145
Slice 7 7
Softmax 24 24
Split 50 50
Sqrt 49 49
Squeeze 48 48
Sub 1 1
Transpose 120 120
Unsqueeze 96 97

@fajin-corp
Copy link
Contributor

@DakeQQ if MatMul takes most of the decoding time, and in your benchmark MatMulNbits is 1.6x slower than (DynamicQuantizeMatMul + MatMulIntegerToFloat), then the decoding time of MatMulNBits LLM should be 1.6x slower than the int8 model.

Gen-AI is using the exact same script as you used to quantize the model. The default setting of the script in Gen-AI is https://github.com/microsoft/onnxruntime-genai/blob/bde55ad4f950f62c9ce8b4b892ecb86747e377b7/src/python/py/models/builder.py#L293C1-L298C102.

I suggest you compare the int8 model vs int4 model to see what's the difference in the model structure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance issues related to performance regressions platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants