Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModelProto support for transformers optimize_model #19990

Merged
merged 11 commits into from
Mar 23, 2024
64 changes: 52 additions & 12 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import os
import tempfile
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import coloredlogs
from fusion_options import FusionOptions
Expand Down Expand Up @@ -64,7 +64,7 @@


def optimize_by_onnxruntime(
onnx_model_path: str,
onnx_model: Union[str, ModelProto],
use_gpu: bool = False,
optimized_model_path: Optional[str] = None,
opt_level: Optional[int] = 99,
Expand All @@ -80,7 +80,7 @@
Use onnxruntime to optimize model.

Args:
onnx_model_path (str): the path of input onnx model.
onnx_model (str | ModelProto): the path of input onnx model or ModelProto.
use_gpu (bool): whether the optimized model is targeted to run in GPU.
optimized_model_path (str or None): the path of optimized model.
opt_level (int): graph optimization level.
Expand All @@ -105,9 +105,13 @@
)
):
logger.error("There is no gpu for onnxruntime to do optimization.")
return onnx_model_path
return onnx_model

model = OnnxModel(load_model(onnx_model_path, load_external_data=False))
model = (
OnnxModel(load_model(onnx_model, load_external_data=False))
if isinstance(onnx_model, str)
else OnnxModel(onnx_model)
)
if model.use_float16() and not use_gpu:
logger.warning(
"This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. "
Expand All @@ -125,7 +129,10 @@
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

if optimized_model_path is None:
path_prefix = onnx_model_path[:-5] # remove .onnx suffix
if isinstance(onnx_model, str):
path_prefix = onnx_model[:-5] # remove .onnx suffix
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved
else:
path_prefix = "optimized_model"
optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu")

sess_options.optimized_model_filepath = optimized_model_path
Expand Down Expand Up @@ -174,13 +181,41 @@
else:
providers.append("CUDAExecutionProvider")

onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs)
if isinstance(onnx_model, str):
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved
onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs)
elif isinstance(onnx_model, ModelProto):
load_infer_session_from_modelproto(onnx_model, save_as_external_data, sess_options, providers, kwargs)

assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path)
logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path)
return optimized_model_path


def load_infer_session_from_modelproto(model, has_external_data, sess_options, providers, kwargs) -> None:
import onnxruntime
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
from onnxruntime.python.onnxruntime_inference_collection import OrtValue
from onnxruntime.python.tools.transformers.fusion_utils import NumpyHelper

external_data = []
for tensor in model.graph.initializer:
name = tensor.name

if has_external_data:
raise ValueError("Model has external data, model path is required to load the inference session.")
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved

logger.info("externalizing tensor: %s", name)
if tensor.HasField("raw_data"):
npt = NumpyHelper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved
external_data.append((name, orv))
tensor.name = name
tensor.ClearField("raw_data")

external_names, external_values = zip(*external_data)
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
sess_options.add_external_initializers(list(external_names), list(external_values))
onnxruntime.InferenceSession(model.SerializeToString(), sess_options=sess_options, providers=providers, **kwargs)


def optimize_by_fusion(
model: ModelProto,
model_type: str = "bert",
Expand Down Expand Up @@ -241,7 +276,7 @@


def optimize_model(
input: str,
input: Union[str, ModelProto],
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
model_type: str = "bert",
num_heads: int = 0,
hidden_size: int = 0,
Expand Down Expand Up @@ -275,7 +310,7 @@
For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters.

Args:
input (str): input model path.
input (str | ModelProto): input model path or ModelProto.
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically.
Expand All @@ -298,7 +333,7 @@

if model_type not in MODEL_TYPES:
logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.")
return OnnxModel(load_model(input))
return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input)

(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]

Expand All @@ -316,7 +351,7 @@

# Auto detect if input model has external data
has_external_data_file = False
original_model = load_model(input, load_external_data=False)
original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input
for initializer in original_model.graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
has_external_data_file = True
Expand Down Expand Up @@ -365,7 +400,12 @@
if only_onnxruntime and not temp_model_path:
logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True")

model = load_model(temp_model_path or input)
if temp_model_path is not None:
model = load_model(temp_model_path)
elif isinstance(input, str):
model = load_model(input)
else:
model = input

if only_onnxruntime:
optimizer = optimizer_class(model, num_heads, hidden_size)
Expand Down
Loading