-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/m…
…oe_q
- Loading branch information
Showing
5 changed files
with
271 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
from fusion_utils import NumpyHelper | ||
from onnx import ModelProto, TensorProto | ||
from onnx.external_data_helper import set_external_data | ||
from onnx_model import OnnxModel | ||
|
||
from onnxruntime import OrtValue | ||
|
||
|
||
def extract_raw_data_from_model(model: ModelProto): | ||
""" | ||
Extract external data from model and return the external data as a list of tuples (name, value). | ||
Note this function does not handle external data that is not loaded into the model as raw data. | ||
Args: | ||
model (ModelProto): the model proto to extract external data from. | ||
Returns: | ||
(external_names, external_values): a tuple of two lists of external data names and values. | ||
""" | ||
external_data = [] | ||
onnx_model = OnnxModel(model) | ||
for graph in onnx_model.graphs(): | ||
for initializer in graph.initializer: | ||
name = initializer.name | ||
|
||
if initializer.HasField("raw_data"): | ||
numpy_tensor = NumpyHelper.to_array(initializer) | ||
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) | ||
external_data.append((name, ort_value)) | ||
# mimic set_external_data | ||
set_external_data(initializer, location="foo.bin") | ||
initializer.name = name | ||
initializer.ClearField("raw_data") | ||
|
||
return zip(*external_data) | ||
|
||
|
||
def has_external_data(model: ModelProto): | ||
""" | ||
Check if the model has external data. | ||
Args: | ||
model (ModelProto): the model proto to check for external data. | ||
Returns: | ||
bool: True if the model has external data, False otherwise. | ||
""" | ||
onnx_model = OnnxModel(model) | ||
for graph in onnx_model.graphs(): | ||
for initializer in graph.initializer: | ||
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: | ||
return True | ||
return False |
Oops, something went wrong.