Skip to content

Commit

Permalink
Add ModelProto support for quantize api (#20018)
Browse files Browse the repository at this point in the history
### Description
Add ModelProto support for `quantize` api



### Motivation and Context
Currently, the `quantize` API only accepts a model path as the input
model. However, for large models, saving and loading from disk can be
time-consuming. By adding `ModelProto` as an input option to the
`quantize` API, significant time can be saved.
  • Loading branch information
xiaoyu-work authored Mar 27, 2024
1 parent 47903e7 commit c8676ff
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


def qnn_preprocess_model(
model_input: Path,
model_output: Path,
model_input: str | Path | onnx.ModelProto,
model_output: str | Path,
fuse_layernorm: bool = False,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = False,
Expand All @@ -37,7 +37,7 @@ def qnn_preprocess_model(
- (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.
Args:
model_input: Path to the input model file.
model_input: Path to the input model file or ModelProto.
model_output: Path the output model file, which is only created if this method returns True.
fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
Defaults to False.
Expand Down Expand Up @@ -82,7 +82,7 @@ def qnn_preprocess_model(
to cancel out.
"""
modified = False
model = onnx.load_model(model_input)
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input)
onnx_model = ONNXModel(model)

# Fuse Erf sequence into a single Gelu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def warn_unable_to_override(


def get_qnn_qdq_config(
model_input: Path,
model_input: str | Path | onnx.ModelProto,
calibration_data_reader: CalibrationDataReader,
calibrate_method=CalibrationMethod.MinMax,
activation_type=QuantType.QUInt8,
Expand All @@ -56,7 +56,11 @@ def get_qnn_qdq_config(
if weight_symmetric is None:
weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16}

model = onnx.load_model(model_input, load_external_data=False)
model = (
model_input
if isinstance(model_input, onnx.ModelProto)
else onnx.load_model(model_input, load_external_data=False)
)

op_types = set()
model_has_external_data = False
Expand Down
45 changes: 33 additions & 12 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import logging
import tempfile
from pathlib import Path
from typing import Union

import onnx

from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
from .onnx_quantizer import ONNXQuantizer
Expand All @@ -16,6 +19,7 @@
QuantType,
load_model_with_shape_infer,
model_has_pre_process_metadata,
save_and_reload_model_with_shape_infer,
)
from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry

Expand Down Expand Up @@ -280,8 +284,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua


def quantize_static(
model_input,
model_output,
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
calibration_data_reader: CalibrationDataReader,
quant_format=QuantFormat.QDQ,
op_types_to_quantize=None,
Expand All @@ -304,7 +308,7 @@ def quantize_static(
Args:
model_input: file path of model to quantize
model_input: file path of model or ModelProto to quantize
model_output: file path of quantized model
calibration_data_reader: a calibration data reader. It
enumerates calibration data and generates inputs for the
Expand Down Expand Up @@ -435,7 +439,11 @@ def quantize_static(
qdq_ops = list(QDQRegistry.keys())
op_types_to_quantize = list(set(q_linear_ops + qdq_ops))

model = load_model_with_shape_infer(Path(model_input))
model = (
save_and_reload_model_with_shape_infer(model_input)
if isinstance(model_input, onnx.ModelProto)
else load_model_with_shape_infer(Path(model_input))
)

pre_processed: bool = model_has_pre_process_metadata(model)
if not pre_processed:
Expand Down Expand Up @@ -485,6 +493,15 @@ def inc_dataloader():
model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration

with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
if isinstance(model_input, onnx.ModelProto):
output_path = str(Path(quant_tmp_dir) / "model_input.onnx")
onnx.save_model(
model_input,
output_path,
save_as_external_data=True,
)
model_input = output_path

calibrator = create_calibrator(
Path(model_input),
op_types_to_quantize,
Expand Down Expand Up @@ -546,8 +563,8 @@ def inc_dataloader():


def quantize_dynamic(
model_input: Path,
model_output: Path,
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
op_types_to_quantize=None,
per_channel=False,
reduce_range=False,
Expand All @@ -560,7 +577,7 @@ def quantize_dynamic(
"""Given an onnx model, create a quantized onnx model and save it into a file
Args:
model_input: file path of model to quantize
model_input: file path of model or ModelProto to quantize
model_output: file path of quantized model
op_types_to_quantize:
specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
Expand Down Expand Up @@ -609,7 +626,11 @@ def quantize_dynamic(
if not op_types_to_quantize or len(op_types_to_quantize) == 0:
op_types_to_quantize = list(IntegerOpsRegistry.keys())

model = load_model_with_shape_infer(Path(model_input))
model = (
save_and_reload_model_with_shape_infer(model_input)
if isinstance(model_input, onnx.ModelProto)
else load_model_with_shape_infer(Path(model_input))
)

pre_processed: bool = model_has_pre_process_metadata(model)
if not pre_processed:
Expand Down Expand Up @@ -642,15 +663,15 @@ def quantize_dynamic(


def quantize(
model_input: Path,
model_output: Path,
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
quant_config: QuantConfig,
):
"""Quantize a model with QuantConfig.
Args:
model_input (Path): Path to the model to quantize.
model_output (Path): Path to save the quantized model.
model_input (str | Path | ModelProto): Path to the model or ModelProto to quantize.
model_output (str | Path): Path to save the quantized model.
quant_config (QuantConfig): Quantization Configuration.
"""

Expand Down
55 changes: 40 additions & 15 deletions onnxruntime/python/tools/quantization/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
import tempfile
import traceback
from pathlib import Path
from typing import Optional
from typing import Optional, Union

import onnx

import onnxruntime
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data

from .quant_utils import add_pre_process_metadata

logger = logging.getLogger(__name__)


def quant_pre_process(
input_model_path: str,
output_model_path: str,
input_model: Union[str, Path, onnx.ModelProto],
output_model_path: Union[str, Path],
skip_optimization: bool = False,
skip_onnx_shape: bool = False,
skip_symbolic_shape: bool = False,
Expand All @@ -39,7 +40,7 @@ def quant_pre_process(
"""Shape inference and model optimization, in preparation for quantization.
Args:
input_model_path: Path to the input model file")
input_model: Path to the input model file or ModelProto
output_model_path: Path to the output model file
skip_optimization: Skip model optimization step if true. This may result in ONNX shape
inference failure for some models.
Expand Down Expand Up @@ -68,8 +69,9 @@ def quant_pre_process(

if not skip_symbolic_shape:
logger.info("Performing symbolic shape inference...")
loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
model = SymbolicShapeInference.infer_shapes(
onnx.load(input_model_path),
loaded_model,
int_max,
auto_merge,
guess_output_rank,
Expand All @@ -80,26 +82,38 @@ def quant_pre_process(
# Use ORT optimizers (native code) to optimize model
if not skip_symbolic_shape:
# Need to save the inferenced model to file so as to run the optimizer
input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
if save_as_external_data:
onnx.save_model(
model,
input_model_path,
input_model,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
onnx.save(model, input_model)
model = None

opt_model_path = str(temp_path / "optimized.onnx")
try:
sess_option = onnxruntime.SessionOptions()
sess_option.optimized_model_filepath = opt_model_path
sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"])
# For large model, extract external data from model and add to session options
if isinstance(input_model, onnx.ModelProto):
if has_external_data(input_model):
raise ValueError(
"ModelProto has external data not loaded into memory, ORT cannot create session. "
"Please load external data before calling this function. "
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
)
external_names, external_values = extract_raw_data_from_model(input_model)
sess_option.add_external_initializers(list(external_names), list(external_values))
input_model = input_model.SerializeToString()

sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
# Close the session to avoid the cleanup error on Windows for temp folders
# https://github.com/microsoft/onnxruntime/issues/17627
del sess
Expand All @@ -109,34 +123,45 @@ def quant_pre_process(
)
logger.error(traceback.format_exc())

input_model_path = opt_model_path
input_model = opt_model_path

if not skip_onnx_shape:
# ONNX shape inference.
# According to docs, infer_shapes_path should be used for 2G+ models.
# If the skip optimization is specified, we could be dealing with a
# large model. So be on the safe side, save the model
if model is not None:
input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
if save_as_external_data:
onnx.save_model(
model,
input_model_path,
input_model,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
onnx.save(model, input_model)
model = None

if isinstance(input_model, onnx.ModelProto):
input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
onnx.save_model(
model,
input_model,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)

inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
onnx.shape_inference.infer_shapes_path(input_model_path, inferred_model_path)
onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
model = onnx.load(inferred_model_path)

if model is None:
model = onnx.load(input_model_path)
model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)

add_pre_process_metadata(model)

Expand Down

0 comments on commit c8676ff

Please sign in to comment.