Skip to content

Commit

Permalink
add weight only quantize
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Sep 1, 2023
1 parent 8b98eca commit 3274b34
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 1 deletion.
5 changes: 4 additions & 1 deletion onnxruntime/python/tools/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@
from .quantize import quantize # noqa: F401
from .quantize import quantize_dynamic # noqa: F401
from .quantize import quantize_static # noqa: F401
from .shape_inference import quant_pre_process # noqa: F401
from .quantize_weight_only import RTNWeightOnlyQuantConfig # noqa: F401
from .quantize_weight_only import GPTQWeightOnlyQuantConfig # noqa: F401
from .quantize_weight_only import quantize_weight_only # noqa: F401
from .shape_inference import quant_pre_process # noqa: F401
198 changes: 198 additions & 0 deletions onnxruntime/python/tools/quantization/quantize_weight_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import copy
import logging
import importlib
from pathlib import Path
from .calibrate import CalibrationDataReader
from .quant_utils import load_model_with_shape_infer

class WeightOnlyQuantConfig:
def __init__(
self,
algorithm,
group_size=32,
scheme="sym",
use_external_data_format=False,
):
"""This is the Base class for Weight Only Quant Configuration.
Args:
algorithm:
weight only quantize algorithm name.
group_size:
how many elements share one scale/zp. -1 indicates the per-channel
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
use_external_data_format:
option used for large size (>2GB) model. Set to False by default.
"""
"""This is the Base class for Weight Only Quant Configuration.
Args:
algorithm:
weight only quantize algorithm name.
"""
self.algorithm = algorithm
self.group_size = group_size
self.scheme = scheme
self.use_external_data_format = use_external_data_format

class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,
group_size=32,
scheme="sym",
use_external_data_format=False,
):
"""
This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
RTN is the most straightforward way to quantize weight using scale maps.
Args:
group_size:
how many elements share one scale/zp. -1 indicates the per-channel
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
use_external_data_format:
option used for large size (>2GB) model. Set to False by default.
"""
super().__init__(
algorithm="RTN",
group_size=group_size,
scheme=scheme,
use_external_data_format=use_external_data_format
)

class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,
calibration_data_reader: CalibrationDataReader,
group_size=32,
scheme="asym",
percdamp=.01,
blocksize=128,
actorder=False,
mse=False,
perchannel=True,
use_external_data_format=False,
):
"""
This is a class for GPTQ algorithm Weight Only Quant Configuration.
GPTQ algorithm provides more accurate quantization but requires more computational resources.
Args:
calibration_data_reader:
a calibration data reader. It enumerates calibration data and generates inputs for the original model.
group_size:
how many elements share one scale/zp. -1 indicates the per-channel
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
percdamp:
percent of the average Hessian diagonal to use for dampening.
blocksize (int, optional):
channel number in one block to execute a GPTQ quantization iteration.
actorder (bool, optional):
whether rearrange Hessian matrix considering the diag's value.
mse (bool, optional):
whether get scale and zero point with mse error.
perchannel (bool, optional):
whether quantize weight per-channel.
use_external_data_format:
option used for large size (>2GB) model. Set to False by default.
"""
super().__init__(
algorithm="GPTQ",
group_size=group_size,
scheme=scheme,
use_external_data_format=use_external_data_format,
)
self.calibration_data_reader = calibration_data_reader
self.percdamp = percdamp
self.blocksize = blocksize
self.actorder = actorder
self.mse = mse
self.perchannel = perchannel

def _generate_weight_only_node_config(model, group_size, scheme):
"""Generate weight only quant configuration for nodes.
Args:
model:
onnx.ModelProto.
group_size:
how many elements share one scale/zp. -1 indicates the per-channel
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
Returns:
dict: weight only quant configuration for nodes.
"""
weight_only_node_config = {}
template_config = {'weight': {"bits": 4, "group_size": group_size, "scheme": scheme}}
for node in model.graph.node:
if node.op_type in ["MatMul"]: # TODO: enable Gemm op support
weight_only_node_config[node.name] = template_config
return weight_only_node_config


def quantize_weight_only(
model_input: Path,
model_output: Path,
weight_only_config: WeightOnlyQuantConfig,
):
"""Weight Only Quantize a model with WeightOnlyQuantConfig. Please refer to
https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md
for more details on weight only quantization.
Args:
model_input (Path): Path to the model to weight only quantize.
model_output (Path): Path to save the quantized model.
weight_only_config (WeightOnlyQuantConfig): Weight Only Quantization Configuration.
Raises:
RuntimeError: Raise RuntimeError if neural-compressor is not correctly installed.
"""
try:
importlib.import_module("neural_compressor.adaptor.ox_utils.weight_only")
except Exception as e:
logging.error(f"{e}.")
raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e

def inc_dataloader():
data_reader = copy.deepcopy(weight_only_config.calibration_data_reader)
for data in data_reader:
yield data, None

model = load_model_with_shape_infer(Path(model_input))
scheme = weight_only_config.scheme
group_size = weight_only_config.group_size
weight_only_node_config = _generate_weight_only_node_config(model, group_size, scheme)

algorithm = weight_only_config.algorithm
if algorithm == "RTN":
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
model = rtn_quantize(model=model_input,
tune_cfg=weight_only_node_config)
elif algorithm == "GPTQ":
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
percdamp = weight_only_config.percdamp
blocksize = weight_only_config.blocksize
actorder = weight_only_config.actorder
mse = weight_only_config.mse
perchannel = weight_only_config.perchannel
dataloader = inc_dataloader()

model = gptq_quantize(model=model_input,
tune_cfg=weight_only_node_config,
dataloader=dataloader,
n_samples=-1,
percdamp=percdamp,
blocksize=blocksize,
actorder=actorder,
mse=mse,
perchannel=perchannel)

model.save_model_to_file(model_output, weight_only_config.use_external_data_format)
127 changes: 127 additions & 0 deletions onnxruntime/test/python/quantization/test_quantize_weight_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import tempfile
import unittest
from importlib.util import find_spec
from pathlib import Path

import numpy as np
import onnx
from onnx import TensorProto, helper
from onnxruntime.quantization.onnx_model import ONNXModel
from op_test_utils import check_model_correctness, input_feeds_neg_one_zero_one

from onnxruntime.quantization import (
RTNWeightOnlyQuantConfig,
GPTQWeightOnlyQuantConfig,
quantize_weight_only
)

def construct_model(output_model_path):
# (input)
# |
# Mul
# |
# MatMul
# |
# (output)
initializers = []

# make mul node
mul_data = np.random.normal(0, 0.1, [1, 10]).astype(np.float32)
initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data"))
mul_node = onnx.helper.make_node("Mul", ["input", "mul.data"], ["mul.output"], "Mul_0")

# make matmul node
matmul_weight = np.random.normal(0, 0.1, [10, 1]).astype(np.float32)
initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight"))
matmul_node = onnx.helper.make_node("MatMul",
["mul.output", "matmul.weight"],
["output"],
"MatMul_1")

# make graph
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10])
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1])
graph_name = "weight_only_quant_test"
graph = helper.make_graph(
[mul_node, matmul_node],
graph_name,
[input_tensor],
[output_tensor],
initializer=initializers,
)
model = helper.make_model(
graph, opset_imports=[helper.make_opsetid("", 13)]
)
model.ir_version = onnx.IR_VERSION

onnx.save(model, output_model_path)

class TestWeightOnlyQuantization(unittest.TestCase):
@classmethod
def setUpClass(cls):
# TODO: there will be a refactor to handle all those temporary directories.
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.quant.save.as.external")
cls._model_fp32_path = str(Path(cls._tmp_model_dir.name) / "fp32.onnx")
cls._model_weight_only_path = str(Path(cls._tmp_model_dir.name) / "fp32.weight_only_quant.onnx")
np.random.seed(1)
construct_model(cls._model_fp32_path)

@classmethod
def tearDownClass(cls):
cls._tmp_model_dir.cleanup()

@unittest.skip(
"Skip failed test in Python Packaging Test Pipeline."
"During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed"
)
def test_quantize_weight_only_rtn(self):
if not find_spec("neural_compressor"):
self.skipTest("skip test_quantize_weight_only_rtn since neural_compressor is not installed")

weight_only_config = RTNWeightOnlyQuantConfig()
quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config)
check_model_correctness(
self,
self._model_fp32_path,
self._model_weight_only_path,
{"input": np.random.rand(1, 10).astype(np.float32)},
)

model_fp32 = ONNXModel(onnx.load(self._model_fp32_path))
model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path))
self.assertNotEqual(model_fp32.get_initializer("matmul.weight"),
model_weight_only.get_initializer("matmul.weight"))


@unittest.skip(
"Skip failed test in Python Packaging Test Pipeline."
"During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed"
)
def test_quantize_weight_only_gptq(self):
if not find_spec("neural_compressor"):
self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed")

data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 10]})
weight_only_config = GPTQWeightOnlyQuantConfig(data_reader)
quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config)
check_model_correctness(
self,
self._model_fp32_path,
self._model_weight_only_path,
{"input": np.random.rand(1, 10).astype(np.float32)},
)

model_fp32 = ONNXModel(onnx.load(self._model_fp32_path))
model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path))
self.assertNotEqual(model_fp32.get_initializer("matmul.weight"),
model_weight_only.get_initializer("matmul.weight"))

if __name__ == '__main__':
unittest.main()

0 comments on commit 3274b34

Please sign in to comment.