From 9fb0a7d73ad1715b1d40cef3f7bc3412892973fa Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Oct 2023 12:08:48 -0700 Subject: [PATCH] Create the Rank and IsScalar shared functions | feat(torchlib) (#1105) This change introduces two shared operators `Rank` and `IsScalar`. They are used to replace the `Size(Shape())` pattern for code reuse and readability. I used a hack to always include these shared functions in the model proto because without #834 we cannot dynamically add these functions to the model as they are used. I added a TODO for this. The first usage is in `aten_all`. I will update the rest of the functions in a separate PR. #1095 --- .../torch_lib/deduce_type_constraints_test.py | 2 +- .../function_libs/torch_lib/_constants.py | 3 +++ .../function_libs/torch_lib/graph_building.py | 20 ++++++++++++++- .../function_libs/torch_lib/ops/common.py | 25 +++++++++++++++++++ .../function_libs/torch_lib/ops/core.py | 6 +++-- .../function_libs/torch_lib/registration.py | 3 ++- 6 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 onnxscript/function_libs/torch_lib/_constants.py create mode 100644 onnxscript/function_libs/torch_lib/ops/common.py diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index 0d68981b8..a2882d283 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -30,7 +30,7 @@ class TestDeduceTypeConstraints(unittest.TestCase): "_aten_embedding_bag_onnx", "_aten_embedding_bag_1d_padding_idx_onnx", ) - _SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = () + _SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",) @parameterized.parameterized.expand( ((op,) for op in torch_lib_onnx_functions_from_registry()), diff --git a/onnxscript/function_libs/torch_lib/_constants.py b/onnxscript/function_libs/torch_lib/_constants.py new file mode 100644 index 000000000..58cc2c068 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/_constants.py @@ -0,0 +1,3 @@ +"""Shared constants for the library.""" + +DOMAIN = "pkg.onnxscript.torch_lib" diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 438c45603..b873d310f 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -21,6 +21,7 @@ from onnxscript import evaluator from onnxscript import tensor as onnxscript_tensor from onnxscript._internal import param_manipulation, runtime_typing +from onnxscript.function_libs.torch_lib.ops import common as common_ops __all__ = [ "TorchScriptTensor", @@ -363,6 +364,16 @@ def _tensor_rawdata_size(tensor: torch.Tensor) -> int: return tensor.numel() * tensor.element_size() +def _shared_functions() -> list[onnx.FunctionProto]: + """Hack to always include the share ops.""" + + # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed + return [ + common_ops.Rank.to_function_proto(), + common_ops.IsScalar.to_function_proto(), + ] + + class TorchScriptGraph: def __init__( self, @@ -717,7 +728,6 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func opset_imports=onnx_model.opset_import, doc_string=onnx_model.doc_string, ) - # TODO: onnx.checker.check_function(onnx_function)? return onnx_function @runtime_typing.checked @@ -786,6 +796,7 @@ def to_model_proto( onnx_model = onnx.load_from_string(proto) onnx_model.functions.extend(function_proto_dict.values()) + onnx_model.functions.extend(_shared_functions()) # `_export_onnx` only exports opset_imports that is visible to it. It does not # export opset_imports for nested functions, since it does not have access to @@ -800,6 +811,13 @@ def to_model_proto( for domain, version in unique_custom_domains.items() ] ) + # Include the library shared opset domain + # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed + onnx_model.opset_import.append( + onnx.helper.make_opsetid( + common_ops.common_opset.domain, common_ops.common_opset.version + ) + ) try: if not cache_model_to_disk: diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py new file mode 100644 index 000000000..ba481b6f1 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -0,0 +1,25 @@ +"""Common operators shared in the torchlib library.""" + +import onnxscript +import onnxscript.values +from onnxscript import BOOL, INT64 +from onnxscript import opset18 as op +from onnxscript.function_libs.torch_lib import _constants, tensor_typing + +DOMAIN = f"{_constants.DOMAIN}.common" + +common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1) + + +@onnxscript.script(common_opset) +def Rank(input: tensor_typing.TTensor) -> INT64: + """Take the rank of the input tensor.""" + + return op.Size(op.Shape(input)) + + +@onnxscript.script(common_opset) +def IsScalar(input: tensor_typing.TTensor) -> BOOL: + """Return whether the input has rank 0, or is a scalar.""" + + return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0)) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 47dfbeb10..6d770254e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -30,6 +30,7 @@ UINT64, graph, ) +from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -52,6 +53,8 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi +IsScalar = common_ops.IsScalar +Rank = common_ops.Rank @torch_op("aten::_local_scalar_dense") @@ -320,8 +323,7 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType: def aten_all(self: TTensor) -> BOOL: """all(Tensor self) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index a2267d70f..57c20964c 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Generator, Optional import onnxscript +from onnxscript.function_libs.torch_lib import _constants # Regex that will match "::[.]" _QUALIFIED_OPERATOR_NAME_REGEX = re.compile( @@ -119,7 +120,7 @@ def wrapper( func: FunctionType, ) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction: # Compile the function - custom_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1) + custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1) processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction if trace_only: