diff --git a/onnxruntime/python/tools/quantization/quantize_weight_only.py b/onnxruntime/python/tools/quantization/quantize_weight_only.py index 6c7fc4d63a00c..2ad4c819270c3 100644 --- a/onnxruntime/python/tools/quantization/quantize_weight_only.py +++ b/onnxruntime/python/tools/quantization/quantize_weight_only.py @@ -2,6 +2,7 @@ import logging import importlib from pathlib import Path +from packaging import version from .calibrate import CalibrationDataReader from .quant_utils import load_model_with_shape_infer @@ -42,6 +43,7 @@ def __init__( self, group_size=32, scheme="sym", + ratios={}, use_external_data_format=False, ): """ @@ -63,6 +65,7 @@ def __init__( scheme=scheme, use_external_data_format=use_external_data_format ) + self.ratios = ratios class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( @@ -131,9 +134,9 @@ def _generate_weight_only_node_config(model, group_size, scheme): dict: weight only quant configuration for nodes. """ weight_only_node_config = {} - template_config = {'weight': {"bits": 4, "group_size": group_size, "scheme": scheme}} + template_config = {"bits": 4, "group_size": group_size, "scheme": scheme} for node in model.graph.node: - if node.op_type in ["MatMul"]: # TODO: enable Gemm op support + if node.op_type in ["MatMul"]: weight_only_node_config[node.name] = template_config return weight_only_node_config @@ -156,11 +159,15 @@ def quantize_weight_only( RuntimeError: Raise RuntimeError if neural-compressor is not correctly installed. """ try: - importlib.import_module("neural_compressor.adaptor.ox_utils.weight_only") + importlib.import_module("neural_compressor") except Exception as e: logging.error(f"{e}.") raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e + import neural_compressor + assert version.parse(neural_compressor.__version__) >= version.parse("2.3.0"), \ + "Require neural-compressor >= 2.3.0 to support weight only quantization!" + def inc_dataloader(): data_reader = copy.deepcopy(weight_only_config.calibration_data_reader) for data in data_reader: @@ -174,8 +181,11 @@ def inc_dataloader(): algorithm = weight_only_config.algorithm if algorithm == "RTN": from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + ratios = weight_only_config.ratios + model = rtn_quantize(model=model_input, - tune_cfg=weight_only_node_config) + weight_config=weight_only_node_config, + ratios=ratios) elif algorithm == "GPTQ": from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize percdamp = weight_only_config.percdamp @@ -186,7 +196,7 @@ def inc_dataloader(): dataloader = inc_dataloader() model = gptq_quantize(model=model_input, - tune_cfg=weight_only_node_config, + weight_config=weight_only_node_config, dataloader=dataloader, n_samples=-1, percdamp=percdamp, diff --git a/onnxruntime/test/python/quantization/test_quantize_weight_only.py b/onnxruntime/test/python/quantization/test_quantize_weight_only.py index 5fb1201788623..e2146a98b0f43 100644 --- a/onnxruntime/test/python/quantization/test_quantize_weight_only.py +++ b/onnxruntime/test/python/quantization/test_quantize_weight_only.py @@ -33,12 +33,12 @@ def construct_model(output_model_path): initializers = [] # make mul node - mul_data = np.random.normal(0, 0.1, [1, 10]).astype(np.float32) + mul_data = np.random.normal(0, 0.1, [1, 32]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data")) mul_node = onnx.helper.make_node("Mul", ["input", "mul.data"], ["mul.output"], "Mul_0") # make matmul node - matmul_weight = np.random.normal(0, 0.1, [10, 1]).astype(np.float32) + matmul_weight = np.random.normal(0, 0.1, [32, 1]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight")) matmul_node = onnx.helper.make_node("MatMul", ["mul.output", "matmul.weight"], @@ -46,7 +46,7 @@ def construct_model(output_model_path): "MatMul_1") # make graph - input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 32]) output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1]) graph_name = "weight_only_quant_test" graph = helper.make_graph( @@ -91,7 +91,7 @@ def test_quantize_weight_only_rtn(self): self, self._model_fp32_path, self._model_weight_only_path, - {"input": np.random.rand(1, 10).astype(np.float32)}, + {"input": np.random.rand(1, 32).astype(np.float32)}, ) model_fp32 = ONNXModel(onnx.load(self._model_fp32_path)) @@ -108,14 +108,14 @@ def test_quantize_weight_only_gptq(self): if not find_spec("neural_compressor"): self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed") - data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 10]}) + data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 32]}) weight_only_config = GPTQWeightOnlyQuantConfig(data_reader) quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config) check_model_correctness( self, self._model_fp32_path, self._model_weight_only_path, - {"input": np.random.rand(1, 10).astype(np.float32)}, + {"input": np.random.rand(1, 32).astype(np.float32)}, ) model_fp32 = ONNXModel(onnx.load(self._model_fp32_path))