Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModelProto support for quantize api #20018

Merged
merged 8 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


def qnn_preprocess_model(
model_input: Path,
model_output: Path,
model_input: str | Path | onnx.ModelProto,
model_output: str | Path,
fuse_layernorm: bool = False,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = False,
Expand All @@ -37,7 +37,7 @@ def qnn_preprocess_model(
- (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.

Args:
model_input: Path to the input model file.
model_input: Path to the input model file or ModelProto.
model_output: Path the output model file, which is only created if this method returns True.
fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
Defaults to False.
Expand Down Expand Up @@ -82,7 +82,7 @@ def qnn_preprocess_model(
to cancel out.
"""
modified = False
model = onnx.load_model(model_input)
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input)
onnx_model = ONNXModel(model)

# Fuse Erf sequence into a single Gelu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def warn_unable_to_override(


def get_qnn_qdq_config(
model_input: Path,
model_input: str | Path | onnx.ModelProto,
calibration_data_reader: CalibrationDataReader,
calibrate_method=CalibrationMethod.MinMax,
activation_type=QuantType.QUInt8,
Expand All @@ -56,7 +56,11 @@ def get_qnn_qdq_config(
if weight_symmetric is None:
weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16}

model = onnx.load_model(model_input, load_external_data=False)
model = (
model_input
if isinstance(model_input, onnx.ModelProto)
else onnx.load_model(model_input, load_external_data=False)
)

op_types = set()
model_has_external_data = False
Expand Down
45 changes: 33 additions & 12 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import logging
import tempfile
from pathlib import Path
from typing import Union

import onnx

from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
from .onnx_quantizer import ONNXQuantizer
Expand All @@ -16,6 +19,7 @@
QuantType,
load_model_with_shape_infer,
model_has_pre_process_metadata,
save_and_reload_model_with_shape_infer,
)
from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry

Expand Down Expand Up @@ -280,8 +284,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua


def quantize_static(
model_input,
model_output,
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
calibration_data_reader: CalibrationDataReader,
quant_format=QuantFormat.QDQ,
op_types_to_quantize=None,
Expand All @@ -304,7 +308,7 @@ def quantize_static(

Args:

model_input: file path of model to quantize
model_input: file path of model or ModelProto to quantize
model_output: file path of quantized model
calibration_data_reader: a calibration data reader. It
enumerates calibration data and generates inputs for the
Expand Down Expand Up @@ -435,7 +439,11 @@ def quantize_static(
qdq_ops = list(QDQRegistry.keys())
op_types_to_quantize = list(set(q_linear_ops + qdq_ops))

model = load_model_with_shape_infer(Path(model_input))
model = (
save_and_reload_model_with_shape_infer(model_input)
if isinstance(model_input, onnx.ModelProto)
else load_model_with_shape_infer(Path(model_input))
)

pre_processed: bool = model_has_pre_process_metadata(model)
if not pre_processed:
Expand Down Expand Up @@ -485,6 +493,15 @@ def inc_dataloader():
model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration

with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
if isinstance(model_input, onnx.ModelProto):
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
output_path = str(Path(quant_tmp_dir) / "model_input.onnx")
onnx.save_model(
model_input,
output_path,
save_as_external_data=True,
)
model_input = output_path

calibrator = create_calibrator(
Path(model_input),
op_types_to_quantize,
Expand Down Expand Up @@ -546,8 +563,8 @@ def inc_dataloader():


def quantize_dynamic(
model_input: Path,
model_output: Path,
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
op_types_to_quantize=None,
per_channel=False,
reduce_range=False,
Expand All @@ -560,7 +577,7 @@ def quantize_dynamic(
"""Given an onnx model, create a quantized onnx model and save it into a file

Args:
model_input: file path of model to quantize
model_input: file path of model or ModelProto to quantize
model_output: file path of quantized model
op_types_to_quantize:
specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
Expand Down Expand Up @@ -609,7 +626,11 @@ def quantize_dynamic(
if not op_types_to_quantize or len(op_types_to_quantize) == 0:
op_types_to_quantize = list(IntegerOpsRegistry.keys())

model = load_model_with_shape_infer(Path(model_input))
model = (
save_and_reload_model_with_shape_infer(model_input)
if isinstance(model_input, onnx.ModelProto)
else load_model_with_shape_infer(Path(model_input))
)

pre_processed: bool = model_has_pre_process_metadata(model)
if not pre_processed:
Expand Down Expand Up @@ -642,15 +663,15 @@ def quantize_dynamic(


def quantize(
model_input: Path,
model_output: Path,
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
quant_config: QuantConfig,
):
"""Quantize a model with QuantConfig.

Args:
model_input (Path): Path to the model to quantize.
model_output (Path): Path to save the quantized model.
model_input (str | Path | ModelProto): Path to the model or ModelProto to quantize.
model_output (str | Path): Path to save the quantized model.
quant_config (QuantConfig): Quantization Configuration.
"""

Expand Down
55 changes: 40 additions & 15 deletions onnxruntime/python/tools/quantization/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
import tempfile
import traceback
from pathlib import Path
from typing import Optional
from typing import Optional, Union

import onnx

import onnxruntime
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data

from .quant_utils import add_pre_process_metadata

logger = logging.getLogger(__name__)


def quant_pre_process(
input_model_path: str,
output_model_path: str,
input_model: Union[str, Path, onnx.ModelProto],
output_model_path: Union[str, Path],
skip_optimization: bool = False,
skip_onnx_shape: bool = False,
skip_symbolic_shape: bool = False,
Expand All @@ -39,7 +40,7 @@ def quant_pre_process(
"""Shape inference and model optimization, in preparation for quantization.

Args:
input_model_path: Path to the input model file")
input_model: Path to the input model file or ModelProto
output_model_path: Path to the output model file
skip_optimization: Skip model optimization step if true. This may result in ONNX shape
inference failure for some models.
Expand Down Expand Up @@ -68,8 +69,9 @@ def quant_pre_process(

if not skip_symbolic_shape:
logger.info("Performing symbolic shape inference...")
loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
model = SymbolicShapeInference.infer_shapes(
onnx.load(input_model_path),
loaded_model,
int_max,
auto_merge,
guess_output_rank,
Expand All @@ -80,26 +82,38 @@ def quant_pre_process(
# Use ORT optimizers (native code) to optimize model
if not skip_symbolic_shape:
# Need to save the inferenced model to file so as to run the optimizer
input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
if save_as_external_data:
onnx.save_model(
model,
input_model_path,
input_model,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
onnx.save(model, input_model)
model = None

opt_model_path = str(temp_path / "optimized.onnx")
try:
sess_option = onnxruntime.SessionOptions()
sess_option.optimized_model_filepath = opt_model_path
sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"])
# For large model, extract external data from model and add to session options
if isinstance(input_model, onnx.ModelProto):
if has_external_data(input_model):
raise ValueError(
"ModelProto has external data not loaded into memory, ORT cannot create session. "
"Please load external data before calling this function. "
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
)
external_names, external_values = extract_raw_data_from_model(input_model)
sess_option.add_external_initializers(list(external_names), list(external_values))
input_model = input_model.SerializeToString()

sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
# Close the session to avoid the cleanup error on Windows for temp folders
# https://github.com/microsoft/onnxruntime/issues/17627
del sess
Expand All @@ -109,34 +123,45 @@ def quant_pre_process(
)
logger.error(traceback.format_exc())

input_model_path = opt_model_path
input_model = opt_model_path

if not skip_onnx_shape:
# ONNX shape inference.
# According to docs, infer_shapes_path should be used for 2G+ models.
# If the skip optimization is specified, we could be dealing with a
# large model. So be on the safe side, save the model
if model is not None:
input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved
if save_as_external_data:
onnx.save_model(
model,
input_model_path,
input_model,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)
else:
onnx.save(model, input_model_path)
onnx.save(model, input_model)
model = None

if isinstance(input_model, onnx.ModelProto):
input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
onnx.save_model(
model,
input_model,
save_as_external_data=True,
all_tensors_to_one_file=all_tensors_to_one_file,
size_threshold=external_data_size_threshold,
convert_attribute=False,
)

inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
onnx.shape_inference.infer_shapes_path(input_model_path, inferred_model_path)
onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
model = onnx.load(inferred_model_path)

if model is None:
model = onnx.load(input_model_path)
model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)

add_pre_process_metadata(model)

Expand Down
Loading