From c8676ffbff5218e226a25651b5b0c981dc5da798 Mon Sep 17 00:00:00 2001 From: Xiaoyu <85524621+xiaoyu-work@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:40:08 -0700 Subject: [PATCH] Add ModelProto support for quantize api (#20018) ### Description Add ModelProto support for `quantize` api ### Motivation and Context Currently, the `quantize` API only accepts a model path as the input model. However, for large models, saving and loading from disk can be time-consuming. By adding `ModelProto` as an input option to the `quantize` API, significant time can be saved. --- .../execution_providers/qnn/preprocess.py | 8 +-- .../execution_providers/qnn/quant_config.py | 8 ++- .../python/tools/quantization/quantize.py | 45 +++++++++++---- .../tools/quantization/shape_inference.py | 55 ++++++++++++++----- 4 files changed, 83 insertions(+), 33 deletions(-) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index e584a65574520..85f5d967f9ee3 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -16,8 +16,8 @@ def qnn_preprocess_model( - model_input: Path, - model_output: Path, + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, fuse_layernorm: bool = False, save_as_external_data: bool = False, all_tensors_to_one_file: bool = False, @@ -37,7 +37,7 @@ def qnn_preprocess_model( - (Optional) Fuse ReduceMean sequence into a single LayerNormalization node. Args: - model_input: Path to the input model file. + model_input: Path to the input model file or ModelProto. model_output: Path the output model file, which is only created if this method returns True. fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes. Defaults to False. @@ -82,7 +82,7 @@ def qnn_preprocess_model( to cancel out. """ modified = False - model = onnx.load_model(model_input) + model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) onnx_model = ONNXModel(model) # Fuse Erf sequence into a single Gelu diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 479eaf5b0c542..3a217fdfaaffd 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -39,7 +39,7 @@ def warn_unable_to_override( def get_qnn_qdq_config( - model_input: Path, + model_input: str | Path | onnx.ModelProto, calibration_data_reader: CalibrationDataReader, calibrate_method=CalibrationMethod.MinMax, activation_type=QuantType.QUInt8, @@ -56,7 +56,11 @@ def get_qnn_qdq_config( if weight_symmetric is None: weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16} - model = onnx.load_model(model_input, load_external_data=False) + model = ( + model_input + if isinstance(model_input, onnx.ModelProto) + else onnx.load_model(model_input, load_external_data=False) + ) op_types = set() model_has_external_data = False diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 9b0c15e4b4dde..9ebd7bf3c408a 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -6,6 +6,9 @@ import logging import tempfile from pathlib import Path +from typing import Union + +import onnx from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator from .onnx_quantizer import ONNXQuantizer @@ -16,6 +19,7 @@ QuantType, load_model_with_shape_infer, model_has_pre_process_metadata, + save_and_reload_model_with_shape_infer, ) from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry @@ -280,8 +284,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua def quantize_static( - model_input, - model_output, + model_input: Union[str, Path, onnx.ModelProto], + model_output: Union[str, Path], calibration_data_reader: CalibrationDataReader, quant_format=QuantFormat.QDQ, op_types_to_quantize=None, @@ -304,7 +308,7 @@ def quantize_static( Args: - model_input: file path of model to quantize + model_input: file path of model or ModelProto to quantize model_output: file path of quantized model calibration_data_reader: a calibration data reader. It enumerates calibration data and generates inputs for the @@ -435,7 +439,11 @@ def quantize_static( qdq_ops = list(QDQRegistry.keys()) op_types_to_quantize = list(set(q_linear_ops + qdq_ops)) - model = load_model_with_shape_infer(Path(model_input)) + model = ( + save_and_reload_model_with_shape_infer(model_input) + if isinstance(model_input, onnx.ModelProto) + else load_model_with_shape_infer(Path(model_input)) + ) pre_processed: bool = model_has_pre_process_metadata(model) if not pre_processed: @@ -485,6 +493,15 @@ def inc_dataloader(): model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: + if isinstance(model_input, onnx.ModelProto): + output_path = str(Path(quant_tmp_dir) / "model_input.onnx") + onnx.save_model( + model_input, + output_path, + save_as_external_data=True, + ) + model_input = output_path + calibrator = create_calibrator( Path(model_input), op_types_to_quantize, @@ -546,8 +563,8 @@ def inc_dataloader(): def quantize_dynamic( - model_input: Path, - model_output: Path, + model_input: Union[str, Path, onnx.ModelProto], + model_output: Union[str, Path], op_types_to_quantize=None, per_channel=False, reduce_range=False, @@ -560,7 +577,7 @@ def quantize_dynamic( """Given an onnx model, create a quantized onnx model and save it into a file Args: - model_input: file path of model to quantize + model_input: file path of model or ModelProto to quantize model_output: file path of quantized model op_types_to_quantize: specify the types of operators to quantize, like ['Conv'] to quantize Conv only. @@ -609,7 +626,11 @@ def quantize_dynamic( if not op_types_to_quantize or len(op_types_to_quantize) == 0: op_types_to_quantize = list(IntegerOpsRegistry.keys()) - model = load_model_with_shape_infer(Path(model_input)) + model = ( + save_and_reload_model_with_shape_infer(model_input) + if isinstance(model_input, onnx.ModelProto) + else load_model_with_shape_infer(Path(model_input)) + ) pre_processed: bool = model_has_pre_process_metadata(model) if not pre_processed: @@ -642,15 +663,15 @@ def quantize_dynamic( def quantize( - model_input: Path, - model_output: Path, + model_input: Union[str, Path, onnx.ModelProto], + model_output: Union[str, Path], quant_config: QuantConfig, ): """Quantize a model with QuantConfig. Args: - model_input (Path): Path to the model to quantize. - model_output (Path): Path to save the quantized model. + model_input (str | Path | ModelProto): Path to the model or ModelProto to quantize. + model_output (str | Path): Path to save the quantized model. quant_config (QuantConfig): Quantization Configuration. """ diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index b7d4726610387..7368304837a96 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -9,12 +9,13 @@ import tempfile import traceback from pathlib import Path -from typing import Optional +from typing import Optional, Union import onnx import onnxruntime from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data from .quant_utils import add_pre_process_metadata @@ -22,8 +23,8 @@ def quant_pre_process( - input_model_path: str, - output_model_path: str, + input_model: Union[str, Path, onnx.ModelProto], + output_model_path: Union[str, Path], skip_optimization: bool = False, skip_onnx_shape: bool = False, skip_symbolic_shape: bool = False, @@ -39,7 +40,7 @@ def quant_pre_process( """Shape inference and model optimization, in preparation for quantization. Args: - input_model_path: Path to the input model file") + input_model: Path to the input model file or ModelProto output_model_path: Path to the output model file skip_optimization: Skip model optimization step if true. This may result in ONNX shape inference failure for some models. @@ -68,8 +69,9 @@ def quant_pre_process( if not skip_symbolic_shape: logger.info("Performing symbolic shape inference...") + loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model) model = SymbolicShapeInference.infer_shapes( - onnx.load(input_model_path), + loaded_model, int_max, auto_merge, guess_output_rank, @@ -80,18 +82,18 @@ def quant_pre_process( # Use ORT optimizers (native code) to optimize model if not skip_symbolic_shape: # Need to save the inferenced model to file so as to run the optimizer - input_model_path = str(temp_path / "symbolic_shape_inferred.onnx") + input_model = str(temp_path / "symbolic_shape_inferred.onnx") if save_as_external_data: onnx.save_model( model, - input_model_path, + input_model, save_as_external_data=True, all_tensors_to_one_file=all_tensors_to_one_file, size_threshold=external_data_size_threshold, convert_attribute=False, ) else: - onnx.save(model, input_model_path) + onnx.save(model, input_model) model = None opt_model_path = str(temp_path / "optimized.onnx") @@ -99,7 +101,19 @@ def quant_pre_process( sess_option = onnxruntime.SessionOptions() sess_option.optimized_model_filepath = opt_model_path sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + # For large model, extract external data from model and add to session options + if isinstance(input_model, onnx.ModelProto): + if has_external_data(input_model): + raise ValueError( + "ModelProto has external data not loaded into memory, ORT cannot create session. " + "Please load external data before calling this function. " + "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information." + ) + external_names, external_values = extract_raw_data_from_model(input_model) + sess_option.add_external_initializers(list(external_names), list(external_values)) + input_model = input_model.SerializeToString() + + sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"]) # Close the session to avoid the cleanup error on Windows for temp folders # https://github.com/microsoft/onnxruntime/issues/17627 del sess @@ -109,7 +123,7 @@ def quant_pre_process( ) logger.error(traceback.format_exc()) - input_model_path = opt_model_path + input_model = opt_model_path if not skip_onnx_shape: # ONNX shape inference. @@ -117,26 +131,37 @@ def quant_pre_process( # If the skip optimization is specified, we could be dealing with a # large model. So be on the safe side, save the model if model is not None: - input_model_path = str(temp_path / "symbolic_shape_inferred.onnx") + input_model = str(temp_path / "symbolic_shape_inferred.onnx") if save_as_external_data: onnx.save_model( model, - input_model_path, + input_model, save_as_external_data=True, all_tensors_to_one_file=all_tensors_to_one_file, size_threshold=external_data_size_threshold, convert_attribute=False, ) else: - onnx.save(model, input_model_path) + onnx.save(model, input_model) model = None + if isinstance(input_model, onnx.ModelProto): + input_model = str(Path(quant_tmp_dir) / "model_input.onnx") + onnx.save_model( + model, + input_model, + save_as_external_data=True, + all_tensors_to_one_file=all_tensors_to_one_file, + size_threshold=external_data_size_threshold, + convert_attribute=False, + ) + inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx") - onnx.shape_inference.infer_shapes_path(input_model_path, inferred_model_path) + onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path) model = onnx.load(inferred_model_path) if model is None: - model = onnx.load(input_model_path) + model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model) add_pre_process_metadata(model)