Skip to content

Commit

Permalink
Create the Rank and IsScalar shared functions | feat(torchlib) (#1105)
Browse files Browse the repository at this point in the history
This change introduces two shared operators `Rank` and `IsScalar`. They
are used to replace the `Size(Shape())` pattern for code reuse and
readability.

I used a hack to always include these shared functions in the model
proto because without #834 we cannot dynamically add these functions to
the model as they are used. I added a TODO for this.

The first usage is in `aten_all`. I will update the rest of the
functions in a separate PR.

#1095
  • Loading branch information
justinchuby authored Oct 24, 2023
1 parent 0035390 commit 9fb0a7d
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestDeduceTypeConstraints(unittest.TestCase):
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
)
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ()
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",)

@parameterized.parameterized.expand(
((op,) for op in torch_lib_onnx_functions_from_registry()),
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/function_libs/torch_lib/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Shared constants for the library."""

DOMAIN = "pkg.onnxscript.torch_lib"
20 changes: 19 additions & 1 deletion onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from onnxscript import evaluator
from onnxscript import tensor as onnxscript_tensor
from onnxscript._internal import param_manipulation, runtime_typing
from onnxscript.function_libs.torch_lib.ops import common as common_ops

__all__ = [
"TorchScriptTensor",
Expand Down Expand Up @@ -363,6 +364,16 @@ def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
return tensor.numel() * tensor.element_size()


def _shared_functions() -> list[onnx.FunctionProto]:
"""Hack to always include the share ops."""

# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
return [
common_ops.Rank.to_function_proto(),
common_ops.IsScalar.to_function_proto(),
]


class TorchScriptGraph:
def __init__(
self,
Expand Down Expand Up @@ -717,7 +728,6 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func
opset_imports=onnx_model.opset_import,
doc_string=onnx_model.doc_string,
)
# TODO: onnx.checker.check_function(onnx_function)?
return onnx_function

@runtime_typing.checked
Expand Down Expand Up @@ -786,6 +796,7 @@ def to_model_proto(
onnx_model = onnx.load_from_string(proto)

onnx_model.functions.extend(function_proto_dict.values())
onnx_model.functions.extend(_shared_functions())

# `_export_onnx` only exports opset_imports that is visible to it. It does not
# export opset_imports for nested functions, since it does not have access to
Expand All @@ -800,6 +811,13 @@ def to_model_proto(
for domain, version in unique_custom_domains.items()
]
)
# Include the library shared opset domain
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
onnx_model.opset_import.append(
onnx.helper.make_opsetid(
common_ops.common_opset.domain, common_ops.common_opset.version
)
)

try:
if not cache_model_to_disk:
Expand Down
25 changes: 25 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Common operators shared in the torchlib library."""

import onnxscript
import onnxscript.values
from onnxscript import BOOL, INT64
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing

DOMAIN = f"{_constants.DOMAIN}.common"

common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1)


@onnxscript.script(common_opset)
def Rank(input: tensor_typing.TTensor) -> INT64:
"""Take the rank of the input tensor."""

return op.Size(op.Shape(input))


@onnxscript.script(common_opset)
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
"""Return whether the input has rank 0, or is a scalar."""

return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))
6 changes: 4 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
UINT64,
graph,
)
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
Expand All @@ -52,6 +53,8 @@
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
_MATH_PI = math.pi
IsScalar = common_ops.IsScalar
Rank = common_ops.Rank


@torch_op("aten::_local_scalar_dense")
Expand Down Expand Up @@ -320,8 +323,7 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType:
def aten_all(self: TTensor) -> BOOL:
"""all(Tensor self) -> Tensor"""

self_rank = op.Size(op.Shape(self))
if self_rank == 0:
if IsScalar(self):
result = op.Cast(self, to=BOOL.dtype)
else:
self_bool = op.Cast(self, to=BOOL.dtype)
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Generator, Optional

import onnxscript
from onnxscript.function_libs.torch_lib import _constants

# Regex that will match "<namespace>::<op_name>[.<overload>]"
_QUALIFIED_OPERATOR_NAME_REGEX = re.compile(
Expand Down Expand Up @@ -119,7 +120,7 @@ def wrapper(
func: FunctionType,
) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction:
# Compile the function
custom_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1)
custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1)

processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction
if trace_only:
Expand Down

0 comments on commit 9fb0a7d

Please sign in to comment.