Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModelProto support for transformers optimize_model #19990

Merged
merged 11 commits into from
Mar 23, 2024
52 changes: 17 additions & 35 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Union

import coloredlogs
Expand All @@ -40,6 +41,9 @@
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel

import onnxruntime
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'onnxruntime' is imported with both 'import' and 'import from'.
from onnxruntime.python.tools.transformers.optimizer_utils import extract_external_data_from_model
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

# Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level
Expand Down Expand Up @@ -95,8 +99,6 @@
assert opt_level in [1, 2, 99]
from torch import version as torch_version

import onnxruntime

if (
use_gpu
and provider is None
Expand Down Expand Up @@ -130,7 +132,7 @@

if optimized_model_path is None:
if isinstance(onnx_model, str):
path_prefix = onnx_model[:-5] # remove .onnx suffix
path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix
else:
path_prefix = "optimized_model"
optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu")
Expand Down Expand Up @@ -181,48 +183,28 @@
else:
providers.append("CUDAExecutionProvider")

if isinstance(onnx_model, str):
onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs)
elif isinstance(onnx_model, ModelProto):
_load_infer_session_from_modelproto(onnx_model, save_as_external_data, sess_options, providers, kwargs)
# For ModelProto, we need to extract external data and add them to the session options.
if isinstance(onnx_model, ModelProto):
if save_as_external_data:
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Model has external data, model path is required to load the inference session.")
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
external_names, external_values = extract_external_data_from_model(onnx_model)
sess_options.add_external_initializers(list(external_names), list(external_values))

# Inference session is only used to optimize the model.
onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs)

assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path)
logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path)
return optimized_model_path


def _load_infer_session_from_modelproto(model, has_external_data, sess_options, providers, kwargs) -> None:
import onnxruntime
from onnxruntime_inference_collection import OrtValue
from fusion_utils import NumpyHelper

external_data = []
for tensor in model.graph.initializer:
name = tensor.name

if has_external_data:
raise ValueError("Model has external data, model path is required to load the inference session.")

logger.info("externalizing tensor: %s", name)
if tensor.HasField("raw_data"):
npt = NumpyHelper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
external_data.append((name, orv))
tensor.name = name
tensor.ClearField("raw_data")

external_names, external_values = zip(*external_data)
sess_options.add_external_initializers(list(external_names), list(external_values))
onnxruntime.InferenceSession(model.SerializeToString(), sess_options=sess_options, providers=providers, **kwargs)


def optimize_by_fusion(
model: ModelProto,
model_type: str = "bert",
num_heads: int = 0,
hidden_size: int = 0,
optimization_options: Optional[FusionOptions] = None,
):
) -> OnnxModel:
"""Optimize Model by graph fusion logic.

Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable
Expand Down Expand Up @@ -287,7 +269,7 @@
verbose: bool = False,
*,
provider: Optional[str] = None,
):
) -> OnnxModel:
"""Optimize Model by OnnxRuntime and/or python fusion logic.

ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html).
Expand Down Expand Up @@ -335,7 +317,7 @@
logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.")
return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input)

(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]
(optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type]

if opt_level is None:
opt_level = default_opt_level
Expand Down
28 changes: 28 additions & 0 deletions onnxruntime/python/tools/transformers/optimizer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from fusion_utils import NumpyHelper
Fixed Show fixed Hide fixed
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
from onnxruntime import OrtValue
from onnx.external_data_helper import set_external_data
from onnx import ModelProto

def extract_external_data_from_model(model: ModelProto):
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved
"""
Extract external data from model and return the external data as a list of tuples (name, value).

Args:
model (ModelProto): the model proto to extract external data from.
Returns:
(external_names, external_values): a tuple of two lists of external data names and values.
"""
external_data = []
for tensor in model.graph.initializer:
name = tensor.name

if tensor.HasField("raw_data"):
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
numpy_tensor = NumpyHelper.to_array(tensor)
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
external_data.append((name, ort_value))
# mimic set_external_data
set_external_data(tensor, location="foo.bin")
tensor.name = name
tensor.ClearField("raw_data")

return zip(*external_data)
22 changes: 22 additions & 0 deletions onnxruntime/test/python/transformers/test_optimizer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest
Fixed Show fixed Hide fixed
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
import numpy
from onnx import ModelProto, TensorProto, helper
from onnxruntime.python.tools.transformers.optimizer_utils import extract_external_data_from_model
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved


class TestOptimizerUtils(unittest.TestCase):
def test_extract_external_data_from_model(self):
model = self._get_model_proto_with_raw_data()
external_names, external_values = extract_external_data_from_model(model)
self.assertEqual(list(external_names), ["inputs"])
self.assertEqual(len(external_values), 1)
self.assertEqual(external_values[0].numpy(), [0.0])


def _get_model_proto_with_raw_data(self) -> ModelProto:
input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None])
output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None])
raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes()
tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True)
node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"])
return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor]))
Loading