diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 9f90196e301e5..6293bcbbf95bd 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -4,10 +4,11 @@ # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import argparse import logging import os -from typing import List, Tuple import numpy as np import numpy.typing as npt @@ -26,16 +27,24 @@ class MatMul4BitsQuantizer: """Perform 4b quantization of constant MatMul weights""" - def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None): + def __init__( + self, + model: ModelProto, + block_size: int, + is_symmetric: bool, + accuracy_level: int | None = None, + nodes_to_exclude: list[str] | None = None, + ): if nodes_to_exclude is None: nodes_to_exclude = [] self.model = ONNXModel(model) self.block_size = block_size self.is_symmetric = is_symmetric + self.accuracy_level = accuracy_level self.nodes_to_exclude = set(nodes_to_exclude) @staticmethod - def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: for gid in range(len(graph_path) - 1, -1, -1): graph = graph_path[gid] for tensor in graph.initializer: @@ -66,7 +75,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: return (packed, scales, zero_point) - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": @@ -113,6 +122,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) kwargs["N"] = cols kwargs["bits"] = 4 kwargs["block_size"] = self.block_size + if self.accuracy_level is not None: + kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( "MatMulNBits", @@ -127,7 +138,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) return matmul_q4_node - def _process_subgraph(self, graph_stack: List[GraphProto]): + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] @@ -201,6 +212,14 @@ def parse_args(): type=bool, help="Indicate whether to quantize the model symmetrically", ) + parser.add_argument( + "--accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) parser.add_argument( @@ -228,6 +247,12 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) - quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude) + quant = MatMul4BitsQuantizer( + model=model, + block_size=args.block_size, + is_symmetric=args.symmetric, + accuracy_level=args.accuracy_level, + nodes_to_exclude=args.nodes_to_exclude, + ) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index e694b5050cc8c..bc09b52574a27 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import argparse import logging import os import shutil from itertools import chain -from typing import List import onnx import torch @@ -21,11 +22,12 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +torch_export_onnx_opset_version = 14 logger = logging.getLogger("") init_dist() -def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): +def get_model_dynamic_axes(input_names: list[str], output_names: list[str]): dynamic_axes = {} for name in input_names + output_names: if name in input_names: @@ -42,7 +44,7 @@ def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): return dynamic_axes -def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]): +def get_model_with_past_kv_dynamic_axes(input_names: list[str], output_names: list[str]): dynamic_axes = {} for name in input_names + output_names: if name in {"input_ids", "position_ids"}: @@ -65,7 +67,7 @@ def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: Li return dynamic_axes -def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): +def get_merged_model_dynamic_axes(input_names: list[str], output_names: list[str]): dynamic_axes = {} for name in input_names + output_names: if name in {"input_ids", "position_ids"}: @@ -229,7 +231,7 @@ def run_torchscript_separate_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, ) @@ -288,7 +290,7 @@ def run_torchscript_separate_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, ) @@ -368,7 +370,7 @@ def run_torchscript_merged_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, ) @@ -412,7 +414,7 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov def convert_to_float16( - args: argparse.Namespace, config: AutoConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 + args: argparse.Namespace, config: AutoConfig, old_paths: list[str], rank: int = 0, world_size: int = 1 ): decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( @@ -635,7 +637,7 @@ def get_args(): help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", ) - blockwise_group = parser.add_argument_group("4-bit quantization") + blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)") blockwise_group.add_argument( "--block_size", @@ -645,6 +647,15 @@ def get_args(): help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", ) + blockwise_group.add_argument( + "--int4_accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) + smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)") smooth_quant_group.add_argument( @@ -937,7 +948,13 @@ def main(): for fp_path, int4_path in zip(old_paths, new_paths): if os.path.exists(fp_path): model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant = MatMul4BitsQuantizer( + model=model, + block_size=args.block_size, + is_symmetric=True, + accuracy_level=args.int4_accuracy_level, + nodes_to_exclude=[], + ) quant.process() quant.model.save_model_to_file(int4_path, use_external_data_format=True) del model diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index bae1ae82e8f7e..a329b73259dda 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from __future__ import annotations import numpy as np import torch @@ -235,7 +235,7 @@ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, u # Convert list of past_key_values to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): +def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]): past_kv = {} for i, (past_k, past_v) in enumerate(past_key_values): past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 418a65325c8f0..25d7519769604 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import argparse import logging import os import time -from typing import List import numpy as np import torch @@ -139,7 +140,7 @@ def verify_parity( return kv_cache_ortvalues -def get_args(argv: List[str]): +def get_args(argv: list[str]): parser = argparse.ArgumentParser() parser.add_argument( @@ -232,7 +233,7 @@ def get_args(argv: List[str]): return args -def main(argv: List[str] = []): # noqa: B006 +def main(argv: list[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}")