Skip to content

Commit

Permalink
Add ModelProto support for quantize api
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Mar 21, 2024
1 parent a4ac727 commit 5c234a5
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import logging
import tempfile
from pathlib import Path
from typing import Union

from onnx import ModelProto

from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
from .onnx_quantizer import ONNXQuantizer
Expand Down Expand Up @@ -280,8 +283,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua


def quantize_static(
model_input,
model_output,
model_input: Union[str, Path, ModelProto],
model_output: Union[str, Path],
calibration_data_reader: CalibrationDataReader,
quant_format=QuantFormat.QDQ,
op_types_to_quantize=None,
Expand All @@ -304,7 +307,7 @@ def quantize_static(
Args:
model_input: file path of model to quantize
model_input: file path of model or ModelProto to quantize
model_output: file path of quantized model
calibration_data_reader: a calibration data reader. It
enumerates calibration data and generates inputs for the
Expand Down Expand Up @@ -435,7 +438,7 @@ def quantize_static(
qdq_ops = list(QDQRegistry.keys())
op_types_to_quantize = list(set(q_linear_ops + qdq_ops))

model = load_model_with_shape_infer(Path(model_input))
model = model_input if isinstance(model_input, ModelProto) else load_model_with_shape_infer(Path(model_input))

pre_processed: bool = model_has_pre_process_metadata(model)
if not pre_processed:
Expand Down Expand Up @@ -479,14 +482,14 @@ def inc_dataloader():
del dataloader
model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True))
sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.")
model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
model.save(model_input)
sq_model_path = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
model.save(sq_model_path)
nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes])
model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration
model = load_model_with_shape_infer(Path(sq_model_path)) # use smooth quant model for calibration

with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
calibrator = create_calibrator(
Path(model_input),
Path(sq_model_path),
op_types_to_quantize,
augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(),
calibrate_method=calibrate_method,
Expand Down Expand Up @@ -546,8 +549,8 @@ def inc_dataloader():


def quantize_dynamic(
model_input: Path,
model_output: Path,
model_input: Union[str, Path, ModelProto],
model_output: Union[str, Path],
op_types_to_quantize=None,
per_channel=False,
reduce_range=False,
Expand All @@ -560,7 +563,7 @@ def quantize_dynamic(
"""Given an onnx model, create a quantized onnx model and save it into a file
Args:
model_input: file path of model to quantize
model_input: file path of model or ModelProto to quantize
model_output: file path of quantized model
op_types_to_quantize:
specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
Expand Down Expand Up @@ -609,7 +612,7 @@ def quantize_dynamic(
if not op_types_to_quantize or len(op_types_to_quantize) == 0:
op_types_to_quantize = list(IntegerOpsRegistry.keys())

model = load_model_with_shape_infer(Path(model_input))
model = model_input if isinstance(model_input, ModelProto) else load_model_with_shape_infer(Path(model_input))

pre_processed: bool = model_has_pre_process_metadata(model)
if not pre_processed:
Expand Down Expand Up @@ -642,14 +645,14 @@ def quantize_dynamic(


def quantize(
model_input: Path,
model_output: Path,
model_input: Union[str, Path, ModelProto],
model_output: Union[str, Path],
quant_config: QuantConfig,
):
"""Quantize a model with QuantConfig.
Args:
model_input (Path): Path to the model to quantize.
model_input (Path): Path to the model or ModelProto to quantize.
model_output (Path): Path to save the quantized model.
quant_config (QuantConfig): Quantization Configuration.
"""
Expand Down

0 comments on commit 5c234a5

Please sign in to comment.