Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support INT4 weight only quantize, including RTN and GPTQ 2 algorithms #17390

Merged
merged 30 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3274b34
add weight only quantize
yuwenzho Sep 1, 2023
ad6b8f6
Merge branch 'main' into yuwenzho/int4
yuwenzho Sep 14, 2023
b232f2e
Merge branch 'main' into yuwenzho/int4
yuwenzho Sep 20, 2023
2f867e2
update inc API usage
yuwenzho Sep 20, 2023
32f7eae
update format
yuwenzho Oct 10, 2023
41e96a7
Merge branch 'main' into yuwenzho/int4
yuwenzho Oct 10, 2023
a8382ac
add accuracy_level attr
mengniwang95 Nov 10, 2023
930ce53
Merge branch 'main' into yuwenzho/int4
yuwenzho Nov 21, 2023
ee73812
update usage of RTN & GPTQ algorithm
yuwenzho Nov 22, 2023
21fc0c4
fix typo
yuwenzho Nov 22, 2023
8f6ea60
Merge branch 'main' into yuwenzho/int4
yuwenzho Nov 22, 2023
b2b9d66
update usage of RTN & GPTQ algorithm
yuwenzho Nov 23, 2023
f3a91ca
Update matmul_4bits_quantizer.py
mengniwang95 Nov 23, 2023
fd02c61
fix for code scan
yuwenzho Nov 23, 2023
673d1d4
update MatMul4BitsQuantizer args
yuwenzho Dec 11, 2023
32fbf0a
Merge branch 'main' into yuwenzho/int4
yuwenzho Dec 11, 2023
4aa9318
add log for woq
yuwenzho Dec 15, 2023
61c9bbc
fix bug in sq
yuwenzho Dec 15, 2023
87d825d
Merge branch 'main' into yuwenzho/int4
yuwenzho Dec 15, 2023
0b46845
Update matmul_4bits_quantizer.py
mengniwang95 Dec 25, 2023
4f30e8b
Update test_op_matmul_4bits.py
mengniwang95 Dec 25, 2023
bb25597
Merge branch 'microsoft:main' into yuwenzho/int4
mengniwang95 Dec 25, 2023
6cc8557
Update matmul_4bits_quantizer.py
mengniwang95 Dec 30, 2023
a141bc3
Update quantize.py
mengniwang95 Dec 30, 2023
c16c001
fix conflict
yuwenzho Jan 8, 2024
0c035c4
fix for lint
yuwenzho Jan 9, 2024
475f388
Merge branch 'main' into yuwenzho/int4
yuwenzho Jan 9, 2024
83b3ed7
fix for lint
yuwenzho Jan 10, 2024
81390c0
fix import
yuwenzho Jan 10, 2024
62438b8
Merge branch 'main' into yuwenzho/int4
yuwenzho Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion onnxruntime/python/tools/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@
from .quantize import quantize # noqa: F401
from .quantize import quantize_dynamic # noqa: F401
from .quantize import quantize_static # noqa: F401
from .shape_inference import quant_pre_process # noqa: F401
from .quantize_weight_only import RTNWeightOnlyQuantConfig # noqa: F401
from .quantize_weight_only import GPTQWeightOnlyQuantConfig # noqa: F401
from .quantize_weight_only import quantize_weight_only # noqa: F401
from .shape_inference import quant_pre_process # noqa: F401

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

No newline at end of file.
See https://beta.ruff.rs/docs/rules/
208 changes: 208 additions & 0 deletions onnxruntime/python/tools/quantization/quantize_weight_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import copy

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.
Fixed Show fixed Hide fixed
import logging
import importlib
from pathlib import Path
from packaging import version
from .calibrate import CalibrationDataReader
from .quant_utils import load_model_with_shape_infer

class WeightOnlyQuantConfig:
def __init__(
self,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
algorithm,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
group_size=32,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
scheme="sym",
use_external_data_format=False,
):
"""This is the Base class for Weight Only Quant Configuration.

Args:
algorithm:
weight only quantize algorithm name.
group_size:
how many elements share one scale/zp. -1 indicates the per-channel

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
use_external_data_format:
option used for large size (>2GB) model. Set to False by default.
"""
"""This is the Base class for Weight Only Quant Configuration.

Args:
algorithm:

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
weight only quantize algorithm name.
"""
self.algorithm = algorithm
self.group_size = group_size
self.scheme = scheme
self.use_external_data_format = use_external_data_format

class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,
group_size=32,
scheme="sym",
ratios={},

Check warning

Code scanning / lintrunner

RUFF/B006 Warning

Do not use mutable data structures for argument defaults.
See https://beta.ruff.rs/docs/rules/
use_external_data_format=False,
):
"""
This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
RTN is the most straightforward way to quantize weight using scale maps.

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/

Args:
group_size:
how many elements share one scale/zp. -1 indicates the per-channel

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
use_external_data_format:
option used for large size (>2GB) model. Set to False by default.
"""
super().__init__(
algorithm="RTN",

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
group_size=group_size,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
scheme=scheme,
use_external_data_format=use_external_data_format
)
self.ratios = ratios

class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
calibration_data_reader: CalibrationDataReader,
group_size=32,
scheme="asym",
percdamp=.01,
blocksize=128,
actorder=False,
mse=False,
perchannel=True,
use_external_data_format=False,
):
"""
This is a class for GPTQ algorithm Weight Only Quant Configuration.

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
GPTQ algorithm provides more accurate quantization but requires more computational resources.

Args:
calibration_data_reader:

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
a calibration data reader. It enumerates calibration data and generates inputs for the original model.
group_size:
how many elements share one scale/zp. -1 indicates the per-channel

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.
percdamp:

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
percent of the average Hessian diagonal to use for dampening.
blocksize (int, optional):
channel number in one block to execute a GPTQ quantization iteration.
actorder (bool, optional):
whether rearrange Hessian matrix considering the diag's value.
mse (bool, optional):
whether get scale and zero point with mse error.
perchannel (bool, optional):
whether quantize weight per-channel.
use_external_data_format:
option used for large size (>2GB) model. Set to False by default.
"""
super().__init__(
algorithm="GPTQ",

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
group_size=group_size,
scheme=scheme,
use_external_data_format=use_external_data_format,
)
self.calibration_data_reader = calibration_data_reader
self.percdamp = percdamp
self.blocksize = blocksize
self.actorder = actorder
self.mse = mse
self.perchannel = perchannel

def _generate_weight_only_node_config(model, group_size, scheme):
"""Generate weight only quant configuration for nodes.

Args:
model:
onnx.ModelProto.
group_size:
how many elements share one scale/zp. -1 indicates the per-channel

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
quantization per output channel.
scheme:
symmetrize or asymmetric calibration data for weights.

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/

Returns:
dict: weight only quant configuration for nodes.
"""
weight_only_node_config = {}
template_config = {"bits": 4, "group_size": group_size, "scheme": scheme}
for node in model.graph.node:
if node.op_type in ["MatMul"]:
weight_only_node_config[node.name] = template_config
return weight_only_node_config


def quantize_weight_only(
model_input: Path,
model_output: Path,
weight_only_config: WeightOnlyQuantConfig,
):
"""Weight Only Quantize a model with WeightOnlyQuantConfig. Please refer to
https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
for more details on weight only quantization.

Args:
model_input (Path): Path to the model to weight only quantize.
model_output (Path): Path to save the quantized model.
weight_only_config (WeightOnlyQuantConfig): Weight Only Quantization Configuration.

Raises:
RuntimeError: Raise RuntimeError if neural-compressor is not correctly installed.
"""
try:
importlib.import_module("neural_compressor")
except Exception as e:
logging.error(f"{e}.")
raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Blank line contains whitespace.
See https://beta.ruff.rs/docs/rules/
import neural_compressor
assert version.parse(neural_compressor.__version__) >= version.parse("2.3.0"), \
"Require neural-compressor >= 2.3.0 to support weight only quantization!"

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Blank line contains whitespace.
See https://beta.ruff.rs/docs/rules/
def inc_dataloader():
data_reader = copy.deepcopy(weight_only_config.calibration_data_reader)
for data in data_reader:
yield data, None

model = load_model_with_shape_infer(Path(model_input))
scheme = weight_only_config.scheme
group_size = weight_only_config.group_size
weight_only_node_config = _generate_weight_only_node_config(model, group_size, scheme)

algorithm = weight_only_config.algorithm
if algorithm == "RTN":
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
ratios = weight_only_config.ratios

model = rtn_quantize(model=model_input,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
weight_config=weight_only_node_config,
ratios=ratios)
elif algorithm == "GPTQ":
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
percdamp = weight_only_config.percdamp
blocksize = weight_only_config.blocksize
actorder = weight_only_config.actorder
mse = weight_only_config.mse
perchannel = weight_only_config.perchannel
dataloader = inc_dataloader()

model = gptq_quantize(model=model_input,
weight_config=weight_only_node_config,
dataloader=dataloader,
n_samples=-1,
percdamp=percdamp,
blocksize=blocksize,
actorder=actorder,
mse=mse,
perchannel=perchannel)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Blank line contains whitespace.
See https://beta.ruff.rs/docs/rules/
model.save_model_to_file(model_output, weight_only_config.use_external_data_format)
127 changes: 127 additions & 0 deletions onnxruntime/test/python/quantization/test_quantize_weight_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python

Check warning

Code scanning / lintrunner

RUFF/format Warning test

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning test

Run lintrunner -a to apply this patch.
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import tempfile
import unittest
from importlib.util import find_spec
from pathlib import Path

import numpy as np
import onnx
Fixed Show fixed Hide fixed
from onnx import TensorProto, helper
from onnxruntime.quantization.onnx_model import ONNXModel
from op_test_utils import check_model_correctness, input_feeds_neg_one_zero_one

from onnxruntime.quantization import (
RTNWeightOnlyQuantConfig,
GPTQWeightOnlyQuantConfig,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
quantize_weight_only
)

def construct_model(output_model_path):
# (input)
# |
# Mul
# |
# MatMul
# |
# (output)
initializers = []

Check warning

Code scanning / lintrunner

RUFF/W293 Warning test

Blank line contains whitespace.
See https://beta.ruff.rs/docs/rules/
# make mul node
mul_data = np.random.normal(0, 0.1, [1, 32]).astype(np.float32)
initializers.append(onnx.numpy_helper.from_array(mul_data, name="mul.data"))
mul_node = onnx.helper.make_node("Mul", ["input", "mul.data"], ["mul.output"], "Mul_0")

# make matmul node
matmul_weight = np.random.normal(0, 0.1, [32, 1]).astype(np.float32)
initializers.append(onnx.numpy_helper.from_array(matmul_weight, name="matmul.weight"))
matmul_node = onnx.helper.make_node("MatMul",

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
["mul.output", "matmul.weight"],

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
["output"],
"MatMul_1")

# make graph
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 32])
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1])
graph_name = "weight_only_quant_test"
graph = helper.make_graph(
[mul_node, matmul_node],
graph_name,
[input_tensor],
[output_tensor],
initializer=initializers,
)
model = helper.make_model(
graph, opset_imports=[helper.make_opsetid("", 13)]
)
model.ir_version = onnx.IR_VERSION

onnx.save(model, output_model_path)

class TestWeightOnlyQuantization(unittest.TestCase):
@classmethod
def setUpClass(cls):
# TODO: there will be a refactor to handle all those temporary directories.
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.quant.save.as.external")
cls._model_fp32_path = str(Path(cls._tmp_model_dir.name) / "fp32.onnx")
cls._model_weight_only_path = str(Path(cls._tmp_model_dir.name) / "fp32.weight_only_quant.onnx")
np.random.seed(1)
construct_model(cls._model_fp32_path)

@classmethod
def tearDownClass(cls):
cls._tmp_model_dir.cleanup()

@unittest.skip(
"Skip failed test in Python Packaging Test Pipeline."
"During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed"
)
def test_quantize_weight_only_rtn(self):
if not find_spec("neural_compressor"):
self.skipTest("skip test_quantize_weight_only_rtn since neural_compressor is not installed")

Check warning

Code scanning / lintrunner

RUFF/W293 Warning test

Blank line contains whitespace.
See https://beta.ruff.rs/docs/rules/
weight_only_config = RTNWeightOnlyQuantConfig()
quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config)
check_model_correctness(
self,
self._model_fp32_path,
self._model_weight_only_path,
{"input": np.random.rand(1, 32).astype(np.float32)},
)

model_fp32 = ONNXModel(onnx.load(self._model_fp32_path))
model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path))
self.assertNotEqual(model_fp32.get_initializer("matmul.weight"),

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
model_weight_only.get_initializer("matmul.weight"))


@unittest.skip(
"Skip failed test in Python Packaging Test Pipeline."
"During importing neural_compressor, pycocotools throws ValueError: numpy.ndarray size changed"
)
def test_quantize_weight_only_gptq(self):
if not find_spec("neural_compressor"):
self.skipTest("skip test_quantize_weight_only_gptq since neural_compressor is not installed")

Check warning

Code scanning / lintrunner

RUFF/W293 Warning test

Blank line contains whitespace.
See https://beta.ruff.rs/docs/rules/
data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, 32]})
weight_only_config = GPTQWeightOnlyQuantConfig(data_reader)
quantize_weight_only(self._model_fp32_path, self._model_weight_only_path, weight_only_config)
check_model_correctness(
self,
self._model_fp32_path,
self._model_weight_only_path,
{"input": np.random.rand(1, 32).astype(np.float32)},
)

model_fp32 = ONNXModel(onnx.load(self._model_fp32_path))
model_weight_only = ONNXModel(onnx.load(self._model_weight_only_path))
self.assertNotEqual(model_fp32.get_initializer("matmul.weight"),

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Trailing whitespace.
See https://beta.ruff.rs/docs/rules/
model_weight_only.get_initializer("matmul.weight"))

if __name__ == '__main__':
unittest.main()
Loading