diff --git a/onnxruntime/python/tools/transformers/onnx_utils.py b/onnxruntime/python/tools/transformers/onnx_utils.py new file mode 100644 index 0000000000000..64fade9369395 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_utils.py @@ -0,0 +1,55 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from fusion_utils import NumpyHelper +from onnx import ModelProto, TensorProto +from onnx.external_data_helper import set_external_data +from onnx_model import OnnxModel + +from onnxruntime import OrtValue + + +def extract_raw_data_from_model(model: ModelProto): + """ + Extract external data from model and return the external data as a list of tuples (name, value). + Note this function does not handle external data that is not loaded into the model as raw data. + + Args: + model (ModelProto): the model proto to extract external data from. + Returns: + (external_names, external_values): a tuple of two lists of external data names and values. + """ + external_data = [] + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + name = initializer.name + + if initializer.HasField("raw_data"): + numpy_tensor = NumpyHelper.to_array(initializer) + ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) + external_data.append((name, ort_value)) + # mimic set_external_data + set_external_data(initializer, location="foo.bin") + initializer.name = name + initializer.ClearField("raw_data") + + return zip(*external_data) + + +def has_external_data(model: ModelProto): + """ + Check if the model has external data. + + Args: + model (ModelProto): the model proto to check for external data. + Returns: + bool: True if the model has external data, False otherwise. + """ + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: + return True + return False diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..068ccefef7d97 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -21,11 +21,12 @@ import logging import os import tempfile -from typing import Dict, List, Optional +from pathlib import Path +from typing import Dict, List, Optional, Union import coloredlogs from fusion_options import FusionOptions -from onnx import ModelProto, TensorProto, load_model +from onnx import ModelProto, load_model from onnx_model import OnnxModel from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel @@ -40,6 +41,9 @@ from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel +import onnxruntime +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + logger = logging.getLogger(__name__) # Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level @@ -64,7 +68,7 @@ def optimize_by_onnxruntime( - onnx_model_path: str, + onnx_model: Union[str, ModelProto], use_gpu: bool = False, optimized_model_path: Optional[str] = None, opt_level: Optional[int] = 99, @@ -80,7 +84,7 @@ def optimize_by_onnxruntime( Use onnxruntime to optimize model. Args: - onnx_model_path (str): the path of input onnx model. + onnx_model (str | ModelProto): the path of input onnx model or ModelProto. use_gpu (bool): whether the optimized model is targeted to run in GPU. optimized_model_path (str or None): the path of optimized model. opt_level (int): graph optimization level. @@ -95,8 +99,6 @@ def optimize_by_onnxruntime( assert opt_level in [1, 2, 99] from torch import version as torch_version - import onnxruntime - if ( use_gpu and provider is None @@ -105,9 +107,13 @@ def optimize_by_onnxruntime( ) ): logger.error("There is no gpu for onnxruntime to do optimization.") - return onnx_model_path + return onnx_model - model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) + model = ( + OnnxModel(load_model(onnx_model, load_external_data=False)) + if isinstance(onnx_model, str) + else OnnxModel(onnx_model) + ) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -125,7 +131,10 @@ def optimize_by_onnxruntime( sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL if optimized_model_path is None: - path_prefix = onnx_model_path[:-5] # remove .onnx suffix + if isinstance(onnx_model, str): + path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix + else: + path_prefix = "optimized_model" optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu") sess_options.optimized_model_filepath = optimized_model_path @@ -174,7 +183,20 @@ def optimize_by_onnxruntime( else: providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) + # For large model, extract external data from model and add to session options + if isinstance(onnx_model, ModelProto): + if has_external_data(onnx_model): + raise ValueError( + "ModelProto has external data not loaded into memory, ORT cannot create session. " + "Please load external data before calling this function. " + "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information." + ) + external_names, external_values = extract_raw_data_from_model(onnx_model) + sess_options.add_external_initializers(list(external_names), list(external_values)) + + # Inference session is only used to optimize the model. + onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model + onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -187,7 +209,7 @@ def optimize_by_fusion( num_heads: int = 0, hidden_size: int = 0, optimization_options: Optional[FusionOptions] = None, -): +) -> OnnxModel: """Optimize Model by graph fusion logic. Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable @@ -241,7 +263,7 @@ def optimize_by_fusion( def optimize_model( - input: str, + input: Union[str, ModelProto], model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, @@ -252,7 +274,7 @@ def optimize_model( verbose: bool = False, *, provider: Optional[str] = None, -): +) -> OnnxModel: """Optimize Model by OnnxRuntime and/or python fusion logic. ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). @@ -275,7 +297,7 @@ def optimize_model( For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters. Args: - input (str): input model path. + input (str | ModelProto): input model path or ModelProto. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. 0 allows detect the parameter from graph automatically. @@ -298,9 +320,9 @@ def optimize_model( if model_type not in MODEL_TYPES: logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.") - return OnnxModel(load_model(input)) + return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input) - (optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type] + (optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type] if opt_level is None: opt_level = default_opt_level @@ -316,11 +338,9 @@ def optimize_model( # Auto detect if input model has external data has_external_data_file = False - original_model = load_model(input, load_external_data=False) - for initializer in original_model.graph.initializer: - if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: - has_external_data_file = True - break + original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input + if has_external_data(original_model): + has_external_data_file = True del original_model if opt_level > 1: @@ -365,7 +385,12 @@ def optimize_model( if only_onnxruntime and not temp_model_path: logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True") - model = load_model(temp_model_path or input) + if temp_model_path is not None: + model = load_model(temp_model_path) + elif isinstance(input, str): + model = load_model(input) + else: + model = input if only_onnxruntime: optimizer = optimizer_class(model, num_heads, hidden_size) diff --git a/onnxruntime/test/python/transformers/test_onnx_utils.py b/onnxruntime/test/python/transformers/test_onnx_utils.py new file mode 100644 index 0000000000000..974991359795e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_utils.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import unittest + +import numpy +from onnx import ModelProto, TensorProto, helper +from onnx.external_data_helper import set_external_data + +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + + +class TestOnnxUtils(unittest.TestCase): + def test_extract_raw_data_from_model(self): + model = self._get_model_proto_with_raw_data(False) + external_names, external_values = extract_raw_data_from_model(model) + self.assertEqual(list(external_names), ["inputs"]) + self.assertEqual(len(external_values), 1) + self.assertEqual(external_values[0].numpy(), [0.0]) + + def test_has_external_data(self): + model = self._get_model_proto_with_raw_data() + self.assertTrue(has_external_data(model)) + + def test_has_external_data_with_no_external_data(self): + model = self._get_model_proto_with_raw_data(False) + self.assertFalse(has_external_data(model)) + + def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto: + input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None]) + output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None]) + raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes() + tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True) + if has_external_data: + set_external_data(tensor, location="foo.bin") + node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"]) + return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor]))