From 3274b34d361dbc5568bf43215fa4f915030f2f32 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Fri, 1 Sep 2023 15:37:17 +0800 Subject: [PATCH 01/19] add weight only quantize Signed-off-by: yuwenzho --- .../python/tools/quantization/__init__.py | 5 +- .../quantization/quantize_weight_only.py | 198 ++++++++++++++++++ .../quantization/test_quantize_weight_only.py | 127 +++++++++++ 3 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/python/tools/quantization/quantize_weight_only.py create mode 100644 onnxruntime/test/python/quantization/test_quantize_weight_only.py diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 170c0928fee23..223268350e2c5 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -14,4 +14,7 @@ from .quantize import quantize # noqa: F401 from .quantize import quantize_dynamic # noqa: F401 from .quantize import quantize_static # noqa: F401 -from .shape_inference import quant_pre_process # noqa: F401 +from .quantize_weight_only import RTNWeightOnlyQuantConfig # noqa: F401 +from .quantize_weight_only import GPTQWeightOnlyQuantConfig # noqa: F401 +from .quantize_weight_only import quantize_weight_only # noqa: F401 +from .shape_inference import quant_pre_process # noqa: F401 \ No newline at end of file diff --git a/onnxruntime/python/tools/quantization/quantize_weight_only.py b/onnxruntime/python/tools/quantization/quantize_weight_only.py new file mode 100644 index 0000000000000..6c7fc4d63a00c --- /dev/null +++ b/onnxruntime/python/tools/quantization/quantize_weight_only.py @@ -0,0 +1,198 @@ +import copy +import logging +import importlib +from pathlib import Path +from .calibrate import CalibrationDataReader +from .quant_utils import load_model_with_shape_infer + +class WeightOnlyQuantConfig: + def __init__( + self, + algorithm, + group_size=32, + scheme="sym", + use_external_data_format=False, + ): + """This is the Base class for Weight Only Quant Configuration. + + Args: + algorithm: + weight only quantize algorithm name. + group_size: + how many elements share one scale/zp. -1 indicates the per-channel + quantization per output channel. + scheme: + symmetrize or asymmetric calibration data for weights. + use_external_data_format: + option used for large size (>2GB) model. Set to False by default. + """ + """This is the Base class for Weight Only Quant Configuration. + + Args: + algorithm: + weight only quantize algorithm name. + """ + self.algorithm = algorithm + self.group_size = group_size + self.scheme = scheme + self.use_external_data_format = use_external_data_format + +class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + group_size=32, + scheme="sym", + use_external_data_format=False, + ): + """ + This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. + RTN is the most straightforward way to quantize weight using scale maps. + + Args: + group_size: + how many elements share one scale/zp. -1 indicates the per-channel + quantization per output channel. + scheme: + symmetrize or asymmetric calibration data for weights. + use_external_data_format: + option used for large size (>2GB) model. Set to False by default. + """ + super().__init__( + algorithm="RTN", + group_size=group_size, + scheme=scheme, + use_external_data_format=use_external_data_format + ) + +class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + calibration_data_reader: CalibrationDataReader, + group_size=32, + scheme="asym", + percdamp=.01, + blocksize=128, + actorder=False, + mse=False, + perchannel=True, + use_external_data_format=False, + ): + """ + This is a class for GPTQ algorithm Weight Only Quant Configuration. + GPTQ algorithm provides more accurate quantization but requires more computational resources. + + Args: + calibration_data_reader: + a calibration data reader. It enumerates calibration data and generates inputs for the original model. + group_size: + how many elements share one scale/zp. -1 indicates the per-channel + quantization per output channel. + scheme: + symmetrize or asymmetric calibration data for weights. + percdamp: + percent of the average Hessian diagonal to use for dampening. + blocksize (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + actorder (bool, optional): + whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): + whether get scale and zero point with mse error. + perchannel (bool, optional): + whether quantize weight per-channel. + use_external_data_format: + option used for large size (>2GB) model. Set to False by default. + """ + super().__init__( + algorithm="GPTQ", + group_size=group_size, + scheme=scheme, + use_external_data_format=use_external_data_format, + ) + self.calibration_data_reader = calibration_data_reader + self.percdamp = percdamp + self.blocksize = blocksize + self.actorder = actorder + self.mse = mse + self.perchannel = perchannel + +def _generate_weight_only_node_config(model, group_size, scheme): + """Generate weight only quant configuration for nodes. + + Args: + model: + onnx.ModelProto. + group_size: + how many elements share one scale/zp. -1 indicates the per-channel + quantization per output channel. + scheme: + symmetrize or asymmetric calibration data for weights. + + Returns: + dict: weight only quant configuration for nodes. + """ + weight_only_node_config = {} + template_config = {'weight': {"bits": 4, "group_size": group_size, "scheme": scheme}} + for node in model.graph.node: + if node.op_type in ["MatMul"]: # TODO: enable Gemm op support + weight_only_node_config[node.name] = template_config + return weight_only_node_config + + +def quantize_weight_only( + model_input: Path, + model_output: Path, + weight_only_config: WeightOnlyQuantConfig, +): + """Weight Only Quantize a model with WeightOnlyQuantConfig. Please refer to + https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md + for more details on weight only quantization. + + Args: + model_input (Path): Path to the model to weight only quantize. + model_output (Path): Path to save the quantized model. + weight_only_config (WeightOnlyQuantConfig): Weight Only Quantization Configuration. + + Raises: + RuntimeError: Raise RuntimeError if neural-compressor is not correctly installed. + """ + try: + importlib.import_module("neural_compressor.adaptor.ox_utils.weight_only") + except Exception as e: + logging.error(f"{e}.") + raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e + + def inc_dataloader(): + data_reader = copy.deepcopy(weight_only_config.calibration_data_reader) + for data in data_reader: + yield data, None + + model = load_model_with_shape_infer(Path(model_input)) + scheme = weight_only_config.scheme + group_size = weight_only_config.group_size + weight_only_node_config = _generate_weight_only_node_config(model, group_size, scheme) + + algorithm = weight_only_config.algorithm + if algorithm == "RTN": + from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + model = rtn_quantize(model=model_input, + tune_cfg=weight_only_node_config) + elif algorithm == "GPTQ": + from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize + percdamp = weight_only_config.percdamp + blocksize = weight_only_config.blocksize + actorder = weight_only_config.actorder + mse = weight_only_config.mse + perchannel = weight_only_config.perchannel + dataloader = inc_dataloader() + + model = gptq_quantize(model=model_input, + tune_cfg=weight_only_node_config, + dataloader=dataloader, + n_samples=-1, + percdamp=percdamp, + blocksize=blocksize, + actorder=actorder, + mse=mse, + perchannel=perchannel) + + model.save_model_to_file(model_output, weight_only_config.use_external_data_format) diff --git a/onnxruntime/test/python/quantization/test_quantize_weight_only.py b/onnxruntime/test/python/quantization/test_quantize_weight_only.py new file mode 100644 index 0000000000000..5fb1201788623 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantize_weight_only.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path + +import numpy as np +import onnx +from onnx import TensorProto, helper +from onnxruntime.quantization.onnx_model import ONNXModel +from op_test_utils import check_model_correctness, input_feeds_neg_one_zero_one + +from onnxruntime.quantization import ( + RTNWeightOnlyQuantConfig, + GPTQWeightOnlyQuantConfig, + quantize_weight_only +) + +def construct_model(output_model_path): + # (input) + # | + # Mul + # | + # MatMul + # | + # (output) + initializers = [] + + # make mul node + mul_data = np.random.normal(0, 0.1, [1, 10]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data")) + mul_node = onnx.helper.make_node("Mul", ["input", "mul.data"], ["mul.output"], "Mul_0") + + # make matmul node + matmul_weight = np.random.normal(0, 0.1, [10, 1]).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight")) + matmul_node = onnx.helper.make_node("MatMul", + ["mul.output", "matmul.weight"], + ["output"], + "MatMul_1") + + # make graph + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1]) + graph_name = "weight_only_quant_test" + graph = helper.make_graph( + [mul_node, matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 13)] + ) + model.ir_version = onnx.IR_VERSION + + onnx.save(model, output_model_path) + +class TestWeightOnlyQuantization(unittest.TestCase): + @classmethod + def setUpClass(cls): + # TODO: there will be a refactor to handle all those temporary directories. + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.quant.save.as.external") + cls._model_fp32_path = str(Path(cls._tmp_model_dir.name) / "fp32.onnx") + cls._model_weight_only_path = str(Path(cls._tmp_model_dir.name) / "fp32.weight_only_quant.onnx") + np.random.seed(1) + construct_model(cls._model_fp32_path) + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + @unittest.skip( + "Skip failed test in Python Packaging Test Pipeline." + "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" + ) + def test_quantize_weight_only_rtn(self): + if not find_spec("neural_compressor"): + self.skipTest("skip test_quantize_weight_only_rtn since neural_compressor is not installed") + + weight_only_config = RTNWeightOnlyQuantConfig() + quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) + check_model_correctness( + self, + self._model_fp32_path, + self._model_weight_only_path, + {"input": np.random.rand(1, 10).astype(np.float32)}, + ) + + model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) + model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path)) + self.assertNotEqual(model_fp32.get_initializer("matmul.weight"), + model_weight_only.get_initializer("matmul.weight")) + + + @unittest.skip( + "Skip failed test in Python Packaging Test Pipeline." + "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" + ) + def test_quantize_weight_only_gptq(self): + if not find_spec("neural_compressor"): + self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed") + + data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 10]}) + weight_only_config = GPTQWeightOnlyQuantConfig(data_reader) + quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) + check_model_correctness( + self, + self._model_fp32_path, + self._model_weight_only_path, + {"input": np.random.rand(1, 10).astype(np.float32)}, + ) + + model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) + model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path)) + self.assertNotEqual(model_fp32.get_initializer("matmul.weight"), + model_weight_only.get_initializer("matmul.weight")) + +if __name__ == '__main__': + unittest.main() From 2f867e2657636af5ef592a8f5c26a06e4b293601 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 20 Sep 2023 17:35:21 +0800 Subject: [PATCH 02/19] update inc API usage Signed-off-by: yuwenzho --- .../quantization/quantize_weight_only.py | 20 ++++++++++++++----- .../quantization/test_quantize_weight_only.py | 12 +++++------ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quantize_weight_only.py b/onnxruntime/python/tools/quantization/quantize_weight_only.py index 6c7fc4d63a00c..2ad4c819270c3 100644 --- a/onnxruntime/python/tools/quantization/quantize_weight_only.py +++ b/onnxruntime/python/tools/quantization/quantize_weight_only.py @@ -2,6 +2,7 @@ import logging import importlib from pathlib import Path +from packaging import version from .calibrate import CalibrationDataReader from .quant_utils import load_model_with_shape_infer @@ -42,6 +43,7 @@ def __init__( self, group_size=32, scheme="sym", + ratios={}, use_external_data_format=False, ): """ @@ -63,6 +65,7 @@ def __init__( scheme=scheme, use_external_data_format=use_external_data_format ) + self.ratios = ratios class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( @@ -131,9 +134,9 @@ def _generate_weight_only_node_config(model, group_size, scheme): dict: weight only quant configuration for nodes. """ weight_only_node_config = {} - template_config = {'weight': {"bits": 4, "group_size": group_size, "scheme": scheme}} + template_config = {"bits": 4, "group_size": group_size, "scheme": scheme} for node in model.graph.node: - if node.op_type in ["MatMul"]: # TODO: enable Gemm op support + if node.op_type in ["MatMul"]: weight_only_node_config[node.name] = template_config return weight_only_node_config @@ -156,11 +159,15 @@ def quantize_weight_only( RuntimeError: Raise RuntimeError if neural-compressor is not correctly installed. """ try: - importlib.import_module("neural_compressor.adaptor.ox_utils.weight_only") + importlib.import_module("neural_compressor") except Exception as e: logging.error(f"{e}.") raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e + import neural_compressor + assert version.parse(neural_compressor.__version__) >= version.parse("2.3.0"), \ + "Require neural-compressor >= 2.3.0 to support weight only quantization!" + def inc_dataloader(): data_reader = copy.deepcopy(weight_only_config.calibration_data_reader) for data in data_reader: @@ -174,8 +181,11 @@ def inc_dataloader(): algorithm = weight_only_config.algorithm if algorithm == "RTN": from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + ratios = weight_only_config.ratios + model = rtn_quantize(model=model_input, - tune_cfg=weight_only_node_config) + weight_config=weight_only_node_config, + ratios=ratios) elif algorithm == "GPTQ": from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize percdamp = weight_only_config.percdamp @@ -186,7 +196,7 @@ def inc_dataloader(): dataloader = inc_dataloader() model = gptq_quantize(model=model_input, - tune_cfg=weight_only_node_config, + weight_config=weight_only_node_config, dataloader=dataloader, n_samples=-1, percdamp=percdamp, diff --git a/onnxruntime/test/python/quantization/test_quantize_weight_only.py b/onnxruntime/test/python/quantization/test_quantize_weight_only.py index 5fb1201788623..e2146a98b0f43 100644 --- a/onnxruntime/test/python/quantization/test_quantize_weight_only.py +++ b/onnxruntime/test/python/quantization/test_quantize_weight_only.py @@ -33,12 +33,12 @@ def construct_model(output_model_path): initializers = [] # make mul node - mul_data = np.random.normal(0, 0.1, [1, 10]).astype(np.float32) + mul_data = np.random.normal(0, 0.1, [1, 32]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data")) mul_node = onnx.helper.make_node("Mul", ["input", "mul.data"], ["mul.output"], "Mul_0") # make matmul node - matmul_weight = np.random.normal(0, 0.1, [10, 1]).astype(np.float32) + matmul_weight = np.random.normal(0, 0.1, [32, 1]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight")) matmul_node = onnx.helper.make_node("MatMul", ["mul.output", "matmul.weight"], @@ -46,7 +46,7 @@ def construct_model(output_model_path): "MatMul_1") # make graph - input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 32]) output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1]) graph_name = "weight_only_quant_test" graph = helper.make_graph( @@ -91,7 +91,7 @@ def test_quantize_weight_only_rtn(self): self, self._model_fp32_path, self._model_weight_only_path, - {"input": np.random.rand(1, 10).astype(np.float32)}, + {"input": np.random.rand(1, 32).astype(np.float32)}, ) model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) @@ -108,14 +108,14 @@ def test_quantize_weight_only_gptq(self): if not find_spec("neural_compressor"): self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed") - data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 10]}) + data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 32]}) weight_only_config = GPTQWeightOnlyQuantConfig(data_reader) quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) check_model_correctness( self, self._model_fp32_path, self._model_weight_only_path, - {"input": np.random.rand(1, 10).astype(np.float32)}, + {"input": np.random.rand(1, 32).astype(np.float32)}, ) model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) From 32f7eae9d714af8c772b0a1fdfedde9926e10fba Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 10 Oct 2023 15:25:59 +0800 Subject: [PATCH 03/19] update format Signed-off-by: yuwenzho --- .../python/tools/quantization/__init__.py | 8 +- .../quantization/quantize_weight_only.py | 89 ++++++++++--------- .../quantization/test_quantize_weight_only.py | 39 ++++---- 3 files changed, 70 insertions(+), 66 deletions(-) diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 223268350e2c5..2766509187194 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -14,7 +14,7 @@ from .quantize import quantize # noqa: F401 from .quantize import quantize_dynamic # noqa: F401 from .quantize import quantize_static # noqa: F401 -from .quantize_weight_only import RTNWeightOnlyQuantConfig # noqa: F401 -from .quantize_weight_only import GPTQWeightOnlyQuantConfig # noqa: F401 -from .quantize_weight_only import quantize_weight_only # noqa: F401 -from .shape_inference import quant_pre_process # noqa: F401 \ No newline at end of file +from .quantize_weight_only import GPTQWeightOnlyQuantConfig # noqa: F401 +from .quantize_weight_only import RTNWeightOnlyQuantConfig # noqa: F401 +from .quantize_weight_only import quantize_weight_only # noqa: F401 +from .shape_inference import quant_pre_process # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/quantize_weight_only.py b/onnxruntime/python/tools/quantization/quantize_weight_only.py index 2ad4c819270c3..516ad79731740 100644 --- a/onnxruntime/python/tools/quantization/quantize_weight_only.py +++ b/onnxruntime/python/tools/quantization/quantize_weight_only.py @@ -1,16 +1,19 @@ import copy -import logging import importlib +import logging from pathlib import Path + from packaging import version + from .calibrate import CalibrationDataReader from .quant_utils import load_model_with_shape_infer + class WeightOnlyQuantConfig: def __init__( - self, - algorithm, - group_size=32, + self, + algorithm, + group_size=32, scheme="sym", use_external_data_format=False, ): @@ -20,7 +23,7 @@ def __init__( algorithm: weight only quantize algorithm name. group_size: - how many elements share one scale/zp. -1 indicates the per-channel + how many elements share one scale/zp. -1 indicates the per-channel quantization per output channel. scheme: symmetrize or asymmetric calibration data for weights. @@ -30,7 +33,7 @@ def __init__( """This is the Base class for Weight Only Quant Configuration. Args: - algorithm: + algorithm: weight only quantize algorithm name. """ self.algorithm = algorithm @@ -38,42 +41,43 @@ def __init__( self.scheme = scheme self.use_external_data_format = use_external_data_format + class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, group_size=32, scheme="sym", - ratios={}, + ratios=None, use_external_data_format=False, ): """ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. - RTN is the most straightforward way to quantize weight using scale maps. + RTN is the most straightforward way to quantize weight using scale maps. Args: group_size: - how many elements share one scale/zp. -1 indicates the per-channel + how many elements share one scale/zp. -1 indicates the per-channel quantization per output channel. scheme: symmetrize or asymmetric calibration data for weights. use_external_data_format: option used for large size (>2GB) model. Set to False by default. """ + if ratios is None: + ratios = {} super().__init__( - algorithm="RTN", - group_size=group_size, - scheme=scheme, - use_external_data_format=use_external_data_format + algorithm="RTN", group_size=group_size, scheme=scheme, use_external_data_format=use_external_data_format ) self.ratios = ratios + class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( - self, + self, calibration_data_reader: CalibrationDataReader, group_size=32, scheme="asym", - percdamp=.01, + percdamp=0.01, blocksize=128, actorder=False, mse=False, @@ -81,18 +85,18 @@ def __init__( use_external_data_format=False, ): """ - This is a class for GPTQ algorithm Weight Only Quant Configuration. + This is a class for GPTQ algorithm Weight Only Quant Configuration. GPTQ algorithm provides more accurate quantization but requires more computational resources. Args: - calibration_data_reader: + calibration_data_reader: a calibration data reader. It enumerates calibration data and generates inputs for the original model. group_size: - how many elements share one scale/zp. -1 indicates the per-channel + how many elements share one scale/zp. -1 indicates the per-channel quantization per output channel. scheme: symmetrize or asymmetric calibration data for weights. - percdamp: + percdamp: percent of the average Hessian diagonal to use for dampening. blocksize (int, optional): channel number in one block to execute a GPTQ quantization iteration. @@ -106,7 +110,7 @@ def __init__( option used for large size (>2GB) model. Set to False by default. """ super().__init__( - algorithm="GPTQ", + algorithm="GPTQ", group_size=group_size, scheme=scheme, use_external_data_format=use_external_data_format, @@ -118,6 +122,7 @@ def __init__( self.mse = mse self.perchannel = perchannel + def _generate_weight_only_node_config(model, group_size, scheme): """Generate weight only quant configuration for nodes. @@ -125,10 +130,10 @@ def _generate_weight_only_node_config(model, group_size, scheme): model: onnx.ModelProto. group_size: - how many elements share one scale/zp. -1 indicates the per-channel + how many elements share one scale/zp. -1 indicates the per-channel quantization per output channel. scheme: - symmetrize or asymmetric calibration data for weights. + symmetrize or asymmetric calibration data for weights. Returns: dict: weight only quant configuration for nodes. @@ -147,7 +152,7 @@ def quantize_weight_only( weight_only_config: WeightOnlyQuantConfig, ): """Weight Only Quantize a model with WeightOnlyQuantConfig. Please refer to - https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md + https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md for more details on weight only quantization. Args: @@ -163,11 +168,13 @@ def quantize_weight_only( except Exception as e: logging.error(f"{e}.") raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e - + import neural_compressor - assert version.parse(neural_compressor.__version__) >= version.parse("2.3.0"), \ - "Require neural-compressor >= 2.3.0 to support weight only quantization!" - + + assert version.parse(neural_compressor.__version__) >= version.parse( + "2.3.0" + ), "Require neural-compressor >= 2.3.0 to support weight only quantization!" + def inc_dataloader(): data_reader = copy.deepcopy(weight_only_config.calibration_data_reader) for data in data_reader: @@ -181,13 +188,13 @@ def inc_dataloader(): algorithm = weight_only_config.algorithm if algorithm == "RTN": from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + ratios = weight_only_config.ratios - model = rtn_quantize(model=model_input, - weight_config=weight_only_node_config, - ratios=ratios) + model = rtn_quantize(model=model_input, weight_config=weight_only_node_config, ratios=ratios) elif algorithm == "GPTQ": from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize + percdamp = weight_only_config.percdamp blocksize = weight_only_config.blocksize actorder = weight_only_config.actorder @@ -195,14 +202,16 @@ def inc_dataloader(): perchannel = weight_only_config.perchannel dataloader = inc_dataloader() - model = gptq_quantize(model=model_input, - weight_config=weight_only_node_config, - dataloader=dataloader, - n_samples=-1, - percdamp=percdamp, - blocksize=blocksize, - actorder=actorder, - mse=mse, - perchannel=perchannel) - + model = gptq_quantize( + model=model_input, + weight_config=weight_only_node_config, + dataloader=dataloader, + n_samples=-1, + percdamp=percdamp, + blocksize=blocksize, + actorder=actorder, + mse=mse, + perchannel=perchannel, + ) + model.save_model_to_file(model_output, weight_only_config.use_external_data_format) diff --git a/onnxruntime/test/python/quantization/test_quantize_weight_only.py b/onnxruntime/test/python/quantization/test_quantize_weight_only.py index e2146a98b0f43..88e057bc58557 100644 --- a/onnxruntime/test/python/quantization/test_quantize_weight_only.py +++ b/onnxruntime/test/python/quantization/test_quantize_weight_only.py @@ -13,14 +13,11 @@ import numpy as np import onnx from onnx import TensorProto, helper -from onnxruntime.quantization.onnx_model import ONNXModel from op_test_utils import check_model_correctness, input_feeds_neg_one_zero_one -from onnxruntime.quantization import ( - RTNWeightOnlyQuantConfig, - GPTQWeightOnlyQuantConfig, - quantize_weight_only -) +from onnxruntime.quantization import GPTQWeightOnlyQuantConfig, RTNWeightOnlyQuantConfig, quantize_weight_only +from onnxruntime.quantization.onnx_model import ONNXModel + def construct_model(output_model_path): # (input) @@ -31,7 +28,7 @@ def construct_model(output_model_path): # | # (output) initializers = [] - + # make mul node mul_data = np.random.normal(0, 0.1, [1, 32]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data")) @@ -40,10 +37,7 @@ def construct_model(output_model_path): # make matmul node matmul_weight = np.random.normal(0, 0.1, [32, 1]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight")) - matmul_node = onnx.helper.make_node("MatMul", - ["mul.output", "matmul.weight"], - ["output"], - "MatMul_1") + matmul_node = onnx.helper.make_node("MatMul", ["mul.output", "matmul.weight"], ["output"], "MatMul_1") # make graph input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 32]) @@ -56,13 +50,12 @@ def construct_model(output_model_path): [output_tensor], initializer=initializers, ) - model = helper.make_model( - graph, opset_imports=[helper.make_opsetid("", 13)] - ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = onnx.IR_VERSION onnx.save(model, output_model_path) + class TestWeightOnlyQuantization(unittest.TestCase): @classmethod def setUpClass(cls): @@ -84,7 +77,7 @@ def tearDownClass(cls): def test_quantize_weight_only_rtn(self): if not find_spec("neural_compressor"): self.skipTest("skip test_quantize_weight_only_rtn since neural_compressor is not installed") - + weight_only_config = RTNWeightOnlyQuantConfig() quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) check_model_correctness( @@ -96,9 +89,9 @@ def test_quantize_weight_only_rtn(self): model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path)) - self.assertNotEqual(model_fp32.get_initializer("matmul.weight"), - model_weight_only.get_initializer("matmul.weight")) - + self.assertNotEqual( + model_fp32.get_initializer("matmul.weight"), model_weight_only.get_initializer("matmul.weight") + ) @unittest.skip( "Skip failed test in Python Packaging Test Pipeline." @@ -107,7 +100,7 @@ def test_quantize_weight_only_rtn(self): def test_quantize_weight_only_gptq(self): if not find_spec("neural_compressor"): self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed") - + data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 32]}) weight_only_config = GPTQWeightOnlyQuantConfig(data_reader) quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) @@ -120,8 +113,10 @@ def test_quantize_weight_only_gptq(self): model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path)) - self.assertNotEqual(model_fp32.get_initializer("matmul.weight"), - model_weight_only.get_initializer("matmul.weight")) + self.assertNotEqual( + model_fp32.get_initializer("matmul.weight"), model_weight_only.get_initializer("matmul.weight") + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From a8382acc2d1e848ec36c2dcbb07857a6cdee7d47 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Fri, 10 Nov 2023 15:05:57 +0800 Subject: [PATCH 04/19] add accuracy_level attr --- .../quantization/quantize_weight_only.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quantize_weight_only.py b/onnxruntime/python/tools/quantization/quantize_weight_only.py index 516ad79731740..2638f31befdc0 100644 --- a/onnxruntime/python/tools/quantization/quantize_weight_only.py +++ b/onnxruntime/python/tools/quantization/quantize_weight_only.py @@ -15,6 +15,7 @@ def __init__( algorithm, group_size=32, scheme="sym", + accuracy_level=0, use_external_data_format=False, ): """This is the Base class for Weight Only Quant Configuration. @@ -27,6 +28,8 @@ def __init__( quantization per output channel. scheme: symmetrize or asymmetric calibration data for weights. + accuracy_level: + support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. use_external_data_format: option used for large size (>2GB) model. Set to False by default. """ @@ -40,6 +43,7 @@ def __init__( self.group_size = group_size self.scheme = scheme self.use_external_data_format = use_external_data_format + self.accuracy_level = accuracy_level class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): @@ -47,6 +51,7 @@ def __init__( self, group_size=32, scheme="sym", + accuracy_level=0, ratios=None, use_external_data_format=False, ): @@ -60,13 +65,19 @@ def __init__( quantization per output channel. scheme: symmetrize or asymmetric calibration data for weights. + accuracy_level: + support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. use_external_data_format: option used for large size (>2GB) model. Set to False by default. """ if ratios is None: ratios = {} super().__init__( - algorithm="RTN", group_size=group_size, scheme=scheme, use_external_data_format=use_external_data_format + algorithm="RTN", + group_size=group_size, + scheme=scheme, + accuracy_level=accuracy_level, + use_external_data_format=use_external_data_format ) self.ratios = ratios @@ -82,6 +93,7 @@ def __init__( actorder=False, mse=False, perchannel=True, + accuracy_level=0, use_external_data_format=False, ): """ @@ -106,6 +118,8 @@ def __init__( whether get scale and zero point with mse error. perchannel (bool, optional): whether quantize weight per-channel. + accuracy_level: + support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. use_external_data_format: option used for large size (>2GB) model. Set to False by default. """ @@ -113,6 +127,7 @@ def __init__( algorithm="GPTQ", group_size=group_size, scheme=scheme, + accuracy_level=accuracy_level, use_external_data_format=use_external_data_format, ) self.calibration_data_reader = calibration_data_reader @@ -183,6 +198,7 @@ def inc_dataloader(): model = load_model_with_shape_infer(Path(model_input)) scheme = weight_only_config.scheme group_size = weight_only_config.group_size + accuracy_level = weight_only_config.accuracy_level weight_only_node_config = _generate_weight_only_node_config(model, group_size, scheme) algorithm = weight_only_config.algorithm @@ -191,7 +207,12 @@ def inc_dataloader(): ratios = weight_only_config.ratios - model = rtn_quantize(model=model_input, weight_config=weight_only_node_config, ratios=ratios) + model = rtn_quantize( + model=model_input, + weight_config=weight_only_node_config, + ratios=ratios, + accuracy_level=accuracy_level, + ) elif algorithm == "GPTQ": from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize @@ -212,6 +233,7 @@ def inc_dataloader(): actorder=actorder, mse=mse, perchannel=perchannel, + accuracy_level=accuracy_level, ) model.save_model_to_file(model_output, weight_only_config.use_external_data_format) From ee73812f1775087ee67d924fe4f8039c2b73f7b4 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 22 Nov 2023 20:17:22 +0800 Subject: [PATCH 05/19] update usage of RTN & GPTQ algorithm Signed-off-by: yuwenzho --- .../python/tools/quantization/__init__.py | 5 +- .../quantization/matmul_4bits_quantizer.py | 214 +++++++++++++++- .../quantization/quantize_weight_only.py | 239 ------------------ .../quantization/test_op_matmul_4bits.py | 76 +++++- .../quantization/test_quantize_weight_only.py | 122 --------- 5 files changed, 277 insertions(+), 379 deletions(-) delete mode 100644 onnxruntime/python/tools/quantization/quantize_weight_only.py delete mode 100644 onnxruntime/test/python/quantization/test_quantize_weight_only.py diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 2766509187194..3bc055d1f063a 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -5,6 +5,8 @@ MinMaxCalibrater, create_calibrator, ) +from .matmul_4bits_quantizer import GPTQWeightOnlyQuantConfig # noqa: F401 +from .matmul_4bits_quantizer import RTNWeightOnlyQuantConfig # noqa: F401 from .matmul_weight4_quantizer import MatMulWeight4Quantizer # noqa: F401 from .qdq_quantizer import QDQQuantizer # noqa: F401 from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401 @@ -14,7 +16,4 @@ from .quantize import quantize # noqa: F401 from .quantize import quantize_dynamic # noqa: F401 from .quantize import quantize_static # noqa: F401 -from .quantize_weight_only import GPTQWeightOnlyQuantConfig # noqa: F401 -from .quantize_weight_only import RTNWeightOnlyQuantConfig # noqa: F401 -from .quantize_weight_only import quantize_weight_only # noqa: F401 from .shape_inference import quant_pre_process # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 1c3c212b54fa4..fdcc73746e591 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -5,6 +5,8 @@ # -------------------------------------------------------------------------- import argparse +import copy +import importlib import logging import os from typing import List, Tuple @@ -13,9 +15,11 @@ import numpy.typing as npt import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto +from packaging import version from onnxruntime.capi._pybind_state import quantize_matmul_4bits +from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel from .quant_utils import attribute_to_kwarg @@ -23,16 +27,122 @@ logger = logging.getLogger(__name__) +class WeightOnlyQuantConfig: + def __init__( + self, + algorithm, + model_path, + accuracy_level=0, + ): + """This is the Base class for Weight Only Quant Configuration. + + Args: + algorithm: + weight only quantize algorithm name. + model_path: + path of the model to do 4b quantization. + accuracy_level: + support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. + """ + self.algorithm = algorithm + self.model_path = model_path + self.accuracy_level = accuracy_level + + +class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + model_path, + accuracy_level=0, + ratios=None, + ): + """ + This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. + RTN is the most straightforward way to quantize weight using scale maps. + + Args: + model_path: + path of the model to do 4b quantization. + accuracy_level: + support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. + ratios: + percentile of clip. Defaults to {}. + """ + if ratios is None: + ratios = {} + super().__init__( + algorithm="RTN", + model_path=model_path, + accuracy_level=accuracy_level, + ) + self.ratios = ratios + + +class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + model_path, + calibration_data_reader: CalibrationDataReader, + percdamp=0.01, + blocksize=128, + actorder=False, + mse=False, + perchannel=True, + accuracy_level=0, + ): + """ + This is a class for GPTQ algorithm Weight Only Quant Configuration. + GPTQ algorithm provides more accurate quantization but requires more computational resources. + + Args: + model_path: + path of the model to do 4b quantization. + calibration_data_reader: + a calibration data reader. It enumerates calibration data and generates inputs for the original model. + percdamp: + percent of the average Hessian diagonal to use for dampening. + blocksize (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + actorder (bool, optional): + whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): + whether get scale and zero point with mse error. + perchannel (bool, optional): + whether quantize weight per-channel. + accuracy_level: + support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. + """ + super().__init__( + algorithm="GPTQ", + model_path=model_path, + accuracy_level=accuracy_level, + ) + self.calibration_data_reader = calibration_data_reader + self.percdamp = percdamp + self.blocksize = blocksize + self.actorder = actorder + self.mse = mse + self.perchannel = perchannel + + class MatMul4BitsQuantizer: """Perform 4b quantization of constant MatMul weights""" - def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None): + def __init__( + self, + model: ModelProto, + block_size: int, + is_symmetric: bool, + nodes_to_exclude=None, + algo_config: WeightOnlyQuantConfig = None, + ): if nodes_to_exclude is None: nodes_to_exclude = [] self.model = ONNXModel(model) self.block_size = block_size self.is_symmetric = is_symmetric self.nodes_to_exclude = set(nodes_to_exclude) + self.algo_config = algo_config @staticmethod def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: @@ -165,20 +275,96 @@ def _process_subgraph(self, graph_stack: List[GraphProto]): graph_stack.pop() return graph + def _generate_q4_node_config(self): + """Generate weight only quant configuration for nodes.""" + q4_node_config = {} + template_config = {"bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym"} + for node in self.model.model.graph.node: + if node.op_type in ["MatMul"]: + q4_node_config[node.name] = template_config + return q4_node_config + + def int4_quant_algo(self): + """4b quantize a model with RTN or GPTQ algorithm. Please refer to + https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md + for more details on weight only quantization using IntelĀ® Neural Compressor. + """ + + def inc_dataloader(): + data_reader = copy.deepcopy(self.algo_config.calibration_data_reader) + for data in data_reader: + yield data, None + + accuracy_level = self.algo_config.accuracy_level + weight_only_node_config = self._generate_q4_node_config() + + algorithm = self.algo_config.algorithm + if algorithm == "RTN": + from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + + ratios = self.algo_config.ratios + + self.model = rtn_quantize( + model=self.algo_config.model_path, + weight_config=weight_only_node_config, + ratios=ratios, + accuracy_level=accuracy_level, + ) + elif algorithm == "GPTQ": + from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize + + percdamp = self.algo_config.percdamp + blocksize = self.algo_config.blocksize + actorder = self.algo_config.actorder + mse = self.algo_config.mse + perchannel = self.algo_config.perchannel + dataloader = inc_dataloader() + + self.model = gptq_quantize( + model=self.algo_config.model_path, + weight_config=weight_only_node_config, + dataloader=dataloader, + n_samples=-1, + percdamp=percdamp, + blocksize=blocksize, + actorder=actorder, + mse=mse, + perchannel=perchannel, + accuracy_level=accuracy_level, + ) + def process(self): - # use a stack to keep track of sub-graphs - graph_stack = [self.model.graph()] - opset_import = self.model.opset_import() - - has_ms_domain = False - for opset in opset_import: - if opset.domain == "com.microsoft": - has_ms_domain = True - if not has_ms_domain: - opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - - self._process_subgraph(graph_stack) - self.model.clean_initializers() + if self.algo_config is None: + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + else: + # use IntelĀ® Neural Compressor for RTN or GPTQ weight-only quantize algorithm + try: + importlib.import_module("neural_compressor") + except Exception as e: + logging.error(f"{e}.") + raise RuntimeError( + "neural-compressor is not correctly installed. Please check your environment." + ) from e + + import neural_compressor + + assert version.parse(neural_compressor.__version__) >= version.parse( + "2.3.0" + ), "Require neural-compressor >= 2.3.0 to support weight only quantization!" + + self.int4_quant_algo() def parse_args(): diff --git a/onnxruntime/python/tools/quantization/quantize_weight_only.py b/onnxruntime/python/tools/quantization/quantize_weight_only.py deleted file mode 100644 index 2638f31befdc0..0000000000000 --- a/onnxruntime/python/tools/quantization/quantize_weight_only.py +++ /dev/null @@ -1,239 +0,0 @@ -import copy -import importlib -import logging -from pathlib import Path - -from packaging import version - -from .calibrate import CalibrationDataReader -from .quant_utils import load_model_with_shape_infer - - -class WeightOnlyQuantConfig: - def __init__( - self, - algorithm, - group_size=32, - scheme="sym", - accuracy_level=0, - use_external_data_format=False, - ): - """This is the Base class for Weight Only Quant Configuration. - - Args: - algorithm: - weight only quantize algorithm name. - group_size: - how many elements share one scale/zp. -1 indicates the per-channel - quantization per output channel. - scheme: - symmetrize or asymmetric calibration data for weights. - accuracy_level: - support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. - use_external_data_format: - option used for large size (>2GB) model. Set to False by default. - """ - """This is the Base class for Weight Only Quant Configuration. - - Args: - algorithm: - weight only quantize algorithm name. - """ - self.algorithm = algorithm - self.group_size = group_size - self.scheme = scheme - self.use_external_data_format = use_external_data_format - self.accuracy_level = accuracy_level - - -class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): - def __init__( - self, - group_size=32, - scheme="sym", - accuracy_level=0, - ratios=None, - use_external_data_format=False, - ): - """ - This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. - RTN is the most straightforward way to quantize weight using scale maps. - - Args: - group_size: - how many elements share one scale/zp. -1 indicates the per-channel - quantization per output channel. - scheme: - symmetrize or asymmetric calibration data for weights. - accuracy_level: - support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. - use_external_data_format: - option used for large size (>2GB) model. Set to False by default. - """ - if ratios is None: - ratios = {} - super().__init__( - algorithm="RTN", - group_size=group_size, - scheme=scheme, - accuracy_level=accuracy_level, - use_external_data_format=use_external_data_format - ) - self.ratios = ratios - - -class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): - def __init__( - self, - calibration_data_reader: CalibrationDataReader, - group_size=32, - scheme="asym", - percdamp=0.01, - blocksize=128, - actorder=False, - mse=False, - perchannel=True, - accuracy_level=0, - use_external_data_format=False, - ): - """ - This is a class for GPTQ algorithm Weight Only Quant Configuration. - GPTQ algorithm provides more accurate quantization but requires more computational resources. - - Args: - calibration_data_reader: - a calibration data reader. It enumerates calibration data and generates inputs for the original model. - group_size: - how many elements share one scale/zp. -1 indicates the per-channel - quantization per output channel. - scheme: - symmetrize or asymmetric calibration data for weights. - percdamp: - percent of the average Hessian diagonal to use for dampening. - blocksize (int, optional): - channel number in one block to execute a GPTQ quantization iteration. - actorder (bool, optional): - whether rearrange Hessian matrix considering the diag's value. - mse (bool, optional): - whether get scale and zero point with mse error. - perchannel (bool, optional): - whether quantize weight per-channel. - accuracy_level: - support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. - use_external_data_format: - option used for large size (>2GB) model. Set to False by default. - """ - super().__init__( - algorithm="GPTQ", - group_size=group_size, - scheme=scheme, - accuracy_level=accuracy_level, - use_external_data_format=use_external_data_format, - ) - self.calibration_data_reader = calibration_data_reader - self.percdamp = percdamp - self.blocksize = blocksize - self.actorder = actorder - self.mse = mse - self.perchannel = perchannel - - -def _generate_weight_only_node_config(model, group_size, scheme): - """Generate weight only quant configuration for nodes. - - Args: - model: - onnx.ModelProto. - group_size: - how many elements share one scale/zp. -1 indicates the per-channel - quantization per output channel. - scheme: - symmetrize or asymmetric calibration data for weights. - - Returns: - dict: weight only quant configuration for nodes. - """ - weight_only_node_config = {} - template_config = {"bits": 4, "group_size": group_size, "scheme": scheme} - for node in model.graph.node: - if node.op_type in ["MatMul"]: - weight_only_node_config[node.name] = template_config - return weight_only_node_config - - -def quantize_weight_only( - model_input: Path, - model_output: Path, - weight_only_config: WeightOnlyQuantConfig, -): - """Weight Only Quantize a model with WeightOnlyQuantConfig. Please refer to - https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md - for more details on weight only quantization. - - Args: - model_input (Path): Path to the model to weight only quantize. - model_output (Path): Path to save the quantized model. - weight_only_config (WeightOnlyQuantConfig): Weight Only Quantization Configuration. - - Raises: - RuntimeError: Raise RuntimeError if neural-compressor is not correctly installed. - """ - try: - importlib.import_module("neural_compressor") - except Exception as e: - logging.error(f"{e}.") - raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e - - import neural_compressor - - assert version.parse(neural_compressor.__version__) >= version.parse( - "2.3.0" - ), "Require neural-compressor >= 2.3.0 to support weight only quantization!" - - def inc_dataloader(): - data_reader = copy.deepcopy(weight_only_config.calibration_data_reader) - for data in data_reader: - yield data, None - - model = load_model_with_shape_infer(Path(model_input)) - scheme = weight_only_config.scheme - group_size = weight_only_config.group_size - accuracy_level = weight_only_config.accuracy_level - weight_only_node_config = _generate_weight_only_node_config(model, group_size, scheme) - - algorithm = weight_only_config.algorithm - if algorithm == "RTN": - from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize - - ratios = weight_only_config.ratios - - model = rtn_quantize( - model=model_input, - weight_config=weight_only_node_config, - ratios=ratios, - accuracy_level=accuracy_level, - ) - elif algorithm == "GPTQ": - from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize - - percdamp = weight_only_config.percdamp - blocksize = weight_only_config.blocksize - actorder = weight_only_config.actorder - mse = weight_only_config.mse - perchannel = weight_only_config.perchannel - dataloader = inc_dataloader() - - model = gptq_quantize( - model=model_input, - weight_config=weight_only_node_config, - dataloader=dataloader, - n_samples=-1, - percdamp=percdamp, - blocksize=blocksize, - actorder=actorder, - mse=mse, - perchannel=perchannel, - accuracy_level=accuracy_level, - ) - - model.save_model_to_file(model_output, weight_only_config.use_external_data_format) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 02f51cc4fa809..2552f32eee608 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -71,13 +71,16 @@ def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> Non output_name = "output" initializers = [] - def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + def make_matmul( + input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str, node_name: str + ): weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) return onnx.helper.make_node( "MatMul", [input_name, weight_name], [output_name], + node_name, ) in_features = 52 @@ -88,6 +91,7 @@ def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_na [in_features, out_features], "linear1.weight", output_name, + "MatMul_0", ) # make graph @@ -139,6 +143,52 @@ def quant_test( else: raise exception + def quant_test_with_algo( + self, + algorithm: str, + model_fp32_path: str, + data_reader: TestDataFeeds, + block_size: int, + is_symmetric: bool, + ): + model_int4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + ) + + # Quantize fp32 model to int4 model + from onnxruntime.quantization import matmul_4bits_quantizer + + if algorithm == "RTN": + # test RTN algorithm + from onnxruntime.quantization import RTNWeightOnlyQuantConfig + + algo_config = RTNWeightOnlyQuantConfig(model_path=model_fp32_path) + elif algorithm == "GPTQ": + # test GPTQ algorithm + print("=" * 50) + from onnxruntime.quantization import GPTQWeightOnlyQuantConfig + + algo_config = GPTQWeightOnlyQuantConfig(model_path=model_fp32_path, calibration_data_reader=data_reader) + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) + quant.process() + quant.model.save_model_to_file(model_int4_path, False) + + quant_nodes = {"MatMulNBits": 1} + check_op_type_count(self, model_int4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next()) + except Exception as exception: + if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: + # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception + pass + else: + raise exception + @unittest.skipIf( find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) @@ -159,6 +209,30 @@ def test_quantize_matmul_int4_offsets(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) + @unittest.skip( + "Skip failed test in Python Packaging Test Pipeline." + "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" + ) + def test_quantize_matmul_int4_using_rtn_algo(self): + if not find_spec("neural_compressor"): + self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False) + + @unittest.skip( + "Skip failed test in Python Packaging Test Pipeline." + "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" + ) + def test_quantize_matmul_int4_using_gptq_algo(self): + if not find_spec("neural_compressor"): + self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantize_weight_only.py b/onnxruntime/test/python/quantization/test_quantize_weight_only.py deleted file mode 100644 index 88e057bc58557..0000000000000 --- a/onnxruntime/test/python/quantization/test_quantize_weight_only.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import tempfile -import unittest -from importlib.util import find_spec -from pathlib import Path - -import numpy as np -import onnx -from onnx import TensorProto, helper -from op_test_utils import check_model_correctness, input_feeds_neg_one_zero_one - -from onnxruntime.quantization import GPTQWeightOnlyQuantConfig, RTNWeightOnlyQuantConfig, quantize_weight_only -from onnxruntime.quantization.onnx_model import ONNXModel - - -def construct_model(output_model_path): - # (input) - # | - # Mul - # | - # MatMul - # | - # (output) - initializers = [] - - # make mul node - mul_data = np.random.normal(0, 0.1, [1, 32]).astype(np.float32) - initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data")) - mul_node = onnx.helper.make_node("Mul", ["input", "mul.data"], ["mul.output"], "Mul_0") - - # make matmul node - matmul_weight = np.random.normal(0, 0.1, [32, 1]).astype(np.float32) - initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight")) - matmul_node = onnx.helper.make_node("MatMul", ["mul.output", "matmul.weight"], ["output"], "MatMul_1") - - # make graph - input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 32]) - output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1]) - graph_name = "weight_only_quant_test" - graph = helper.make_graph( - [mul_node, matmul_node], - graph_name, - [input_tensor], - [output_tensor], - initializer=initializers, - ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = onnx.IR_VERSION - - onnx.save(model, output_model_path) - - -class TestWeightOnlyQuantization(unittest.TestCase): - @classmethod - def setUpClass(cls): - # TODO: there will be a refactor to handle all those temporary directories. - cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.quant.save.as.external") - cls._model_fp32_path = str(Path(cls._tmp_model_dir.name) / "fp32.onnx") - cls._model_weight_only_path = str(Path(cls._tmp_model_dir.name) / "fp32.weight_only_quant.onnx") - np.random.seed(1) - construct_model(cls._model_fp32_path) - - @classmethod - def tearDownClass(cls): - cls._tmp_model_dir.cleanup() - - @unittest.skip( - "Skip failed test in Python Packaging Test Pipeline." - "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" - ) - def test_quantize_weight_only_rtn(self): - if not find_spec("neural_compressor"): - self.skipTest("skip test_quantize_weight_only_rtn since neural_compressor is not installed") - - weight_only_config = RTNWeightOnlyQuantConfig() - quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) - check_model_correctness( - self, - self._model_fp32_path, - self._model_weight_only_path, - {"input": np.random.rand(1, 32).astype(np.float32)}, - ) - - model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) - model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path)) - self.assertNotEqual( - model_fp32.get_initializer("matmul.weight"), model_weight_only.get_initializer("matmul.weight") - ) - - @unittest.skip( - "Skip failed test in Python Packaging Test Pipeline." - "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" - ) - def test_quantize_weight_only_gptq(self): - if not find_spec("neural_compressor"): - self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed") - - data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 32]}) - weight_only_config = GPTQWeightOnlyQuantConfig(data_reader) - quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) - check_model_correctness( - self, - self._model_fp32_path, - self._model_weight_only_path, - {"input": np.random.rand(1, 32).astype(np.float32)}, - ) - - model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) - model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path)) - self.assertNotEqual( - model_fp32.get_initializer("matmul.weight"), model_weight_only.get_initializer("matmul.weight") - ) - - -if __name__ == "__main__": - unittest.main() From 21fc0c4cb7e8f0ae7d84e1623089d76382a95d1f Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 22 Nov 2023 20:19:01 +0800 Subject: [PATCH 06/19] fix typo Signed-off-by: yuwenzho --- onnxruntime/test/python/quantization/test_op_matmul_4bits.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 2552f32eee608..accb0847930af 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -165,7 +165,6 @@ def quant_test_with_algo( algo_config = RTNWeightOnlyQuantConfig(model_path=model_fp32_path) elif algorithm == "GPTQ": # test GPTQ algorithm - print("=" * 50) from onnxruntime.quantization import GPTQWeightOnlyQuantConfig algo_config = GPTQWeightOnlyQuantConfig(model_path=model_fp32_path, calibration_data_reader=data_reader) From b2b9d6636d9c8c554651cb696f9297de4546e3af Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 23 Nov 2023 15:39:34 +0800 Subject: [PATCH 07/19] update usage of RTN & GPTQ algorithm Signed-off-by: yuwenzho --- .../quantization/matmul_4bits_quantizer.py | 24 +++++++------------ .../quantization/test_op_matmul_4bits.py | 4 ++-- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index fdcc73746e591..f7684c215711d 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -31,7 +31,6 @@ class WeightOnlyQuantConfig: def __init__( self, algorithm, - model_path, accuracy_level=0, ): """This is the Base class for Weight Only Quant Configuration. @@ -39,20 +38,16 @@ def __init__( Args: algorithm: weight only quantize algorithm name. - model_path: - path of the model to do 4b quantization. accuracy_level: support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. """ self.algorithm = algorithm - self.model_path = model_path self.accuracy_level = accuracy_level class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, - model_path, accuracy_level=0, ratios=None, ): @@ -61,8 +56,6 @@ def __init__( RTN is the most straightforward way to quantize weight using scale maps. Args: - model_path: - path of the model to do 4b quantization. accuracy_level: support 0 (default fp32), 1 (optimized fp32 for intel CPU), 2 (fp16), 3 (bf16), 4 (int8). Set to 0 by default. ratios: @@ -72,7 +65,6 @@ def __init__( ratios = {} super().__init__( algorithm="RTN", - model_path=model_path, accuracy_level=accuracy_level, ) self.ratios = ratios @@ -81,7 +73,6 @@ def __init__( class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, - model_path, calibration_data_reader: CalibrationDataReader, percdamp=0.01, blocksize=128, @@ -95,8 +86,6 @@ def __init__( GPTQ algorithm provides more accurate quantization but requires more computational resources. Args: - model_path: - path of the model to do 4b quantization. calibration_data_reader: a calibration data reader. It enumerates calibration data and generates inputs for the original model. percdamp: @@ -114,7 +103,6 @@ def __init__( """ super().__init__( algorithm="GPTQ", - model_path=model_path, accuracy_level=accuracy_level, ) self.calibration_data_reader = calibration_data_reader @@ -278,10 +266,14 @@ def _process_subgraph(self, graph_stack: List[GraphProto]): def _generate_q4_node_config(self): """Generate weight only quant configuration for nodes.""" q4_node_config = {} - template_config = {"bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym"} + template_config_q4 = {"bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym"} + template_config_fp32 = 'fp32' for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: - q4_node_config[node.name] = template_config + if not all([self.model.get_initializer(i) is None for i in node.input]): + q4_node_config[node.name] = template_config_q4 + else: + q4_node_config[node.name] = template_config_fp32 return q4_node_config def int4_quant_algo(self): @@ -305,7 +297,7 @@ def inc_dataloader(): ratios = self.algo_config.ratios self.model = rtn_quantize( - model=self.algo_config.model_path, + model=self.model.model, weight_config=weight_only_node_config, ratios=ratios, accuracy_level=accuracy_level, @@ -321,7 +313,7 @@ def inc_dataloader(): dataloader = inc_dataloader() self.model = gptq_quantize( - model=self.algo_config.model_path, + model=self.model.model, weight_config=weight_only_node_config, dataloader=dataloader, n_samples=-1, diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index accb0847930af..6ec9903362531 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -162,12 +162,12 @@ def quant_test_with_algo( # test RTN algorithm from onnxruntime.quantization import RTNWeightOnlyQuantConfig - algo_config = RTNWeightOnlyQuantConfig(model_path=model_fp32_path) + algo_config = RTNWeightOnlyQuantConfig() elif algorithm == "GPTQ": # test GPTQ algorithm from onnxruntime.quantization import GPTQWeightOnlyQuantConfig - algo_config = GPTQWeightOnlyQuantConfig(model_path=model_fp32_path, calibration_data_reader=data_reader) + algo_config = GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) From f3a91cad00f9edc48a869cc194e47908a97bc94a Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Thu, 23 Nov 2023 16:09:27 +0800 Subject: [PATCH 08/19] Update matmul_4bits_quantizer.py --- .../python/tools/quantization/matmul_4bits_quantizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index f7684c215711d..9b537162cb381 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -353,8 +353,8 @@ def process(self): import neural_compressor assert version.parse(neural_compressor.__version__) >= version.parse( - "2.3.0" - ), "Require neural-compressor >= 2.3.0 to support weight only quantization!" + "2.3.2" + ), "Require neural-compressor >= 2.3.2 to support weight only quantization!" self.int4_quant_algo() From fd02c615d0644f1669262af996534db2455c6a47 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 23 Nov 2023 16:13:36 +0800 Subject: [PATCH 09/19] fix for code scan Signed-off-by: yuwenzho --- onnxruntime/test/python/quantization/test_op_matmul_4bits.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 6ec9903362531..fbd91c71451cc 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -158,6 +158,7 @@ def quant_test_with_algo( # Quantize fp32 model to int4 model from onnxruntime.quantization import matmul_4bits_quantizer + algo_config = None if algorithm == "RTN": # test RTN algorithm from onnxruntime.quantization import RTNWeightOnlyQuantConfig From 673d1d4596c712d7450d6ce9fe07fd38a9184a35 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Mon, 11 Dec 2023 15:46:40 +0800 Subject: [PATCH 10/19] update MatMul4BitsQuantizer args Signed-off-by: yuwenzho --- .../tools/quantization/matmul_4bits_quantizer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 9b537162cb381..93573c024adf4 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -9,7 +9,7 @@ import importlib import logging import os -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import numpy.typing as npt @@ -118,7 +118,7 @@ class MatMul4BitsQuantizer: def __init__( self, - model: ModelProto, + model: Union[ModelProto, str], block_size: int, is_symmetric: bool, nodes_to_exclude=None, @@ -126,7 +126,8 @@ def __init__( ): if nodes_to_exclude is None: nodes_to_exclude = [] - self.model = ONNXModel(model) + self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) + self.model_path = model if isinstance(model, str) else None self.block_size = block_size self.is_symmetric = is_symmetric self.nodes_to_exclude = set(nodes_to_exclude) @@ -267,7 +268,7 @@ def _generate_q4_node_config(self): """Generate weight only quant configuration for nodes.""" q4_node_config = {} template_config_q4 = {"bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym"} - template_config_fp32 = 'fp32' + template_config_fp32 = "fp32" for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: if not all([self.model.get_initializer(i) is None for i in node.input]): @@ -297,7 +298,7 @@ def inc_dataloader(): ratios = self.algo_config.ratios self.model = rtn_quantize( - model=self.model.model, + model=self.model_path if self.model_path is not None else self.model.model, weight_config=weight_only_node_config, ratios=ratios, accuracy_level=accuracy_level, @@ -313,7 +314,7 @@ def inc_dataloader(): dataloader = inc_dataloader() self.model = gptq_quantize( - model=self.model.model, + model=self.model_path if self.model_path is not None else self.model.model, weight_config=weight_only_node_config, dataloader=dataloader, n_samples=-1, From 4aa9318d4622f86c3236742f02ba47985d9c077f Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Fri, 15 Dec 2023 15:45:57 +0800 Subject: [PATCH 11/19] add log for woq Signed-off-by: yuwenzho --- onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 93573c024adf4..1dd5805f23199 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -292,6 +292,7 @@ def inc_dataloader(): weight_only_node_config = self._generate_q4_node_config() algorithm = self.algo_config.algorithm + logger.info(f"start to quantize model with {algorithm} algorithm...") if algorithm == "RTN": from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize @@ -325,6 +326,7 @@ def inc_dataloader(): perchannel=perchannel, accuracy_level=accuracy_level, ) + logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): if self.algo_config is None: From 61c9bbca88e2f8b252e23b281b9ed593302d0110 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Fri, 15 Dec 2023 15:46:11 +0800 Subject: [PATCH 12/19] fix bug in sq Signed-off-by: yuwenzho --- onnxruntime/python/tools/quantization/quantize.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index aed46563c2764..36c366b3fc406 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -479,12 +479,11 @@ def inc_dataloader(): sq = ORTSmoothQuant(model_input, dataloader, reduce_range) del dataloader model = sq.transform( - extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True) - ).model - nodes_to_exclude.extend([i.name for i in model.graph.node if i.name not in orig_nodes]) + extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)) sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.") - model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix() - onnx.save_model(model, model_input, save_as_external_data=True) + model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix() + model.save(model_input) + nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes]) model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: From 0b468452ab714e292e8492c4eeb5592a62da6165 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Mon, 25 Dec 2023 10:46:18 +0800 Subject: [PATCH 13/19] Update matmul_4bits_quantizer.py --- .../python/tools/quantization/matmul_4bits_quantizer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 1dd5805f23199..9b445ab98866f 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -268,13 +268,10 @@ def _generate_q4_node_config(self): """Generate weight only quant configuration for nodes.""" q4_node_config = {} template_config_q4 = {"bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym"} - template_config_fp32 = "fp32" for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: if not all([self.model.get_initializer(i) is None for i in node.input]): q4_node_config[node.name] = template_config_q4 - else: - q4_node_config[node.name] = template_config_fp32 return q4_node_config def int4_quant_algo(self): From 4f30e8b0efdf33fec12cc53a9391da83edda95c2 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Mon, 25 Dec 2023 10:48:08 +0800 Subject: [PATCH 14/19] Update test_op_matmul_4bits.py --- .../test/python/quantization/test_op_matmul_4bits.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index fbd91c71451cc..b19d3fac43abe 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -209,9 +209,8 @@ def test_quantize_matmul_int4_offsets(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) - @unittest.skip( - "Skip failed test in Python Packaging Test Pipeline." - "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) def test_quantize_matmul_int4_using_rtn_algo(self): if not find_spec("neural_compressor"): @@ -221,9 +220,8 @@ def test_quantize_matmul_int4_using_rtn_algo(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False) - @unittest.skip( - "Skip failed test in Python Packaging Test Pipeline." - "During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed" + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) def test_quantize_matmul_int4_using_gptq_algo(self): if not find_spec("neural_compressor"): From 6cc85571f18b5c6bf85e55cbb62994fe0cd921cd Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Sat, 30 Dec 2023 16:35:10 +0800 Subject: [PATCH 15/19] Update matmul_4bits_quantizer.py --- .../python/tools/quantization/matmul_4bits_quantizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index b174c3794403e..d18748c22da4a 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -267,7 +267,11 @@ def _process_subgraph(self, graph_stack: List[GraphProto]): def _generate_q4_node_config(self): """Generate weight only quant configuration for nodes.""" q4_node_config = {} - template_config_q4 = {"bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym"} + template_config_q4 = { + "bits": 4, + "group_size": self.block_size, + "scheme": "sym" if self.is_symmetric else "asym" + } for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: if not all([self.model.get_initializer(i) is None for i in node.input]): From a141bc38fd2674f7ce42dfa6937438a907bb63f9 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Sat, 30 Dec 2023 16:37:10 +0800 Subject: [PATCH 16/19] Update quantize.py --- onnxruntime/python/tools/quantization/quantize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 36c366b3fc406..1bd2ef42151d0 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -466,7 +466,6 @@ def quantize_static( import copy - import onnx from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant def inc_dataloader(): @@ -478,8 +477,7 @@ def inc_dataloader(): dataloader = inc_dataloader() sq = ORTSmoothQuant(model_input, dataloader, reduce_range) del dataloader - model = sq.transform( - extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)) + model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)) sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.") model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix() model.save(model_input) From 0c035c4f1348c5ef30bdf6aeb9d363ce5f94c66d Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 9 Jan 2024 09:15:01 +0800 Subject: [PATCH 17/19] fix for lint Signed-off-by: yuwenzho --- .../python/tools/quantization/matmul_4bits_quantizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index d6635bf84d4d5..82fe6887f43b1 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -11,7 +11,6 @@ import importlib import logging import os -from typing import Union import numpy as np import numpy.typing as npt @@ -108,7 +107,7 @@ class MatMul4BitsQuantizer: def __init__( self, - model: Union[ModelProto, str], + model: ModelProto | str, block_size: int, is_symmetric: bool, accuracy_level: int | None = None, From 83b3ed791a3615f798b7984f262dc1293a2c55f5 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 10 Jan 2024 08:57:42 +0800 Subject: [PATCH 18/19] fix for lint Signed-off-by: yuwenzho --- .../python/tools/quantization/matmul_4bits_quantizer.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 82fe6887f43b1..3e9f9a6544a71 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -29,10 +29,7 @@ class WeightOnlyQuantConfig: - def __init__( - self, - algorithm - ): + def __init__(self, algorithm): """This is the Base class for Weight Only Quant Configuration. Args: @@ -263,7 +260,7 @@ def _generate_q4_node_config(self): template_config_q4 = { "bits": 4, "group_size": self.block_size, - "scheme": "sym" if self.is_symmetric else "asym" + "scheme": "sym" if self.is_symmetric else "asym", } for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: From 81390c04836561728e792c9f8a85e9ef84173b94 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 10 Jan 2024 15:11:51 +0800 Subject: [PATCH 19/19] fix import Signed-off-by: yuwenzho --- onnxruntime/python/tools/quantization/__init__.py | 2 -- .../test/python/quantization/test_op_matmul_4bits.py | 8 ++------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 3bc055d1f063a..170c0928fee23 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -5,8 +5,6 @@ MinMaxCalibrater, create_calibrator, ) -from .matmul_4bits_quantizer import GPTQWeightOnlyQuantConfig # noqa: F401 -from .matmul_4bits_quantizer import RTNWeightOnlyQuantConfig # noqa: F401 from .matmul_weight4_quantizer import MatMulWeight4Quantizer # noqa: F401 from .qdq_quantizer import QDQQuantizer # noqa: F401 from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401 diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index b19d3fac43abe..73dae08af8ece 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -161,14 +161,10 @@ def quant_test_with_algo( algo_config = None if algorithm == "RTN": # test RTN algorithm - from onnxruntime.quantization import RTNWeightOnlyQuantConfig - - algo_config = RTNWeightOnlyQuantConfig() + algo_config = matmul_4bits_quantizer.RTNWeightOnlyQuantConfig() elif algorithm == "GPTQ": # test GPTQ algorithm - from onnxruntime.quantization import GPTQWeightOnlyQuantConfig - - algo_config = GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) + algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)