-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ModelProto support for transformers optimize_model (#19990)
### Description Add `ModelProto` support as an input to transformers `optimize_model` API. ### Motivation and Context Currently, the `optimize_model` API only accepts a model path as the input model. However, for large models, saving and loading from disk can be time-consuming. By adding `ModelProto` as an input option to the `optimize_model` API, significant time can be saved.
- Loading branch information
1 parent
3076b56
commit 71551da
Showing
3 changed files
with
140 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
from fusion_utils import NumpyHelper | ||
from onnx import ModelProto, TensorProto | ||
from onnx.external_data_helper import set_external_data | ||
from onnx_model import OnnxModel | ||
|
||
from onnxruntime import OrtValue | ||
|
||
|
||
def extract_raw_data_from_model(model: ModelProto): | ||
""" | ||
Extract external data from model and return the external data as a list of tuples (name, value). | ||
Note this function does not handle external data that is not loaded into the model as raw data. | ||
Args: | ||
model (ModelProto): the model proto to extract external data from. | ||
Returns: | ||
(external_names, external_values): a tuple of two lists of external data names and values. | ||
""" | ||
external_data = [] | ||
onnx_model = OnnxModel(model) | ||
for graph in onnx_model.graphs(): | ||
for initializer in graph.initializer: | ||
name = initializer.name | ||
|
||
if initializer.HasField("raw_data"): | ||
numpy_tensor = NumpyHelper.to_array(initializer) | ||
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) | ||
external_data.append((name, ort_value)) | ||
# mimic set_external_data | ||
set_external_data(initializer, location="foo.bin") | ||
initializer.name = name | ||
initializer.ClearField("raw_data") | ||
|
||
return zip(*external_data) | ||
|
||
|
||
def has_external_data(model: ModelProto): | ||
""" | ||
Check if the model has external data. | ||
Args: | ||
model (ModelProto): the model proto to check for external data. | ||
Returns: | ||
bool: True if the model has external data, False otherwise. | ||
""" | ||
onnx_model = OnnxModel(model) | ||
for graph in onnx_model.graphs(): | ||
for initializer in graph.initializer: | ||
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
import unittest | ||
|
||
import numpy | ||
from onnx import ModelProto, TensorProto, helper | ||
from onnx.external_data_helper import set_external_data | ||
|
||
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data | ||
|
||
|
||
class TestOnnxUtils(unittest.TestCase): | ||
def test_extract_raw_data_from_model(self): | ||
model = self._get_model_proto_with_raw_data(False) | ||
external_names, external_values = extract_raw_data_from_model(model) | ||
self.assertEqual(list(external_names), ["inputs"]) | ||
self.assertEqual(len(external_values), 1) | ||
self.assertEqual(external_values[0].numpy(), [0.0]) | ||
|
||
def test_has_external_data(self): | ||
model = self._get_model_proto_with_raw_data() | ||
self.assertTrue(has_external_data(model)) | ||
|
||
def test_has_external_data_with_no_external_data(self): | ||
model = self._get_model_proto_with_raw_data(False) | ||
self.assertFalse(has_external_data(model)) | ||
|
||
def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto: | ||
input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None]) | ||
output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None]) | ||
raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes() | ||
tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True) | ||
if has_external_data: | ||
set_external_data(tensor, location="foo.bin") | ||
node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"]) | ||
return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor])) |