From 5df4ddd1c3d5686fb8a57e439b7162a447fcfdce Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:34:19 -0700 Subject: [PATCH] matmul 4bit tool chain support qdq (#21362) ### Description This is a partial change ported from fajin/qdqmatmulnbitstoolchain. That branch has issues resolving the web CI. MatMulNBits is a heavily optimized matmul operation. Currently a MatMul can be converted to MatMulNBits to speed up the model inference. However, MatMulNBits is an ORT only op. To make the graph compatible with ONNX ops and utilize MatMulNBits at the same time, we introduce Q/DQ support for MatMulNBits. To convert MatMul ops in a model to MatMulNBits: use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using QDQ mode. In ORT session, DQ + MatMul is fused to MatMulNBits #### Note MatMulNBits assume B weight is uint4. When no zp is provided, zp defaults to 8, which is different from DQ. DQ defaults zp to 0 when no zp provided. And DQ supports int4. Therefore some conversions are introduced during DQ + MatMul --> MatMulNBits step. #### Perf Using QDQ format will increase the model initialization time and memory consumption. With current implement, model init time increased from ~4s to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB. The memory increase is due to 1. in optimizer, after transpose the B weight, a in-memory tensor proto is created using protobuf's arena. 2. in finalize step, when saving initializer and prepacking, ORT arena is used to create buffers for initializers. The memory allocated by arenas cannot be fully deallocated. If disable ORT arena memory allocation, the memory consumptions of both QDQ format and original format are ~2.2GB. The time increase is mainly due to multiple memory copy, but can be further optimized. ### Motivation and Context Please see description for details. --- .../quantization/matmul_4bits_quantizer.py | 277 +++++++++++++----- .../quantization/test_op_matmul_4bits.py | 54 +++- 2 files changed, 246 insertions(+), 85 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 11a830dc6d7f5..40a4a4d26dc1c 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -18,31 +18,36 @@ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version -from onnxruntime.capi._pybind_state import quantize_matmul_4bits +from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel -from .quant_utils import attribute_to_kwarg +from .quant_utils import QuantFormat, attribute_to_kwarg logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) class WeightOnlyQuantConfig: - def __init__(self, algorithm): + def __init__(self, algorithm, quant_format): """This is the Base class for Weight Only Quant Configuration. Args: algorithm: weight only quantize algorithm name. + quant_format: QuantFormat{QOperator, QDQ}. + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. """ self.algorithm = algorithm + self.quant_format = quant_format class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, ratios=None, + quant_format=QuantFormat.QOperator, ): """ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. @@ -51,11 +56,18 @@ def __init__( Args: ratios: percentile of clip. Defaults to {}. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format" + if ratios is None: ratios = {} super().__init__( algorithm="RTN", + quant_format=quant_format, ) self.ratios = ratios @@ -69,6 +81,7 @@ def __init__( actorder=False, mse=False, perchannel=True, + quant_format=QuantFormat.QOperator, ): """ This is a class for GPTQ algorithm Weight Only Quant Configuration. @@ -87,9 +100,16 @@ def __init__( whether get scale and zero point with mse error. perchannel (bool, optional): whether quantize weight per-channel. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format" + super().__init__( algorithm="GPTQ", + quant_format=quant_format, ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp @@ -105,6 +125,7 @@ def __init__( block_size=128, bits=4, axis=1, + quant_format=QuantFormat.QOperator, ): """ This is a class for HQQ algorithm Weight Only Quant Configuration. @@ -112,14 +133,21 @@ def __init__( Args: block_size (int, optional): - channel number in one block to execute a GPTQ quantization iteration. + channel number in one block to execute a HQQ quantization iteration. bits (int, optional): how many bits to represent weight. axis (int, optional): 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format" + super().__init__( algorithm="HQQ", + quant_format=quant_format, ) self.block_size = block_size self.bits = bits @@ -132,8 +160,26 @@ def __init__( block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, + quant_format=QuantFormat.QOperator, ): - super().__init__(algorithm="DEFAULT") + """ + This is a class for weight only affine quantization configuration. + + Args: + block_size (int, optional): + channel number in one block to execute an affine quantization iteration. + is_symmetric (bool, optional): + whether quantize weight symmetrically. + accuracy_level (int, optional): + Accuracy level of the 4-bit quantized MatMul computation. + Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details. + (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits) + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. + """ + super().__init__(algorithm="DEFAULT", quant_format=quant_format) self.block_size = block_size self.is_symmetric = is_symmetric self.bits = 4 @@ -287,23 +333,26 @@ def quantize_internal( return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now import torch logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - b_pb, bs_graph = get_initializer(inputB, graph_stack) + input_b = node.input[1] + b_pb, bs_graph = get_initializer(input_b, graph_stack) if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight b_array = onnx.numpy_helper.to_array(b_pb) if len(b_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix + return [node] # can only process 2-D matrix b_array_torch = torch.from_numpy(b_array) if torch.cuda.is_available(): b_array_torch = b_array_torch.cuda() @@ -334,7 +383,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) b_quant.name = b_pb.name + "_Q4" for input in bs_graph.input: - if input.name == inputB: + if input.name == input_b: bs_graph.input.remove(input) break @@ -366,7 +415,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): logger.info(f"complete quantization of {node.name} ...") - return matmul_q4_node + return [matmul_q4_node] def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: @@ -382,7 +431,7 @@ class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config - def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: + def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """4b quantize fp32 weight to a blob""" if len(fp32weight.shape) != 2: @@ -390,83 +439,136 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: rows, cols = fp32weight.shape block_size = self.config.block_size - blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size - padded_rows = k_blocks * block_size - pad_len = padded_rows - rows - if pad_len > 0: - fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") - # block wise quantization, each block comes from a single column - packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) + if self.config.quant_format == QuantFormat.QOperator: + blob_size = block_size // 2 + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") + + # block wise quantization, each block comes from a single column + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) + quantize_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) + else: + packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype) + quantize_qdq_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) return (packed, scales, zero_point) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: + qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4 + input_b = node.input[1] + b_tensor, b_graph = get_initializer(input_b, graph_stack) + if b_tensor is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: + b_ndarray = onnx.numpy_helper.to_array(b_tensor) + if len(b_ndarray.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix - - packed, scales, zero_points = self.int4_block_quant(B_array) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" - for input in Bs_graph.input: - if input.name == inputB: - Bs_graph.input.remove(input) - break + return [node] # can only process 2-D matrix - scales_tensor = onnx.numpy_helper.from_array(scales) - scales_tensor.name = B.name + "_scales" - Bs_graph.initializer.extend([B_quant, scales_tensor]) + packed, scales, zero_points = self.int4_block_quant(b_ndarray) - input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.config.is_symmetric: - zp_tensor = onnx.numpy_helper.from_array(zero_points) - zp_tensor.name = B.name + "_zero_points" - Bs_graph.initializer.extend([zp_tensor]) - input_names.append(zp_tensor.name) + if self.config.quant_format == QuantFormat.QOperator: + b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4") + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales") + else: + b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") - kwargs = {} - rows, cols = B_array.shape - kwargs["K"] = rows - kwargs["N"] = cols - kwargs["bits"] = 4 - kwargs["block_size"] = self.config.block_size - if self.config.accuracy_level is not None: - kwargs["accuracy_level"] = self.config.accuracy_level + for input in b_graph.input: + if input.name == input_b: + b_graph.input.remove(input) + break - matmul_q4_node = onnx.helper.make_node( - "MatMulNBits", - inputs=input_names, - outputs=[node.output[0]], - name=node.name + "_Q4" if node.name else "", - domain="com.microsoft", - **kwargs, - ) + b_graph.initializer.extend([b_quant, scales_tensor]) + + output_nodes = [] + + if self.config.quant_format == QuantFormat.QOperator: + input_names = [node.input[0], b_quant.name, scales_tensor.name] + if not self.config.is_symmetric: + zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points") + input_names.append(zp_tensor.name) + b_graph.initializer.extend([zp_tensor]) + kwargs = {} + rows, cols = b_ndarray.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = 4 + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: + kwargs["accuracy_level"] = self.config.accuracy_level + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) - logger.info(f"complete quantization of {node.name} ...") + output_nodes.append(matmul_q4_node) + else: + dq_input_names = [b_quant.name, scales_tensor.name] + dq_output_names = [b_quant.name + "_output"] + matmul_input_names = [node.input[0], dq_output_names[0]] + matmul_output_names = [node.output[0]] + if not self.config.is_symmetric: + zp_tensor = onnx.helper.make_tensor( + b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True + ) + dq_input_names.append(zp_tensor.name) + b_graph.initializer.extend([zp_tensor]) + dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + dq_node = onnx.helper.make_node( + "DequantizeLinear", + inputs=dq_input_names, + outputs=dq_output_names, + name=node.name + "_DQ_Q4" if node.name else "", + **dq_kwargs, + ) + matmul_node = onnx.helper.make_node( + "MatMul", + inputs=matmul_input_names, + outputs=matmul_output_names, + name=node.name + "_matmul_Q4" if node.name else "", + ) + output_nodes.extend([dq_node, matmul_node]) - return matmul_q4_node + logger.info(f"complete quantization of {node.name} ...") + return output_nodes class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" + """ + Perform 4b quantization of constant MatMul weights. + If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the + MatMul node. + If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is + replaced by the DequantizeLinear + MatMul nodes. + """ def __init__( self, @@ -475,7 +577,8 @@ def __init__( is_symmetric: bool = False, accuracy_level: int | None = None, nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, + quant_format=QuantFormat.QOperator, + algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: nodes_to_exclude = [] @@ -488,7 +591,10 @@ def __init__( self.node_quantizer = None if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + block_size=block_size, + is_symmetric=is_symmetric, + accuracy_level=accuracy_level, + quant_format=quant_format, ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": @@ -526,15 +632,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - out_node = None + out_nodes = [] if node.name in self.nodes_to_exclude: logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - out_node = node + out_nodes = [node] elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": - out_node = self.node_quantizer.quantize(node, graph_stack) + out_nodes = self.node_quantizer.quantize(node, graph_stack) else: - out_node = self.node_quantizer.quantize(node, graph_stack) - new_nodes.append(out_node) + out_nodes = self.node_quantizer.quantize(node, graph_stack) + new_nodes.extend(out_nodes) graph.ClearField("node") graph.node.extend(new_nodes) @@ -688,6 +794,15 @@ def parse_args(): default=[], help="Specify the nodes to be excluded from quantization with node names", ) + parser.add_argument( + "--quant_format", + default="QOperator", + type=QuantFormat, + choices=list(QuantFormat), + help="QuantFormat {QOperator, QDQ}" + "QOperator format quantizes the model with quantized operators directly." + "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", + ) return parser.parse_args() @@ -699,6 +814,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model + quant_format = args.quant_format if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") @@ -713,7 +829,10 @@ def parse_args(): quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( - block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + block_size=args.block_size, + is_symmetric=args.symmetric, + accuracy_level=args.accuracy_level, + quant_format=quant_format, ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig() diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 88e5052db4e2e..4cc8a0c151d14 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -14,7 +14,7 @@ import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type from onnxruntime.quantization import quant_utils @@ -105,8 +105,9 @@ def make_matmul( [output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + # blocked quantization requires DQ op set >= 21 + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 # use stable onnx ir version onnx.save(model, output_model_path) @@ -116,9 +117,12 @@ def quant_test( data_reader: TestDataFeeds, block_size: int, is_symmetric: bool, + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, ): + use_qdq = quant_format == quant_utils.QuantFormat.QDQ + name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits" model_int4_path = str( - Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + Path(self._tmp_model_dir.name).joinpath(f"{name_prefix}_{block_size}_{is_symmetric}.onnx").absolute() ) # Quantize fp32 model to int4 model @@ -126,15 +130,33 @@ def quant_test( model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric + block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format ) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) - quant_nodes = {"MatMulNBits": 1} + quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} check_op_type_count(self, model_int4_path, **quant_nodes) + if use_qdq: + dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4 + dqnode_io_qtypes = ( + { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } + if is_symmetric + else { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + ) + check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes) + data_reader.rewind() try: @@ -211,6 +233,26 @@ def test_quantize_matmul_int4_offsets(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_symmetric_qdq(self): + np.random.seed(13) + + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=True) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_offsets_qdq(self): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ) + @unittest.skipIf( find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" )