Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModelProto support for transformers optimize_model #19990

Merged
merged 11 commits into from
Mar 23, 2024
55 changes: 55 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from fusion_utils import NumpyHelper
from onnx import ModelProto, TensorProto
from onnx.external_data_helper import set_external_data
from onnx_model import OnnxModel

from onnxruntime import OrtValue


def extract_raw_data_from_model(model: ModelProto):
"""
Extract external data from model and return the external data as a list of tuples (name, value).
Note this function does not handle external data that is not loaded into the model as raw data.

Args:
model (ModelProto): the model proto to extract external data from.
Returns:
(external_names, external_values): a tuple of two lists of external data names and values.
"""
external_data = []
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
name = initializer.name

if initializer.HasField("raw_data"):
numpy_tensor = NumpyHelper.to_array(initializer)
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
external_data.append((name, ort_value))
# mimic set_external_data
set_external_data(initializer, location="foo.bin")
initializer.name = name
initializer.ClearField("raw_data")

return zip(*external_data)


def has_external_data(model: ModelProto):
"""
Check if the model has external data.

Args:
model (ModelProto): the model proto to check for external data.
Returns:
bool: True if the model has external data, False otherwise.
"""
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
return True
return False
69 changes: 47 additions & 22 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import logging
import os
import tempfile
from typing import Dict, List, Optional
from pathlib import Path
from typing import Dict, List, Optional, Union

import coloredlogs
from fusion_options import FusionOptions
from onnx import ModelProto, TensorProto, load_model
from onnx import ModelProto, load_model
from onnx_model import OnnxModel
from onnx_model_bart import BartOnnxModel
from onnx_model_bert import BertOnnxModel
Expand All @@ -40,6 +41,9 @@
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel

import onnxruntime
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'onnxruntime' is imported with both 'import' and 'import from'.
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data

logger = logging.getLogger(__name__)

# Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level
Expand All @@ -64,7 +68,7 @@


def optimize_by_onnxruntime(
onnx_model_path: str,
onnx_model: Union[str, ModelProto],
use_gpu: bool = False,
optimized_model_path: Optional[str] = None,
opt_level: Optional[int] = 99,
Expand All @@ -80,7 +84,7 @@
Use onnxruntime to optimize model.

Args:
onnx_model_path (str): the path of input onnx model.
onnx_model (str | ModelProto): the path of input onnx model or ModelProto.
use_gpu (bool): whether the optimized model is targeted to run in GPU.
optimized_model_path (str or None): the path of optimized model.
opt_level (int): graph optimization level.
Expand All @@ -95,8 +99,6 @@
assert opt_level in [1, 2, 99]
from torch import version as torch_version

import onnxruntime

if (
use_gpu
and provider is None
Expand All @@ -105,9 +107,13 @@
)
):
logger.error("There is no gpu for onnxruntime to do optimization.")
return onnx_model_path
return onnx_model

model = OnnxModel(load_model(onnx_model_path, load_external_data=False))
model = (
OnnxModel(load_model(onnx_model, load_external_data=False))
if isinstance(onnx_model, str)
else OnnxModel(onnx_model)
)
if model.use_float16() and not use_gpu:
logger.warning(
"This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. "
Expand All @@ -125,7 +131,10 @@
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

if optimized_model_path is None:
path_prefix = onnx_model_path[:-5] # remove .onnx suffix
if isinstance(onnx_model, str):
path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix
else:
path_prefix = "optimized_model"
optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu")

sess_options.optimized_model_filepath = optimized_model_path
Expand Down Expand Up @@ -174,7 +183,20 @@
else:
providers.append("CUDAExecutionProvider")

onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs)
# For large model, extract external data from model and add to session options
if isinstance(onnx_model, ModelProto):
if has_external_data(onnx_model):
raise ValueError(
"ModelProto has external data not loaded into memory, ORT cannot create session. "
"Please load external data before calling this function. "
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
)
external_names, external_values = extract_raw_data_from_model(onnx_model)
sess_options.add_external_initializers(list(external_names), list(external_values))

# Inference session is only used to optimize the model.
onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model
onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs)

assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path)
logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path)
Expand All @@ -187,7 +209,7 @@
num_heads: int = 0,
hidden_size: int = 0,
optimization_options: Optional[FusionOptions] = None,
):
) -> OnnxModel:
"""Optimize Model by graph fusion logic.

Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable
Expand Down Expand Up @@ -241,7 +263,7 @@


def optimize_model(
input: str,
input: Union[str, ModelProto],
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
model_type: str = "bert",
num_heads: int = 0,
hidden_size: int = 0,
Expand All @@ -252,7 +274,7 @@
verbose: bool = False,
*,
provider: Optional[str] = None,
):
) -> OnnxModel:
"""Optimize Model by OnnxRuntime and/or python fusion logic.

ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html).
Expand All @@ -275,7 +297,7 @@
For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters.

Args:
input (str): input model path.
input (str | ModelProto): input model path or ModelProto.
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically.
Expand All @@ -298,9 +320,9 @@

if model_type not in MODEL_TYPES:
logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.")
return OnnxModel(load_model(input))
return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input)

(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]
(optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type]

if opt_level is None:
opt_level = default_opt_level
Expand All @@ -316,11 +338,9 @@

# Auto detect if input model has external data
has_external_data_file = False
original_model = load_model(input, load_external_data=False)
for initializer in original_model.graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
has_external_data_file = True
break
original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input
if has_external_data(original_model):
has_external_data_file = True
del original_model

if opt_level > 1:
Expand Down Expand Up @@ -365,7 +385,12 @@
if only_onnxruntime and not temp_model_path:
logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True")

model = load_model(temp_model_path or input)
if temp_model_path is not None:
model = load_model(temp_model_path)
elif isinstance(input, str):
model = load_model(input)
else:
model = input

if only_onnxruntime:
optimizer = optimizer_class(model, num_heads, hidden_size)
Expand Down
38 changes: 38 additions & 0 deletions onnxruntime/test/python/transformers/test_onnx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import unittest

import numpy
from onnx import ModelProto, TensorProto, helper
from onnx.external_data_helper import set_external_data

from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data


class TestOnnxUtils(unittest.TestCase):
def test_extract_raw_data_from_model(self):
model = self._get_model_proto_with_raw_data(False)
external_names, external_values = extract_raw_data_from_model(model)
self.assertEqual(list(external_names), ["inputs"])
self.assertEqual(len(external_values), 1)
self.assertEqual(external_values[0].numpy(), [0.0])

def test_has_external_data(self):
model = self._get_model_proto_with_raw_data()
self.assertTrue(has_external_data(model))

def test_has_external_data_with_no_external_data(self):
model = self._get_model_proto_with_raw_data(False)
self.assertFalse(has_external_data(model))

def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto:
input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None])
output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None])
raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes()
tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True)
if has_external_data:
set_external_data(tensor, location="foo.bin")
node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"])
return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor]))
Loading