diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 75710e9fd..ad3e72f37 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -1,11 +1,9 @@ """Graph building functions for torchscript graph backend.""" from __future__ import annotations -import logging import os import tempfile import typing -import warnings from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -1028,21 +1026,4 @@ def to_model_proto( common_ops.common_opset.domain, common_ops.common_opset.version ) ) - - try: - if not cache_model_to_disk: - # Only check the model if it is in memory. - # Otherwise the checker and shape_inference will fail because - # we cannot serialize the model. - onnx_model = onnx.shape_inference.infer_shapes( - onnx_model, check_type=True, strict_mode=False, data_prop=True - ) - onnx.checker.check_model(onnx_model, full_check=True) - except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: - warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1) - logging.debug( - "ONNX model:\n%s\n\nTorchScript graph:\n%s", - onnx.printer.to_text(onnx_model), - self.torch_graph, - ) return onnx_model diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index 831b8190f..bc0b08e98 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -506,9 +506,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, # We need to set the size of the output tensors for the ONNX model to be valid for output, symbolic_output in zip(outputs, symbolic_outputs): if isinstance(output, Sequence): - # Output is a sequence, set the type correctly to ListType - symbolic_output.dtype = output[0].dtype - symbolic_output.symbolic_value().setType(torch.ListType.ofTensors()) + # Output is a sequence, skip setting the type and leave it + # for ONNX shape_inference to handle continue output = ( output @@ -521,6 +520,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, onnxscript_graph.register_outputs(symbolic_outputs) onnx_model = onnxscript_graph.to_model_proto(TEST_OPSET_VERSION) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model, data_prop=True) # Make sure the model is valid try: onnx.checker.check_model(onnx_model, full_check=True)