Skip to content

Commit

Permalink
Remove check_model and shape_infer in graph building | feat(torchlib) (
Browse files Browse the repository at this point in the history
…#1226)

Remove check_model and shape_infer in graph building because
serialization adds overhead and impacts runtime when calling
to_model_proto.
  • Loading branch information
justinchuby authored Dec 15, 2023
1 parent c57b520 commit 0c0dc1b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 22 deletions.
19 changes: 0 additions & 19 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Graph building functions for torchscript graph backend."""
from __future__ import annotations

import logging
import os
import tempfile
import typing
import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1028,21 +1026,4 @@ def to_model_proto(
common_ops.common_opset.domain, common_ops.common_opset.version
)
)

try:
if not cache_model_to_disk:
# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
# we cannot serialize the model.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
logging.debug(
"ONNX model:\n%s\n\nTorchScript graph:\n%s",
onnx.printer.to_text(onnx_model),
self.torch_graph,
)
return onnx_model
6 changes: 3 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
# We need to set the size of the output tensors for the ONNX model to be valid
for output, symbolic_output in zip(outputs, symbolic_outputs):
if isinstance(output, Sequence):
# Output is a sequence, set the type correctly to ListType
symbolic_output.dtype = output[0].dtype
symbolic_output.symbolic_value().setType(torch.ListType.ofTensors())
# Output is a sequence, skip setting the type and leave it
# for ONNX shape_inference to handle
continue
output = (
output
Expand All @@ -521,6 +520,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
onnxscript_graph.register_outputs(symbolic_outputs)

onnx_model = onnxscript_graph.to_model_proto(TEST_OPSET_VERSION)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model, data_prop=True)
# Make sure the model is valid
try:
onnx.checker.check_model(onnx_model, full_check=True)
Expand Down

0 comments on commit 0c0dc1b

Please sign in to comment.