Skip to content

Commit

Permalink
Experimental flag for controlling initializers as inputs | feat(torch…
Browse files Browse the repository at this point in the history
…lib) (#1112)

Create a flag `TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS=1` to mark
initializers as inputs to the model.

**This flag is experimental only for ONNX Runtime training and should
not be assumed to exist in production.**

Reference:
microsoft/onnx-converters-private#182

Tested with the open-llama model with `transformers==4.31.0` with script
https://gist.github.com/abock/2115d34d98df15a77516e8a2899b121c


![image](https://github.com/microsoft/onnxscript/assets/11205048/fe2d565f-bd2f-449d-82ba-9f865b335177)
  • Loading branch information
justinchuby authored Oct 25, 2023
1 parent 67f790b commit 754accc
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
41 changes: 41 additions & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Experimental flags.
NOTE: These flags are experimental only. Any flag here can be removed at any
time without notice.
"""

import logging
import os

logger = logging.getLogger(__name__)


def _load_boolean_flag(
name: str,
*,
this_will: str,
deprecated: bool = False,
) -> bool:
"""Load a boolean flag from environment variable.
Args:
name: The name of the environment variable.
this_will: A string that describes what this flag will do.
deprecated: Whether this flag is deprecated.
"""
state = os.getenv(name) == "1"
if state:
if deprecated:
logger.error(
"Experimental flag %s is deprecated. Please remove it from your environment.",
name,
)
else:
logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will)
return state


EXPERIMENTAL_INITIALIZERS_AS_INPUTS: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
this_will="make initializers as inputs to the model graph",
)
7 changes: 5 additions & 2 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from onnxscript import evaluator
from onnxscript import tensor as onnxscript_tensor
from onnxscript._internal import param_manipulation, runtime_typing
from onnxscript.function_libs.torch_lib import _flags
from onnxscript.function_libs.torch_lib.ops import common as common_ops

__all__ = [
Expand Down Expand Up @@ -750,13 +751,15 @@ def to_model_proto(
large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers if include_initializers else {},
initializers=self.initializers
if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS
else {},
onnx_opset_version=opset_version,
dynamic_axes={},
defer_weight_export=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
strip_doc_string=False,
keep_initializers_as_inputs=False,
keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS,
custom_opsets={},
add_node_names=True,
node_attr_to_name={},
Expand Down

0 comments on commit 754accc

Please sign in to comment.