Skip to content

Commit

Permalink
Add MatMulNBits accuracy_level parameter to quantization utilities. (#…
Browse files Browse the repository at this point in the history
…19015)

Allow MatMulNBits `accuracy_level` attribute (added in #17669) to be set to a particular value when the model is quantized.
  • Loading branch information
edgchen1 authored Jan 5, 2024
1 parent efdcefc commit 4190c29
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 21 deletions.
37 changes: 31 additions & 6 deletions onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# license information.
# --------------------------------------------------------------------------

from __future__ import annotations

import argparse
import logging
import os
from typing import List, Tuple

import numpy as np
import numpy.typing as npt
Expand All @@ -26,16 +27,24 @@
class MatMul4BitsQuantizer:
"""Perform 4b quantization of constant MatMul weights"""

def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None):
def __init__(
self,
model: ModelProto,
block_size: int,
is_symmetric: bool,
accuracy_level: int | None = None,
nodes_to_exclude: list[str] | None = None,
):
if nodes_to_exclude is None:
nodes_to_exclude = []
self.model = ONNXModel(model)
self.block_size = block_size
self.is_symmetric = is_symmetric
self.accuracy_level = accuracy_level
self.nodes_to_exclude = set(nodes_to_exclude)

@staticmethod
def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]:
def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
for gid in range(len(graph_path) - 1, -1, -1):
graph = graph_path[gid]
for tensor in graph.initializer:
Expand Down Expand Up @@ -66,7 +75,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:

return (packed, scales, zero_point)

def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto:
def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""

if node.op_type != "MatMul":
Expand Down Expand Up @@ -113,6 +122,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto])
kwargs["N"] = cols
kwargs["bits"] = 4
kwargs["block_size"] = self.block_size
if self.accuracy_level is not None:
kwargs["accuracy_level"] = self.accuracy_level

matmul_q4_node = onnx.helper.make_node(
"MatMulNBits",
Expand All @@ -127,7 +138,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto])

return matmul_q4_node

def _process_subgraph(self, graph_stack: List[GraphProto]):
def _process_subgraph(self, graph_stack: list[GraphProto]):
new_nodes = []
graph = graph_stack[-1]

Expand Down Expand Up @@ -201,6 +212,14 @@ def parse_args():
type=bool,
help="Indicate whether to quantize the model symmetrically",
)
parser.add_argument(
"--accuracy_level",
required=False,
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
)
parser.add_argument("-v", "--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument(
Expand Down Expand Up @@ -228,6 +247,12 @@ def parse_args():
raise Exception(f"file {output_model_path} already exists")

model = onnx.load(input_model_path)
quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude)
quant = MatMul4BitsQuantizer(
model=model,
block_size=args.block_size,
is_symmetric=args.symmetric,
accuracy_level=args.accuracy_level,
nodes_to_exclude=args.nodes_to_exclude,
)
quant.process()
quant.model.save_model_to_file(output_model_path, True)
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import argparse
import logging
import os
import shutil
from itertools import chain
from typing import List

import onnx
import torch
Expand All @@ -21,11 +22,12 @@
from onnxruntime import quantization as ort_quantization
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer

torch_export_onnx_opset_version = 14
logger = logging.getLogger("")
init_dist()


def get_model_dynamic_axes(input_names: List[str], output_names: List[str]):
def get_model_dynamic_axes(input_names: list[str], output_names: list[str]):
dynamic_axes = {}
for name in input_names + output_names:
if name in input_names:
Expand All @@ -42,7 +44,7 @@ def get_model_dynamic_axes(input_names: List[str], output_names: List[str]):
return dynamic_axes


def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]):
def get_model_with_past_kv_dynamic_axes(input_names: list[str], output_names: list[str]):
dynamic_axes = {}
for name in input_names + output_names:
if name in {"input_ids", "position_ids"}:
Expand All @@ -65,7 +67,7 @@ def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: Li
return dynamic_axes


def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]):
def get_merged_model_dynamic_axes(input_names: list[str], output_names: list[str]):
dynamic_axes = {}
for name in input_names + output_names:
if name in {"input_ids", "position_ids"}:
Expand Down Expand Up @@ -229,7 +231,7 @@ def run_torchscript_separate_export(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13,
opset_version=torch_export_onnx_opset_version,
do_constant_folding=True,
verbose=args.verbose,
)
Expand Down Expand Up @@ -288,7 +290,7 @@ def run_torchscript_separate_export(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13,
opset_version=torch_export_onnx_opset_version,
do_constant_folding=True,
verbose=args.verbose,
)
Expand Down Expand Up @@ -368,7 +370,7 @@ def run_torchscript_merged_export(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13,
opset_version=torch_export_onnx_opset_version,
do_constant_folding=True,
verbose=args.verbose,
)
Expand Down Expand Up @@ -412,7 +414,7 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov


def convert_to_float16(
args: argparse.Namespace, config: AutoConfig, old_paths: List[str], rank: int = 0, world_size: int = 1
args: argparse.Namespace, config: AutoConfig, old_paths: list[str], rank: int = 0, world_size: int = 1
):
decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx")
decoder_with_past_model_fp16_path = os.path.join(
Expand Down Expand Up @@ -635,7 +637,7 @@ def get_args():
help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
)

blockwise_group = parser.add_argument_group("4-bit quantization")
blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)")

blockwise_group.add_argument(
"--block_size",
Expand All @@ -645,6 +647,15 @@ def get_args():
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
)

blockwise_group.add_argument(
"--int4_accuracy_level",
required=False,
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
)

smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)")

smooth_quant_group.add_argument(
Expand Down Expand Up @@ -937,7 +948,13 @@ def main():
for fp_path, int4_path in zip(old_paths, new_paths):
if os.path.exists(fp_path):
model = onnx.load_model(fp_path, load_external_data=True)
quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[])
quant = MatMul4BitsQuantizer(
model=model,
block_size=args.block_size,
is_symmetric=True,
accuracy_level=args.int4_accuracy_level,
nodes_to_exclude=[],
)
quant.process()
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
del model
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from __future__ import annotations

import numpy as np
import torch
Expand Down Expand Up @@ -235,7 +235,7 @@ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, u


# Convert list of past_key_values to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]):
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
past_kv = {}
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import argparse
import logging
import os
import time
from typing import List

import numpy as np
import torch
Expand Down Expand Up @@ -139,7 +140,7 @@ def verify_parity(
return kv_cache_ortvalues


def get_args(argv: List[str]):
def get_args(argv: list[str]):
parser = argparse.ArgumentParser()

parser.add_argument(
Expand Down Expand Up @@ -232,7 +233,7 @@ def get_args(argv: List[str]):
return args


def main(argv: List[str] = []): # noqa: B006
def main(argv: list[str] = []): # noqa: B006
args = get_args(argv)
setup_logger(args.verbose)
logger.info(f"Arguments: {args}")
Expand Down

0 comments on commit 4190c29

Please sign in to comment.