diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 98c2038d9..8a4326979 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -56,9 +56,15 @@ def optimize( ) for _ in range(num_iterations): if onnx_shape_inference: - model = onnx.shape_inference.infer_shapes( - model, check_type=True, strict_mode=True, data_prop=True - ) + if model.ByteSize() < 1024 * 1024 * 1024 * 2: + model = onnx.shape_inference.infer_shapes( + model, check_type=True, strict_mode=True, data_prop=True + ) + else: + logger.warning( + "The model size is too large for full model shape inference. " + "Skipping this step." + ) inline_simple_functions(model) modified = fold_constants(