Skip to content

Commit

Permalink
Update method name & add subgraph check
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Mar 21, 2024
1 parent 41bf9cd commit e00fd13
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from onnx_model_vae import VaeOnnxModel

import onnxruntime

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'onnxruntime' is imported with both 'import' and 'import from'.
from onnxruntime.transformers.optimizer_utils import extract_external_data_from_model, has_external_data
from onnxruntime.transformers.optimizer_utils import extract_raw_data_from_model, has_external_data

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -191,7 +191,7 @@ def optimize_by_onnxruntime(
"Please load external data before calling this function. "
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
)
external_names, external_values = extract_external_data_from_model(onnx_model)
external_names, external_values = extract_raw_data_from_model(onnx_model)
sess_options.add_external_initializers(list(external_names), list(external_values))

# Inference session is only used to optimize the model.
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/python/tools/transformers/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from onnxruntime import OrtValue


def extract_external_data_from_model(model: ModelProto):
def extract_raw_data_from_model(model: ModelProto):
"""
Extract external data from model and return the external data as a list of tuples (name, value).
Note this function does not handle external data that is not loaded into the model as raw data.
Args:
model (ModelProto): the model proto to extract external data from.
Expand Down Expand Up @@ -46,7 +47,9 @@ def has_external_data(model: ModelProto):
Returns:
bool: True if the model has external data, False otherwise.
"""
return any(
initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL
for initializer in model.graph.initializer
)
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
return True
return False
6 changes: 3 additions & 3 deletions onnxruntime/test/python/transformers/test_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from onnx import ModelProto, TensorProto, helper
from onnx.external_data_helper import set_external_data

from onnxruntime.transformers.optimizer_utils import extract_external_data_from_model, has_external_data
from onnxruntime.transformers.optimizer_utils import extract_raw_data_from_model, has_external_data


class TestOptimizerUtils(unittest.TestCase):
def test_extract_external_data_from_model(self):
def test_extract_raw_data_from_model(self):
model = self._get_model_proto_with_raw_data()
external_names, external_values = extract_external_data_from_model(model)
external_names, external_values = extract_raw_data_from_model(model)
self.assertEqual(list(external_names), ["inputs"])
self.assertEqual(len(external_values), 1)
self.assertEqual(external_values[0].numpy(), [0.0])
Expand Down

0 comments on commit e00fd13

Please sign in to comment.