Skip to content

Commit

Permalink
Update to allow large models to be checked for mobile support. (#18357)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Update usability checker and related infrastructure to support checking
models > 2GB.
- Add ability to set flag to keep initializers as external data
- we optimize the model as part of the checking so need to write out a
new copy.
- Handle issue with ONNX shape inferencing silently failing
- use API that supports large models but requires writing the model to a
new file
  - automate cleanup of that copy of the model

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Allow analysis of LLMs to determine gaps for mobile usage.

---------

Co-authored-by: Edward Chen <[email protected]>
  • Loading branch information
skottmckay and edgchen1 authored Nov 16, 2023
1 parent b6b9aff commit e7a524f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import sys

import onnx
from onnx import shape_inference

from ..onnx_model_utils import get_opsets_imported
from ..onnx_model_utils import ModelProtoWithShapeInfo, get_opsets_imported
from ..reduced_build_config_parser import parse_config

cpp_to_tensorproto_type = {
Expand Down Expand Up @@ -265,15 +264,13 @@ def run_check(model_path: pathlib.Path, mobile_pkg_build_config: pathlib.Path, l
)

model_file = model_path.resolve(strict=True)
model = onnx.load(str(model_file))

# we need to run shape inferencing to populate that type info for node outputs.
# we will get warnings if the model uses ORT contrib ops (ONNX does not have shape inferencing for those),
# and shape inferencing will be lost downstream of those.
# TODO: add support for checking ORT format model as it will have full type/shape info for all nodes
model_with_type_info = shape_inference.infer_shapes(model)

return run_check_with_model(model_with_type_info, mobile_pkg_build_config, logger)
model_wrapper = ModelProtoWithShapeInfo(model_file)
return run_check_with_model(model_wrapper.model_with_shape_info, mobile_pkg_build_config, logger)


def main():
Expand Down
11 changes: 6 additions & 5 deletions tools/python/util/mobile_helpers/usability_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import onnx

from ..onnx_model_utils import (
ModelProtoWithShapeInfo,
get_producer_consumer_maps,
is_fixed_size_tensor,
iterate_graph_per_graph_func,
Expand Down Expand Up @@ -464,9 +465,9 @@ def check_shapes(graph: onnx.GraphProto, logger: Optional[logging.Logger] = None
return dynamic_inputs, num_dynamic_values


def checker(model_path, logger: logging.Logger):
model = onnx.load(model_path)
model_with_shape_info = onnx.shape_inference.infer_shapes(model)
def checker(model_path: pathlib.Path, logger: logging.Logger):
model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path)
model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info

# create lookup map for efficiency
value_to_shape = {}
Expand Down Expand Up @@ -541,10 +542,10 @@ def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger:
with tempfile.TemporaryDirectory() as tmp:
if not skip_optimize:
tmp_path = pathlib.Path(tmp) / model_path.name
optimize_model(model_path, tmp_path)
optimize_model(model_path, tmp_path, use_external_initializers=True)
model_path = tmp_path

try_eps = checker(str(model_path.resolve(strict=True)), logger)
try_eps = checker(model_path.resolve(strict=True), logger)

return try_eps

Expand Down
45 changes: 45 additions & 0 deletions tools/python/util/onnx_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def optimize_model(
output_path: pathlib.Path,
level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
log_level: int = 3,
use_external_initializers: bool = False,
):
"""
Optimize an ONNX model using ONNX Runtime to the specified level
Expand All @@ -103,12 +104,25 @@ def optimize_model(
:param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC.
:param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed.
Warning (2) or Info (1) may be desirable in some scenarios.
:param use_external_initializers: Set flag to write initializers to an external file. Required if model > 2GB.
Requires onnxruntime 1.17+
"""
so = ort.SessionOptions()
so.optimized_model_filepath = str(output_path.resolve())
so.graph_optimization_level = level
so.log_severity_level = log_level

# save using external initializers so models > 2 GB are handled
if use_external_initializers:
major, minor, rest = ort.__version__.split(".", 3)
if (int(major), int(minor)) >= (1, 17):
so.add_session_config_entry("session.optimized_model_external_initializers_file_name", "external_data.pb")
else:
raise ValueError(
"ONNX Runtime 1.17 or higher required to save initializers as external data when optimizing model. "
f"Current ONNX Runtime version is {ort.__version__}"
)

# create session to optimize. this will write the updated model to output_path
_ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"])

Expand Down Expand Up @@ -366,3 +380,34 @@ def get_optimization_level(level):
return ort.GraphOptimizationLevel.ORT_ENABLE_ALL

raise ValueError("Invalid optimization level of " + level)


class ModelProtoWithShapeInfo:
"""
Class to load an ONNX model and run shape inferencing on it to populate the ValueInfo.
The model_with_shape_info property will contain the updated model.
If the model is > 2GB and uses external data a temporary file is required to run shape inferencing successfully.
This helper class handles automatic removal of the temporary file.
"""

def __init__(self, model_path: pathlib.Path):
"""
:param model_path: Path to ONNX model to load and run shape inferencing on.
"""

self.model_path = model_path

model = onnx.load(str(model_path))
self.model_with_shape_info = onnx.shape_inference.infer_shapes(model, strict_mode=True)

# ONNX has a silent failure from the call to infer_shapes when the model is > 2GB.
# We detect that by checking the nodes in the returned model.
self._tmp_model_path = None
if len(model.graph.node) > 0 and len(self.model_with_shape_info.graph.node) == 0:
self._tmp_model_path = pathlib.Path(model_path).with_suffix(".temp_with_shapeinf.onnx")
onnx.shape_inference.infer_shapes_path(str(model_path), str(self._tmp_model_path), strict_mode=True)
self.model_with_shape_info = onnx.load(str(self._tmp_model_path))

def __del__(self):
if self._tmp_model_path:
self._tmp_model_path.unlink(missing_ok=True)

0 comments on commit e7a524f

Please sign in to comment.