diff --git a/onnxscript/function_libs/torch_lib/_flags.py b/onnxscript/function_libs/torch_lib/_flags.py new file mode 100644 index 000000000..209b34fa4 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/_flags.py @@ -0,0 +1,41 @@ +"""Experimental flags. + +NOTE: These flags are experimental only. Any flag here can be removed at any +time without notice. +""" + +import logging +import os + +logger = logging.getLogger(__name__) + + +def _load_boolean_flag( + name: str, + *, + this_will: str, + deprecated: bool = False, +) -> bool: + """Load a boolean flag from environment variable. + + Args: + name: The name of the environment variable. + this_will: A string that describes what this flag will do. + deprecated: Whether this flag is deprecated. + """ + state = os.getenv(name) == "1" + if state: + if deprecated: + logger.error( + "Experimental flag %s is deprecated. Please remove it from your environment.", + name, + ) + else: + logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will) + return state + + +EXPERIMENTAL_INITIALIZERS_AS_INPUTS: bool = _load_boolean_flag( + "TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS", + this_will="make initializers as inputs to the model graph", +) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index b873d310f..449f8a7ed 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -21,6 +21,7 @@ from onnxscript import evaluator from onnxscript import tensor as onnxscript_tensor from onnxscript._internal import param_manipulation, runtime_typing +from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import common as common_ops __all__ = [ @@ -750,13 +751,15 @@ def to_model_proto( large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD export_kwargs: dict[str, Any] = dict( - initializers=self.initializers if include_initializers else {}, + initializers=self.initializers + if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS + else {}, onnx_opset_version=opset_version, dynamic_axes={}, defer_weight_export=False, operator_export_type=torch.onnx.OperatorExportTypes.ONNX, strip_doc_string=False, - keep_initializers_as_inputs=False, + keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS, custom_opsets={}, add_node_names=True, node_attr_to_name={},