From 1167ad7ef2901e61c702657b0d422d4553291e64 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Mon, 26 Feb 2024 11:21:46 +0800 Subject: [PATCH] fix build warning && cpp lint --- docs/ContribOperators.md | 30 +++++++++---------- docs/OperatorKernels.md | 4 +-- .../cpu/quantization/matmul_nbits.cc | 8 +++-- .../cpu/quantization/matmul_nbits_impl.cc | 5 ++-- .../cuda/quantization/dequantize_blockwise.cu | 20 ++++++------- .../cuda/quantization/matmul_nbits.cc | 7 +++-- .../core/graph/contrib_ops/contrib_defs.cc | 2 +- .../quantization/matmul_4bits_quantizer.py | 26 +++++++++++----- .../test/contrib_ops/matmul_4bits_test.cc | 21 ++++++++----- .../quantization/test_op_matmul_4bits.py | 2 +- 10 files changed, 73 insertions(+), 52 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f523e97293427..7536923323060 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2808,22 +2808,14 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits - - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = block_size / 8 * bits Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - [N * n_blocks_per_col] if bits > 4 - #### Version @@ -2844,17 +2836,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2869,8 +2863,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b0ed68d595c42..4cf65c4bd3fe4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -470,7 +470,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -854,7 +854,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index aec65b93e37de..dd078de186713 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -12,7 +12,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "matmul_nbits_impl.h" +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #ifdef ORT_NEURAL_SPEED #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" @@ -288,8 +288,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = (reorder_idx_data == nullptr) && - (!zero_points || !zero_points->IsDataType()) && std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); + const bool has_single_b_matrix = + (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index 0321876b31b0d..d86b54d397341 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #include #include #include @@ -54,7 +55,7 @@ void Dequantize4BitsKernelReOrder( if constexpr (std::is_same_v) { T zp_adjust = -scale * MLFloat16(zp_f); - output_i[i] = float((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; } else { T zp_adjust = -scale * zp_f; output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; @@ -86,7 +87,7 @@ void DequantizeBlockwise( for (int j = 0; j < 256; j++) { Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, reorder_idx, block_size, groups_per_threadblock, - total_groups, N, K, block_id, j); + total_groups, N, K, static_cast(block_id), j); } }); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index ba8e8511fbb79..cd6593352008b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -72,9 +72,9 @@ __global__ void Dequantize4BitsKernelReOrder( if (group_id >= total_groups) { return; } - //T __shared__ zero_points_after_reorder[];//K - //T __shared__ scales_after_reorder[]; // K - //const int num_r_per_thread = k / 256; + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; const int zero_point_shape_x = (groups_per_K + 1) / 2; const int scales_shape_x = groups_per_K; @@ -102,7 +102,7 @@ __global__ void Dequantize4BitsKernelReOrder( } } -template +template __global__ void Dequantize4BitsKernel( T* output, const uint8_t* quant_data, @@ -116,15 +116,15 @@ __global__ void Dequantize4BitsKernel( if (block_id >= total_groups) { return; } - const int zero_point_shape_x = (groups_per_K + 1) / 2; - const int scales_shape_x = groups_per_K; - int n_idx = block_id / scales_shape_x; - int kb_idx = block_id % scales_shape_x; int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); T zero_point_value; - if constexpr(std::is_same_v) { + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; uint8_t zp = 8; if (zero_points) { zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; @@ -168,7 +168,7 @@ Status Dequantize4Bits( groups_per_threadblock, total_groups); } else { - //static_assert(std::is_same_v, "ZeroT must be uint8_t"); + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); Dequantize4BitsKernelReOrder<<>>( output, quant_data, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index fcecce8ce9999..d5f6d04771c57 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -7,7 +7,7 @@ // pre-packed and block-compacted into int4 // -#include "matmul_nbits.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" #include #include "core/common/status.h" #include "core/framework/float16.h" @@ -67,6 +67,9 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); auto* b_data = b_data_ptr.get(); if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } // column-wise block if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( @@ -159,7 +162,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace cuda diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 1716de58fa178..1b3b19e746e31 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3375,7 +3375,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index e23af144b8c26..a1916e806c5c0 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -15,7 +15,6 @@ import numpy as np import numpy.typing as npt import onnx -import torch from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version @@ -155,14 +154,16 @@ def __init__( # Proximal solver || weight - dequantize(quantize(weight))||_p^p @staticmethod def optimize_weights( - tensor: torch.Tensor, - scale: torch.Tensor, - zero: torch.Tensor, + tensor, + scale, + zero, min_max: list[int], axis: int = 0, opt_params: dict = None, # noqa: RUF013 verbose=False, ): + import torch + opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params lp_norm, beta, kappa, iters = ( opt_params["lp_norm"], @@ -214,7 +215,7 @@ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): ori_int_tensor = ori_int_tensor.T pack_tensor = pack_tensor.T if bits in [2, 4, 8]: - compress_ratio = pack_tensor.dtype.itemsize * 8 // bits + compress_ratio = pack_tensor.element_size() * 8 // bits for j in range(0, compress_ratio): pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) else: @@ -224,6 +225,8 @@ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): def quantize_internal( self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 ): + import torch + weight = tensor.float() ori_shape = weight.shape @@ -288,6 +291,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": return node # only care about MatMul for now + import torch logger.info(f"start to quantize {node.name} ...") inputB = node.input[1] # noqa: N806 @@ -458,6 +462,8 @@ class MatMul4BitsQuantizer: def __init__( self, model: ModelProto | str, + block_size: int = 128, + is_symmetric: bool = False, accuracy_level: int | None = None, nodes_to_exclude=None, algo_config: WeightOnlyQuantConfig = None, @@ -466,11 +472,15 @@ def __init__( nodes_to_exclude = [] self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) self.model_path = model if isinstance(model, str) else None + self.block_size = block_size + self.is_symmetric = is_symmetric self.accuracy_level = accuracy_level self.nodes_to_exclude = set(nodes_to_exclude) self.node_quantizer = None if algo_config is None: - algo_config = DefaultWeightOnlyQuantConfig(block_size=32, is_symmetric=False, accuracy_level=accuracy_level) + algo_config = DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) @@ -527,8 +537,8 @@ def _generate_q4_node_config(self): q4_node_config = {} template_config_q4 = { "bits": 4, - "group_size": self.algo_config.block_size, - "scheme": "sym" if self.algo_config.is_symmetric else "asym", + "group_size": self.block_size, + "scheme": "sym" if self.is_symmetric else "asym", } for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bd1bc1b8e224a..89768dbf682e4 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -68,6 +68,7 @@ void QuantizeDequantize(std::vector& raw_vals, void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, bool has_zeropoint, bool use_float16, bool has_g_idx = false, bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { + std::cerr << M << " " << N << " " << K << " " << block_size << " " << has_zeropoint << " " << use_float16 << " " << has_g_idx << " " << zp_is_4bit << " " << std::endl; zp_is_4bit = zp_is_4bit | has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); @@ -115,6 +116,8 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddAttribute("block_size", block_size); test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", accuracy_level); + auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; }; + if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); @@ -137,15 +140,16 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); } - } else { + } else if(has_g_idx) { test.AddInput("", {0}, {}); } if (has_g_idx) { - std::vector g_idx(K); - for (int64_t i = 0; i < K; i++) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { g_idx[i] = gsl::narrow(i / block_size); } - test.AddInput("g_idx", {static_cast(K)}, g_idx, true); + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -180,11 +184,12 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("", {0}, {}); } if (has_g_idx) { - std::vector g_idx(K); - for (int64_t i = 0; i < K; i++) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { g_idx[i] = gsl::narrow(i / block_size); } - test.AddInput("g_idx", {static_cast(K)}, g_idx, true); + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); } test.AddOutput("Y", {M, N}, expected_vals); if (accuracy_level == 4) { @@ -209,7 +214,9 @@ TEST(MatMulNBits, Float32) { for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); +#if !defined(DISABLE_OPTIONAL_TYPE) RunTest(M, N, K, block_size, accuracy_level, false, false, true); +#endif RunTest(M, N, K, block_size, accuracy_level, true, false, false, false); } #endif diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 4664177ee146c..88e5052db4e2e 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -173,7 +173,7 @@ def quant_test_with_algo( algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=algo_config) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) quant.process() quant.model.save_model_to_file(model_int4_path, False)