Skip to content

Commit

Permalink
Add ModelProto support for transformers optimize_model (#19990)
Browse files Browse the repository at this point in the history
### Description
Add `ModelProto` support as an input to transformers `optimize_model`
API.



### Motivation and Context
Currently, the `optimize_model` API only accepts a model path as the
input model. However, for large models, saving and loading from disk can
be time-consuming. By adding `ModelProto` as an input option to the
`optimize_model` API, significant time can be saved.
  • Loading branch information
xiaoyu-work authored Mar 23, 2024
1 parent 3076b56 commit 71551da
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 22 deletions.
55 changes: 55 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from fusion_utils import NumpyHelper
from onnx import ModelProto, TensorProto
from onnx.external_data_helper import set_external_data
from onnx_model import OnnxModel

from onnxruntime import OrtValue


def extract_raw_data_from_model(model: ModelProto):
"""
Extract external data from model and return the external data as a list of tuples (name, value).
Note this function does not handle external data that is not loaded into the model as raw data.
Args:
model (ModelProto): the model proto to extract external data from.
Returns:
(external_names, external_values): a tuple of two lists of external data names and values.
"""
external_data = []
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
name = initializer.name

if initializer.HasField("raw_data"):
numpy_tensor = NumpyHelper.to_array(initializer)
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
external_data.append((name, ort_value))
# mimic set_external_data
set_external_data(initializer, location="foo.bin")
initializer.name = name
initializer.ClearField("raw_data")

return zip(*external_data)


def has_external_data(model: ModelProto):
"""
Check if the model has external data.
Args:
model (ModelProto): the model proto to check for external data.
Returns:
bool: True if the model has external data, False otherwise.
"""
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
return True
return False
69 changes: 47 additions & 22 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import logging
import os
import tempfile
from typing import Dict, List, Optional
from pathlib import Path
from typing import Dict, List, Optional, Union

import coloredlogs
from fusion_options import FusionOptions
from onnx import ModelProto, TensorProto, load_model
from onnx import ModelProto, load_model
from onnx_model import OnnxModel
from onnx_model_bart import BartOnnxModel
from onnx_model_bert import BertOnnxModel
Expand All @@ -40,6 +41,9 @@
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel

import onnxruntime
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data

logger = logging.getLogger(__name__)

# Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level
Expand All @@ -64,7 +68,7 @@


def optimize_by_onnxruntime(
onnx_model_path: str,
onnx_model: Union[str, ModelProto],
use_gpu: bool = False,
optimized_model_path: Optional[str] = None,
opt_level: Optional[int] = 99,
Expand All @@ -80,7 +84,7 @@ def optimize_by_onnxruntime(
Use onnxruntime to optimize model.
Args:
onnx_model_path (str): the path of input onnx model.
onnx_model (str | ModelProto): the path of input onnx model or ModelProto.
use_gpu (bool): whether the optimized model is targeted to run in GPU.
optimized_model_path (str or None): the path of optimized model.
opt_level (int): graph optimization level.
Expand All @@ -95,8 +99,6 @@ def optimize_by_onnxruntime(
assert opt_level in [1, 2, 99]
from torch import version as torch_version

import onnxruntime

if (
use_gpu
and provider is None
Expand All @@ -105,9 +107,13 @@ def optimize_by_onnxruntime(
)
):
logger.error("There is no gpu for onnxruntime to do optimization.")
return onnx_model_path
return onnx_model

model = OnnxModel(load_model(onnx_model_path, load_external_data=False))
model = (
OnnxModel(load_model(onnx_model, load_external_data=False))
if isinstance(onnx_model, str)
else OnnxModel(onnx_model)
)
if model.use_float16() and not use_gpu:
logger.warning(
"This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. "
Expand All @@ -125,7 +131,10 @@ def optimize_by_onnxruntime(
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

if optimized_model_path is None:
path_prefix = onnx_model_path[:-5] # remove .onnx suffix
if isinstance(onnx_model, str):
path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix
else:
path_prefix = "optimized_model"
optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu")

sess_options.optimized_model_filepath = optimized_model_path
Expand Down Expand Up @@ -174,7 +183,20 @@ def optimize_by_onnxruntime(
else:
providers.append("CUDAExecutionProvider")

onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs)
# For large model, extract external data from model and add to session options
if isinstance(onnx_model, ModelProto):
if has_external_data(onnx_model):
raise ValueError(
"ModelProto has external data not loaded into memory, ORT cannot create session. "
"Please load external data before calling this function. "
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
)
external_names, external_values = extract_raw_data_from_model(onnx_model)
sess_options.add_external_initializers(list(external_names), list(external_values))

# Inference session is only used to optimize the model.
onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model
onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs)

assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path)
logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path)
Expand All @@ -187,7 +209,7 @@ def optimize_by_fusion(
num_heads: int = 0,
hidden_size: int = 0,
optimization_options: Optional[FusionOptions] = None,
):
) -> OnnxModel:
"""Optimize Model by graph fusion logic.
Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable
Expand Down Expand Up @@ -241,7 +263,7 @@ def optimize_by_fusion(


def optimize_model(
input: str,
input: Union[str, ModelProto],
model_type: str = "bert",
num_heads: int = 0,
hidden_size: int = 0,
Expand All @@ -252,7 +274,7 @@ def optimize_model(
verbose: bool = False,
*,
provider: Optional[str] = None,
):
) -> OnnxModel:
"""Optimize Model by OnnxRuntime and/or python fusion logic.
ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html).
Expand All @@ -275,7 +297,7 @@ def optimize_model(
For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters.
Args:
input (str): input model path.
input (str | ModelProto): input model path or ModelProto.
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically.
Expand All @@ -298,9 +320,9 @@ def optimize_model(

if model_type not in MODEL_TYPES:
logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.")
return OnnxModel(load_model(input))
return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input)

(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]
(optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type]

if opt_level is None:
opt_level = default_opt_level
Expand All @@ -316,11 +338,9 @@ def optimize_model(

# Auto detect if input model has external data
has_external_data_file = False
original_model = load_model(input, load_external_data=False)
for initializer in original_model.graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
has_external_data_file = True
break
original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input
if has_external_data(original_model):
has_external_data_file = True
del original_model

if opt_level > 1:
Expand Down Expand Up @@ -365,7 +385,12 @@ def optimize_model(
if only_onnxruntime and not temp_model_path:
logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True")

model = load_model(temp_model_path or input)
if temp_model_path is not None:
model = load_model(temp_model_path)
elif isinstance(input, str):
model = load_model(input)
else:
model = input

if only_onnxruntime:
optimizer = optimizer_class(model, num_heads, hidden_size)
Expand Down
38 changes: 38 additions & 0 deletions onnxruntime/test/python/transformers/test_onnx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import unittest

import numpy
from onnx import ModelProto, TensorProto, helper
from onnx.external_data_helper import set_external_data

from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data


class TestOnnxUtils(unittest.TestCase):
def test_extract_raw_data_from_model(self):
model = self._get_model_proto_with_raw_data(False)
external_names, external_values = extract_raw_data_from_model(model)
self.assertEqual(list(external_names), ["inputs"])
self.assertEqual(len(external_values), 1)
self.assertEqual(external_values[0].numpy(), [0.0])

def test_has_external_data(self):
model = self._get_model_proto_with_raw_data()
self.assertTrue(has_external_data(model))

def test_has_external_data_with_no_external_data(self):
model = self._get_model_proto_with_raw_data(False)
self.assertFalse(has_external_data(model))

def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto:
input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None])
output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None])
raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes()
tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True)
if has_external_data:
set_external_data(tensor, location="foo.bin")
node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"])
return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor]))

0 comments on commit 71551da

Please sign in to comment.