Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support INT4 weight only quantize, including RTN and GPTQ 2 algorithms #17390

Merged
merged 30 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3274b34
add weight only quantize
yuwenzho Sep 1, 2023
ad6b8f6
Merge branch 'main' into yuwenzho/int4
yuwenzho Sep 14, 2023
b232f2e
Merge branch 'main' into yuwenzho/int4
yuwenzho Sep 20, 2023
2f867e2
update inc API usage
yuwenzho Sep 20, 2023
32f7eae
update format
yuwenzho Oct 10, 2023
41e96a7
Merge branch 'main' into yuwenzho/int4
yuwenzho Oct 10, 2023
a8382ac
add accuracy_level attr
mengniwang95 Nov 10, 2023
930ce53
Merge branch 'main' into yuwenzho/int4
yuwenzho Nov 21, 2023
ee73812
update usage of RTN & GPTQ algorithm
yuwenzho Nov 22, 2023
21fc0c4
fix typo
yuwenzho Nov 22, 2023
8f6ea60
Merge branch 'main' into yuwenzho/int4
yuwenzho Nov 22, 2023
b2b9d66
update usage of RTN & GPTQ algorithm
yuwenzho Nov 23, 2023
f3a91ca
Update matmul_4bits_quantizer.py
mengniwang95 Nov 23, 2023
fd02c61
fix for code scan
yuwenzho Nov 23, 2023
673d1d4
update MatMul4BitsQuantizer args
yuwenzho Dec 11, 2023
32fbf0a
Merge branch 'main' into yuwenzho/int4
yuwenzho Dec 11, 2023
4aa9318
add log for woq
yuwenzho Dec 15, 2023
61c9bbc
fix bug in sq
yuwenzho Dec 15, 2023
87d825d
Merge branch 'main' into yuwenzho/int4
yuwenzho Dec 15, 2023
0b46845
Update matmul_4bits_quantizer.py
mengniwang95 Dec 25, 2023
4f30e8b
Update test_op_matmul_4bits.py
mengniwang95 Dec 25, 2023
bb25597
Merge branch 'microsoft:main' into yuwenzho/int4
mengniwang95 Dec 25, 2023
6cc8557
Update matmul_4bits_quantizer.py
mengniwang95 Dec 30, 2023
a141bc3
Update quantize.py
mengniwang95 Dec 30, 2023
c16c001
fix conflict
yuwenzho Jan 8, 2024
0c035c4
fix for lint
yuwenzho Jan 9, 2024
475f388
Merge branch 'main' into yuwenzho/int4
yuwenzho Jan 9, 2024
83b3ed7
fix for lint
yuwenzho Jan 10, 2024
81390c0
fix import
yuwenzho Jan 10, 2024
62438b8
Merge branch 'main' into yuwenzho/int4
yuwenzho Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
MinMaxCalibrater,
create_calibrator,
)
from .matmul_4bits_quantizer import GPTQWeightOnlyQuantConfig # noqa: F401
from .matmul_4bits_quantizer import RTNWeightOnlyQuantConfig # noqa: F401
from .matmul_weight4_quantizer import MatMulWeight4Quantizer # noqa: F401
from .qdq_quantizer import QDQQuantizer # noqa: F401
from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401
Expand Down
193 changes: 177 additions & 16 deletions onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,123 @@
from __future__ import annotations

import argparse
import copy
import importlib
import logging
import os
from typing import Union

import numpy as np
import numpy.typing as npt
import onnx
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
from packaging import version

from onnxruntime.capi._pybind_state import quantize_matmul_4bits

from .calibrate import CalibrationDataReader
from .onnx_model import ONNXModel
from .quant_utils import attribute_to_kwarg

logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)


class WeightOnlyQuantConfig:
def __init__(
self,
algorithm
):
"""This is the Base class for Weight Only Quant Configuration.

Args:
algorithm:
weight only quantize algorithm name.
"""
self.algorithm = algorithm


class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,
ratios=None,
):
"""
This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
RTN is the most straightforward way to quantize weight using scale maps.

Args:
ratios:
percentile of clip. Defaults to {}.
"""
if ratios is None:
ratios = {}
super().__init__(
algorithm="RTN",
)
self.ratios = ratios


class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,
calibration_data_reader: CalibrationDataReader,
percdamp=0.01,
blocksize=128,
actorder=False,
mse=False,
perchannel=True,
):
"""
This is a class for GPTQ algorithm Weight Only Quant Configuration.
GPTQ algorithm provides more accurate quantization but requires more computational resources.

Args:
calibration_data_reader:
a calibration data reader. It enumerates calibration data and generates inputs for the original model.
percdamp:
percent of the average Hessian diagonal to use for dampening.
blocksize (int, optional):
channel number in one block to execute a GPTQ quantization iteration.
actorder (bool, optional):
whether rearrange Hessian matrix considering the diag's value.
mse (bool, optional):
whether get scale and zero point with mse error.
perchannel (bool, optional):
whether quantize weight per-channel.
"""
super().__init__(
algorithm="GPTQ",
)
self.calibration_data_reader = calibration_data_reader
self.percdamp = percdamp
self.blocksize = blocksize
self.actorder = actorder
self.mse = mse
self.perchannel = perchannel


class MatMul4BitsQuantizer:
"""Perform 4b quantization of constant MatMul weights"""

def __init__(
self,
model: ModelProto,
model: Union[ModelProto, str],
Fixed Show fixed Hide fixed
block_size: int,
is_symmetric: bool,
accuracy_level: int | None = None,
nodes_to_exclude: list[str] | None = None,
nodes_to_exclude=None,
algo_config: WeightOnlyQuantConfig = None,
):
if nodes_to_exclude is None:
nodes_to_exclude = []
self.model = ONNXModel(model)
self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
self.model_path = model if isinstance(model, str) else None
self.block_size = block_size
self.is_symmetric = is_symmetric
self.accuracy_level = accuracy_level
self.nodes_to_exclude = set(nodes_to_exclude)
self.algo_config = algo_config

@staticmethod
def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
Expand Down Expand Up @@ -176,20 +258,99 @@
graph_stack.pop()
return graph

def _generate_q4_node_config(self):
"""Generate weight only quant configuration for nodes."""
q4_node_config = {}
template_config_q4 = {
"bits": 4,
"group_size": self.block_size,
"scheme": "sym" if self.is_symmetric else "asym"
}
for node in self.model.model.graph.node:
if node.op_type in ["MatMul"]:
if not all([self.model.get_initializer(i) is None for i in node.input]):
q4_node_config[node.name] = template_config_q4
return q4_node_config

def int4_quant_algo(self):
"""4b quantize a model with RTN or GPTQ algorithm. Please refer to
https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md
for more details on weight only quantization using Intel® Neural Compressor.
"""

def inc_dataloader():
data_reader = copy.deepcopy(self.algo_config.calibration_data_reader)
for data in data_reader:
yield data, None

kwargs = {}
if self.accuracy_level is not None:
kwargs["accuracy_level"] = self.accuracy_level
weight_only_node_config = self._generate_q4_node_config()

algorithm = self.algo_config.algorithm
logger.info(f"start to quantize model with {algorithm} algorithm...")
if algorithm == "RTN":
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize

kwargs["ratios"] = self.algo_config.ratios

self.model = rtn_quantize(
model=self.model_path if self.model_path is not None else self.model.model,
weight_config=weight_only_node_config,
**kwargs,
)
elif algorithm == "GPTQ":
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize

kwargs["percdamp"] = self.algo_config.percdamp
kwargs["blocksize"] = self.algo_config.blocksize
kwargs["actorder"] = self.algo_config.actorder
kwargs["mse"] = self.algo_config.mse
kwargs["perchannel"] = self.algo_config.perchannel
kwargs["n_samples"] = -1
dataloader = inc_dataloader()

self.model = gptq_quantize(
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
model=self.model_path if self.model_path is not None else self.model.model,
weight_config=weight_only_node_config,
dataloader=dataloader,
**kwargs,
)
logger.info(f"complete quantization of model with {algorithm} algorithm.")

def process(self):
# use a stack to keep track of sub-graphs
graph_stack = [self.model.graph()]
opset_import = self.model.opset_import()

has_ms_domain = False
for opset in opset_import:
if opset.domain == "com.microsoft":
has_ms_domain = True
if not has_ms_domain:
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])

self._process_subgraph(graph_stack)
self.model.clean_initializers()
if self.algo_config is None:
# use a stack to keep track of sub-graphs
graph_stack = [self.model.graph()]
opset_import = self.model.opset_import()

has_ms_domain = False
for opset in opset_import:
if opset.domain == "com.microsoft":
has_ms_domain = True
if not has_ms_domain:
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])

self._process_subgraph(graph_stack)
self.model.clean_initializers()
else:
# use Intel® Neural Compressor for RTN or GPTQ weight-only quantize algorithm
try:
importlib.import_module("neural_compressor")
except Exception as e:
logging.error(f"{e}.")
raise RuntimeError(
"neural-compressor is not correctly installed. Please check your environment."
) from e

import neural_compressor

assert version.parse(neural_compressor.__version__) >= version.parse(
"2.3.2"
), "Require neural-compressor >= 2.3.2 to support weight only quantization!"

self.int4_quant_algo()


def parse_args():
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def quantize_static(

import copy

import onnx
from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant

def inc_dataloader():
Expand All @@ -478,13 +477,11 @@ def inc_dataloader():
dataloader = inc_dataloader()
sq = ORTSmoothQuant(model_input, dataloader, reduce_range)
del dataloader
model = sq.transform(
extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)
).model
nodes_to_exclude.extend([i.name for i in model.graph.node if i.name not in orig_nodes])
model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True))
sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.")
model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
onnx.save_model(model, model_input, save_as_external_data=True)
model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix()
model.save(model_input)
nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes])
model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration

with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
Expand Down
74 changes: 73 additions & 1 deletion onnxruntime/test/python/quantization/test_op_matmul_4bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> Non
output_name = "output"
initializers = []

def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str):
def make_matmul(
input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str, node_name: str
):
weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32)
initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name))
return onnx.helper.make_node(
"MatMul",
[input_name, weight_name],
[output_name],
node_name,
)

in_features = 52
Expand All @@ -88,6 +91,7 @@ def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_na
[in_features, out_features],
"linear1.weight",
output_name,
"MatMul_0",
)

# make graph
Expand Down Expand Up @@ -139,6 +143,52 @@ def quant_test(
else:
raise exception

def quant_test_with_algo(
self,
algorithm: str,
model_fp32_path: str,
data_reader: TestDataFeeds,
block_size: int,
is_symmetric: bool,
):
model_int4_path = str(
Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute()
)

# Quantize fp32 model to int4 model
from onnxruntime.quantization import matmul_4bits_quantizer

algo_config = None
if algorithm == "RTN":
# test RTN algorithm
from onnxruntime.quantization import RTNWeightOnlyQuantConfig

algo_config = RTNWeightOnlyQuantConfig()
elif algorithm == "GPTQ":
# test GPTQ algorithm
from onnxruntime.quantization import GPTQWeightOnlyQuantConfig

algo_config = GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)

model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
Fixed Show fixed Hide fixed
quant.process()
quant.model.save_model_to_file(model_int4_path, False)

quant_nodes = {"MatMulNBits": 1}
check_op_type_count(self, model_int4_path, **quant_nodes)

data_reader.rewind()

try:
check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next())
except Exception as exception:
if "4b quantization not yet supported on this hardware platform!" in exception.args[0]:
# Currently we don't have int4 quantization support on all platforms, has to tolerate this exception
pass
else:
raise exception

@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
)
Expand All @@ -159,6 +209,28 @@ def test_quantize_matmul_int4_offsets(self):
data_reader = self.input_feeds(1, {"input": [100, 52]})
self.quant_test(model_fp32_path, data_reader, 32, False)

@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
)
def test_quantize_matmul_int4_using_rtn_algo(self):
if not find_spec("neural_compressor"):
self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
self.construct_model_matmul(model_fp32_path, symmetric=False)
data_reader = self.input_feeds(1, {"input": [100, 52]})
self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False)

@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
)
def test_quantize_matmul_int4_using_gptq_algo(self):
if not find_spec("neural_compressor"):
self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
self.construct_model_matmul(model_fp32_path, symmetric=False)
data_reader = self.input_feeds(1, {"input": [100, 52]})
self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False)


if __name__ == "__main__":
unittest.main()
Loading