diff --git a/.lintrunner.toml b/.lintrunner.toml index 1db450f7f..a179f8a49 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -70,7 +70,7 @@ init_command = [ ] [[linter]] -code = 'BLACK-ISORT' +code = 'RUFF-FORMAT' include_patterns = [ '**/*.py', ] @@ -82,7 +82,7 @@ command = [ '-m', 'lintrunner_adapters', 'run', - 'black_isort_linter', + 'ruff_format_linter', '--', '@{{PATHSFILE}}' ] diff --git a/docs/examples/04_plot_eager_mode_evaluation.py b/docs/examples/04_plot_eager_mode_evaluation.py index 05331d75e..c9f31455c 100644 --- a/docs/examples/04_plot_eager_mode_evaluation.py +++ b/docs/examples/04_plot_eager_mode_evaluation.py @@ -18,9 +18,7 @@ @script() -def linear( - A: FLOAT["N", "K"], W: FLOAT["K", "M"], Bias: FLOAT["M"] -) -> FLOAT["N", "M"]: # noqa: F821 +def linear(A: FLOAT["N", "K"], W: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]: # noqa: F821 T1 = op.MatMul(A, W) T2 = op.Add(T1, Bias) Y = op.Relu(T2) diff --git a/noxfile.py b/noxfile.py index 1fe69025c..d141f0fec 100644 --- a/noxfile.py +++ b/noxfile.py @@ -27,8 +27,8 @@ "pyyaml", ) ONNX = "onnx==1.14.1" -ONNX_RUNTIME = "onnxruntime==1.16.0" -PYTORCH = "torch==2.0.1" +ONNX_RUNTIME = "onnxruntime==1.16.1" +PYTORCH = "torch==2.1.0" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", "coloredlogs", diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index dd3c7bf59..ae625c3ae 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -61,8 +61,7 @@ def separate_input_attributes_from_arguments( else: onnx_attributes[param.name] = kwargs[param.name] elif ( - param.is_attribute - and param.default is not values._EmptyDefault # pylint: disable=protected-access + param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access ): # User did not provide the attribute if fill_defaults: diff --git a/onnxscript/converter.py b/onnxscript/converter.py index d683082c8..b030a273e 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -21,9 +21,8 @@ import onnx import onnxscript -from onnxscript import irbuilder, onnx_types, sourceinfo +from onnxscript import irbuilder, onnx_types, sourceinfo, values from onnxscript import type_annotation as ta -from onnxscript import values from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39 diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 2e12222ee..6036641ff 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -17,6 +17,7 @@ import numpy as np import onnx import onnxruntime as ort +import pytest from onnxruntime.capi.onnxruntime_pybind11_state import ( Fail, InvalidArgument, @@ -270,7 +271,10 @@ def test_renaming(self): self.validate_save(renaming, shape_inference=False) - @unittest.skip(reason="TypeError: val must be numeric not ") + @pytest.mark.xfail( + strict=True, + reason="default_opset must be specified in script for functions that do not contain any use of an ONNX op", + ) def test_opt_output(self): from onnxscript.tests.models import opt_output diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index 0d68981b8..4e01d37ac 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase): "_aten_embedding_bag_onnx", "_aten_embedding_bag_1d_padding_idx_onnx", ) - _SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = () @parameterized.parameterized.expand( ((op,) for op in torch_lib_onnx_functions_from_registry()), @@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function( ): if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN: self.skipTest("Unimplemented: function contains loop or scan node.") - if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION: - self.skipTest("Unimplemented: function contains nested function.") - signature_type_constraint = deduce_type_constraints.deduce_type_constraints( - onnx_function - ) + try: + signature_type_constraint = deduce_type_constraints.deduce_type_constraints( + onnx_function + ) + except NotImplementedError as e: + if "Nested function" in str(e): + self.skipTest("Unimplemented: function contains nested function.") logger.info( "Original signature: %s%s", onnx_function.name, diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index ae0d45a04..14afeff36 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -12,11 +12,8 @@ import os import re import textwrap -from pathlib import Path from typing import Any, Dict, List, Sequence -import black -import isort import torch import torchgen.gen import torchgen.model @@ -319,15 +316,6 @@ def main(args: argparse.Namespace) -> None: ) py_module.accept(cg.PythonWriter(f)) - # Format the generated files so that they pass linting. - # line_length=95 is to match the lintrunner rules. - isort.file(output_path) - black.format_file_in_place( - Path(output_path), - fast=True, - mode=black.Mode(line_length=95), - write_back=black.WriteBack.YES, - ) print("Done.") diff --git a/onnxscript/function_libs/torch_lib/_constants.py b/onnxscript/function_libs/torch_lib/_constants.py new file mode 100644 index 000000000..58cc2c068 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/_constants.py @@ -0,0 +1,3 @@ +"""Shared constants for the library.""" + +DOMAIN = "pkg.onnxscript.torch_lib" diff --git a/onnxscript/function_libs/torch_lib/_flags.py b/onnxscript/function_libs/torch_lib/_flags.py new file mode 100644 index 000000000..209b34fa4 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/_flags.py @@ -0,0 +1,41 @@ +"""Experimental flags. + +NOTE: These flags are experimental only. Any flag here can be removed at any +time without notice. +""" + +import logging +import os + +logger = logging.getLogger(__name__) + + +def _load_boolean_flag( + name: str, + *, + this_will: str, + deprecated: bool = False, +) -> bool: + """Load a boolean flag from environment variable. + + Args: + name: The name of the environment variable. + this_will: A string that describes what this flag will do. + deprecated: Whether this flag is deprecated. + """ + state = os.getenv(name) == "1" + if state: + if deprecated: + logger.error( + "Experimental flag %s is deprecated. Please remove it from your environment.", + name, + ) + else: + logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will) + return state + + +EXPERIMENTAL_INITIALIZERS_AS_INPUTS: bool = _load_boolean_flag( + "TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS", + this_will="make initializers as inputs to the model graph", +) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 438c45603..5d82eb044 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -21,6 +21,8 @@ from onnxscript import evaluator from onnxscript import tensor as onnxscript_tensor from onnxscript._internal import param_manipulation, runtime_typing +from onnxscript.function_libs.torch_lib import _flags +from onnxscript.function_libs.torch_lib.ops import common as common_ops __all__ = [ "TorchScriptTensor", @@ -198,7 +200,7 @@ def symbolic_value(self) -> torch.Value: def _unwrap_tensor_to_torch_value( value: Union[ ValidArgumentType, Mapping[str, ValidArgumentType], Sequence[ValidArgumentType] - ] + ], ) -> Union[ ValidTorchValueType, Dict[str, ValidTorchValueType], @@ -363,6 +365,16 @@ def _tensor_rawdata_size(tensor: torch.Tensor) -> int: return tensor.numel() * tensor.element_size() +def _shared_functions() -> list[onnx.FunctionProto]: + """Hack to always include the share ops.""" + + # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed + return [ + common_ops.Rank.to_function_proto(), + common_ops.IsScalar.to_function_proto(), + ] + + class TorchScriptGraph: def __init__( self, @@ -717,7 +729,6 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func opset_imports=onnx_model.opset_import, doc_string=onnx_model.doc_string, ) - # TODO: onnx.checker.check_function(onnx_function)? return onnx_function @runtime_typing.checked @@ -740,13 +751,15 @@ def to_model_proto( large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD export_kwargs: dict[str, Any] = dict( - initializers=self.initializers if include_initializers else {}, + initializers=self.initializers + if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS + else {}, onnx_opset_version=opset_version, dynamic_axes={}, defer_weight_export=False, operator_export_type=torch.onnx.OperatorExportTypes.ONNX, strip_doc_string=False, - keep_initializers_as_inputs=False, + keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS, custom_opsets={}, add_node_names=True, node_attr_to_name={}, @@ -786,6 +799,7 @@ def to_model_proto( onnx_model = onnx.load_from_string(proto) onnx_model.functions.extend(function_proto_dict.values()) + onnx_model.functions.extend(_shared_functions()) # `_export_onnx` only exports opset_imports that is visible to it. It does not # export opset_imports for nested functions, since it does not have access to @@ -800,6 +814,13 @@ def to_model_proto( for domain, version in unique_custom_domains.items() ] ) + # Include the library shared opset domain + # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed + onnx_model.opset_import.append( + onnx.helper.make_opsetid( + common_ops.common_opset.domain, common_ops.common_opset.version + ) + ) try: if not cache_model_to_disk: diff --git a/onnxscript/function_libs/torch_lib/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building_test.py index d38bb52d5..ab8b4bc8f 100644 --- a/onnxscript/function_libs/torch_lib/graph_building_test.py +++ b/onnxscript/function_libs/torch_lib/graph_building_test.py @@ -11,11 +11,9 @@ import onnxscript.testing from onnxscript import FLOAT, evaluator from onnxscript import opset18 as op -from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import graph_building, ops -@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported") class TestTorchScriptTracingEvaluator(unittest.TestCase): def setUp(self): self.opset_version = 18 diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py new file mode 100644 index 000000000..ecef6852b --- /dev/null +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -0,0 +1,56 @@ +"""Common operators shared in the torchlib library.""" + +import onnxscript +import onnxscript.values +from onnxscript import BOOL, INT64 +from onnxscript import opset18 as op +from onnxscript.function_libs.torch_lib import _constants, tensor_typing +from onnxscript.function_libs.torch_lib.tensor_typing import RealType +from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT + +COMPLEX64_TYPE = COMPLEX64.dtype +COMPLEX128_TYPE = COMPLEX128.dtype + +DOMAIN = f"{_constants.DOMAIN}.common" + +common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1) + + +@onnxscript.script(common_opset) +def Rank(input: tensor_typing.TTensor) -> INT64: + """Take the rank of the input tensor.""" + + return op.Size(op.Shape(input)) + + +@onnxscript.script(common_opset) +def IsScalar(input: tensor_typing.TTensor) -> BOOL: + """Return whether the input has rank 0, or is a scalar.""" + + return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0)) + + +def cast_to(a: RealType, dtype: int) -> RealType: + """Cast input to dtype while handling complex types.""" + + # Traced function because different if branches return different dtypes + # which is not supported in an ONNX function + if dtype == COMPLEX128_TYPE: + # Cast to the real representation of the complex type + casted = op.Cast(a, to=DOUBLE.dtype) + # Create a complex number + real_part = op.Unsqueeze(casted, axes=[-1]) + imag_part = op.Expand(op.Cast(0.0, to=DOUBLE.dtype), op.Shape(real_part)) + result = op.Concat(real_part, imag_part, axis=-1) + elif dtype == COMPLEX64_TYPE: + # Cast to the real representation of the complex type + casted = op.Cast(a, to=FLOAT.dtype) + # Create a complex number + real_part = op.Unsqueeze(casted, axes=[-1]) + imag_part = op.Expand(0.0, op.Shape(real_part)) + result = op.Concat(real_part, imag_part, axis=-1) + else: + # Cast to real numbers + result = op.Cast(a, to=dtype) + + return result diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 751ec5de2..1620efe0d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -30,6 +30,7 @@ UINT64, graph, ) +from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -52,6 +53,8 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi +IsScalar = common_ops.IsScalar +Rank = common_ops.Rank @torch_op("aten::_local_scalar_dense") @@ -85,11 +88,13 @@ def aten__log_softmax_half( @torch_op("aten::_log_softmax") def aten__log_softmax( - self: TFloatHighPrecision, dim: int, half_to_float: bool # pylint: disable=unused-argument + self: TFloatHighPrecision, + dim: int, + half_to_float: bool, # pylint: disable=unused-argument ) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.LogSoftmax(self, axis=dim) @@ -121,7 +126,7 @@ def aten__softmax( return aten_softmax_no_dtype(self, dim) -@torch_op("aten::abs") +@torch_op(("aten::abs", "_operator::abs")) def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" @@ -132,13 +137,13 @@ def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" # self_real = self[..., 0] - self_real = op.Gather(self, 0, axis=-1) + self_real = op.Slice(self, [0], [1], axes=[-1]) # self_imag = self[..., 1] - self_imag = op.Gather(self, 1, axis=-1) + self_imag = op.Slice(self, [1], [2], axes=[-1]) real_pow = op.Pow(self_real, 2) imag_pow = op.Pow(self_imag, 2) real_plus_imag = op.Add(real_pow, imag_pow) - return op.Sqrt(real_plus_imag) + return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1]) @torch_op("aten::acos") @@ -155,7 +160,7 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op(("aten::add", "aten::add.Tensor")) +@torch_op(("aten::add", "aten::add.Tensor", "_operator::add")) def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" # TODO(microsoft/onnxruntime#15977): Improve fp16 precision @@ -164,6 +169,13 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Add(self, other) +@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True) +def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + + return aten_add(self, other, alpha=alpha) + + @torch_op("aten::addbmm") def aten_addbmm( self: TReal, @@ -223,10 +235,11 @@ def aten_addmm( ) -> TReal: """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" - mat1_mat2 = op.MatMul(mat1, mat2) - scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) - scaled_self = op.Mul(self, beta) - return op.Add(scaled_self, scaled_mat1_mat2) + # NOTE: ONNX Runtime does not support int inputs to Gemm as of 1.16. + # To support int inputs, consider an overriding implementation that casts to float and back. + + # addmm only accepts 2d tensors: https://pytorch.org/docs/stable/generated/torch.addmm.html + return op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta) @torch_op("aten::addmv") @@ -320,8 +333,7 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType: def aten_all(self: TTensor) -> BOOL: """all(Tensor self) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -335,8 +347,7 @@ def aten_all(self: TTensor) -> BOOL: def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: """all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -364,8 +375,7 @@ def aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: # dim is None and thus not supplied - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -443,8 +453,7 @@ def aten_angle(self: TensorType) -> TensorType: def aten_any(self: TTensor) -> BOOL: """any(Tensor self) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -459,8 +468,7 @@ def aten_any(self: TTensor) -> BOOL: def aten_any_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: """any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -490,8 +498,7 @@ def aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: # dim is None and thus not supplied - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -681,7 +688,7 @@ def aten_arctanh(self: TensorType) -> TensorType: def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = op.Size(op.Shape(self)) == 0 + self_is_scaler = IsScalar(self) self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ArgMax(self, keepdims=keepdim) if self_is_scaler: @@ -694,7 +701,7 @@ def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = op.Size(op.Shape(self)) == 0 + self_is_scaler = IsScalar(self) if self_is_scaler: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -709,7 +716,7 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = op.Size(op.Shape(self)) == 0 + self_is_scaler = IsScalar(self) self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ArgMin(self, keepdims=keepdim) if self_is_scaler: @@ -722,7 +729,7 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = op.Size(op.Shape(self)) == 0 + self_is_scaler = IsScalar(self) if self_is_scaler: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -881,9 +888,7 @@ def aten_atanh(self: TFloat) -> TFloat: def aten_atleast_1d(self: TTensor) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" - shape = op.Shape(self) - rank = op.Size(shape) - if rank == 0: + if IsScalar(self): self = op.Reshape(self, op.Constant(value_ints=[1])) return self @@ -907,9 +912,7 @@ def reshape_to_1d(tensor): def aten_atleast_2d(self: TTensor) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - shape = op.Shape(self) - rank = op.Size(shape) - if rank <= 1: + if Rank(self) <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1])) return self @@ -933,8 +936,7 @@ def reshape_to_2d(tensor): def aten_atleast_3d(self: TTensor) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - shape = op.Shape(self) - rank = op.Size(shape) + rank = Rank(self) if rank <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: @@ -1160,6 +1162,7 @@ def aten_binomial( "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", + "_operator::and_", ) ) def aten_bitwise_and(self: TInt, other: TInt) -> TInt: @@ -1231,6 +1234,7 @@ def aten_bitwise_not(self: TInt) -> TInt: "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", + "_operator::or_", ) ) def aten_bitwise_or(self: TInt, other: TInt) -> TInt: @@ -1440,6 +1444,13 @@ def aten_ceil(self: TFloat) -> TFloat: return op.Ceil(self) +@torch_op("math::ceil") +def python_math_ceil(self: TFloat) -> TInt: + """ceil(Tensor self) -> Tensor""" + ceil = op.Ceil(self) + return op.Cast(ceil, to=INT64.dtype) + + def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType: """chain_matmul(Tensor[] matrices) -> Tensor""" @@ -1573,7 +1584,8 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") def aten_clone( - self: TTensor, memory_format: str = "" # pylint: disable=unused-argument + self: TTensor, + memory_format: str = "", # pylint: disable=unused-argument ) -> TTensor: """clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor""" @@ -1686,8 +1698,7 @@ def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTens neg_1 = op.Constant(value_ints=[-1]) - rank = op.Size(op.Shape(self)) - zero_count = op.Sub(op.Mul(rank, 2), op.Size(pad)) + zero_count = op.Sub(op.Mul(Rank(self), 2), op.Size(pad)) zero_count = op.Reshape(zero_count, neg_1) zero = op.Constant(value_ints=[0]) zeros = op.Expand(zero, zero_count) @@ -1709,7 +1720,8 @@ def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTens @torch_op("aten::contiguous") def aten_contiguous( - self: TTensor, memory_format: str = "contiguous_format" # pylint: disable=unused-argument + self: TTensor, + memory_format: str = "contiguous_format", # pylint: disable=unused-argument ) -> TTensor: """contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)""" @@ -1949,8 +1961,7 @@ def _aten_convolution_onnx( # Alternatively we could cast transposed to BOOL. # E.g. `if op.Cast(transposed, BOOL.dtype): ...` - weight_size = op.Size(op.Shape(weight)) - no_batch = op.Size(op.Shape(input)) != weight_size + no_batch = Rank(input) != Rank(weight) if no_batch: input = op.Unsqueeze(input, op.Constant(value_ints=[0])) @@ -2036,12 +2047,27 @@ def aten_convolution_overrideable( @torch_op("aten::copy") def aten_copy( - self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument + self: TTensor, + src: TTensor2, + non_blocking: bool = False, # pylint: disable=unused-argument ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" - self = op.Identity(src) - return self + return op.CastLike(src, self) + + +@torch_op("aten::_to_copy", trace_only=True) +def aten__to_copy( + self: TTensor, + dtype: int = -1, + non_blocking: bool = False, # pylint: disable=unused-argument +) -> TTensor: + """_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor""" + + if dtype == -1: + return op.Identity(self) + else: + return common_ops.cast_to(self, dtype=dtype) def aten_copysign(self: TensorType, other: TensorType) -> TensorType: @@ -2340,7 +2366,7 @@ def aten_cumsum( def _aten_cumsum_onnx( self: TRealUnlessInt16OrInt8, dim: Union[INT32, INT64] ) -> TRealUnlessInt16OrInt8: - if op.Size(op.Shape(self)) == 0: + if IsScalar(self): # A scalar result = op.Identity(self) else: @@ -2366,12 +2392,6 @@ def aten_dense_dim(self: TensorType) -> int: raise NotImplementedError() -def aten_det(self: TensorType) -> TensorType: - """det(Tensor self) -> Tensor""" - - raise NotImplementedError() - - @torch_op("aten::detach") def aten_detach(self: TensorType) -> TensorType: """detach(Tensor(a) self) -> Tensor(a)""" @@ -2614,6 +2634,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType "aten::div.Scalar_mode", "aten::divide", "aten::true_divide", + "_operator::truediv", ) ) def aten_div(self: TFloat, other: TFloat) -> TFloat: @@ -2623,6 +2644,40 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat: return op.Div(self, other) +@torch_op( + ( + "aten::div", + "aten::div.Tensor", + "aten::div.Scalar", + "aten::divide", + "aten::true_divide", + "_operator::truediv", + ), + complex=True, +) +def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: + """div.Tensor(Tensor self, Tensor other) -> Tensor""" + + # Complex division. PyTorch type promotion ensures both arguments are complex numbers + self_real = op.Slice(self, [0], [1], axes=[-1]) + self_imag = op.Slice(self, [1], [2], axes=[-1]) + other_real = op.Slice(other, [0], [1], axes=[-1]) + other_imag = op.Slice(other, [1], [2], axes=[-1]) + + # Complex division + # (a + bi) / (c + di) = (ac + bd) / (c^2 + d^2) + (bc - ad) / (c^2 + d^2)i + # https://mathworld.wolfram.com/ComplexDivision.html + ac = op.Mul(self_real, other_real) + bd = op.Mul(self_imag, other_imag) + bc = op.Mul(self_imag, other_real) + ad = op.Mul(self_real, other_imag) + denominator = op.Add(op.Mul(other_real, other_real), op.Mul(other_imag, other_imag)) + real = op.Div(ac + bd, denominator) + imag = op.Div(bc - ad, denominator) + + return op.Concat(real, imag, axis=-1) + + @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" @@ -2672,8 +2727,7 @@ def aten_dot(self: TFloat, tensor: TFloat) -> TFloat: def aten_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat: """dropout(Tensor input, float p, bool train) -> Tensor""" - input_is_scalar = op.Size(op.Shape(input)) == 0 - if input_is_scalar: + if IsScalar(input): input = op.Reshape(input, op.Constant(value_ints=[-1])) result, _ = op.Dropout(input, p, train) result = op.Squeeze(result) @@ -2689,12 +2743,16 @@ def aten_dstack(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() +@torch_op("aten::einsum", trace_only=True) def aten_einsum( - equation: str, tensors: Sequence[TensorType], path: Optional[int] = None -) -> TensorType: + equation: str, + tensors: Sequence[TReal], + path: Optional[int] = None, # pylint: disable=unused-argument +) -> TReal: """einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor""" - raise NotImplementedError() + # Use trace_only to unpack the `tensors` sequence + return op.Einsum(*tensors, equation=equation) @torch_op("aten::embedding") @@ -3098,6 +3156,43 @@ def aten_embedding_dense_backward( raise NotImplementedError() +@torch_op("aten::embedding_renorm") +def aten_embedding_renorm( + weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 +) -> TFloat: + """embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor""" + + unique_indices, _, _, _ = op.Unique(indices) + partial_weight = op.Gather(weight, unique_indices) + # partial_weight_norm = sum(|w|^p)^(1/p) + if norm_type == 1.0: + # This is not necessary, but op.ReduceL1 is faster than function list in 'else' + partial_weight_norm = op.ReduceL1(partial_weight, axes=[1], keepdims=True) + elif norm_type == 2.0: + # This is not necessary, but op.ReduceL2 is faster than function list in 'else' + partial_weight_norm = op.ReduceL2(partial_weight, axes=[1], keepdims=True) + else: + # Abs -> Pow -> ReduceSum -> Pow -> Pow + partial_weight_abs = op.Abs(partial_weight) + partial_weight_pow = op.Pow(partial_weight_abs, op.Constant(value_float=norm_type)) + partial_weight_norm = op.ReduceSum(partial_weight_pow, axes=[1], keepdims=True) + pow_value = op.CastLike(1.0 / norm_type, weight) + partial_weight_norm = op.Pow(partial_weight_norm, pow_value) + + max_norm = op.CastLike(op.Constant(value_float=max_norm), weight) + # This is to avoid weight is zero + err = op.CastLike(op.Constant(value_float=1e-7), weight) + partial_weight_norm_ = op.Add(partial_weight_norm, err) + scales = op.Div(max_norm, partial_weight_norm_) + partial_weight_renorm = op.Mul(partial_weight, scales) + # Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm + partial_weight_renorm = op.Where( + op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight + ) + value = op.ScatterND(weight, op.Unsqueeze(unique_indices, [1]), partial_weight_renorm) + return value + + def aten_embedding_sparse_backward( grad: TensorType, indices: TensorType, @@ -3153,7 +3248,8 @@ def aten_empty_quantized( @torch_op("aten::empty_strided") def aten_empty_strided( - size: INT64, stride: INT64 # pylint: disable=unused-argument + size: INT64, + stride: INT64, # pylint: disable=unused-argument ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3426,7 +3522,14 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Floor(self) -@torch_op("aten::floor_divide") +@torch_op("math::floor") +def python_math_floor(self: TFloatOrBFloat16) -> TInt: + """floor(Tensor self) -> Tensor""" + floor = op.Floor(self) + return op.Cast(floor, to=INT64.dtype) + + +@torch_op(("aten::floor_divide", "_operator::floordiv")) def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor""" @@ -3541,10 +3644,10 @@ def aten_gather( ) -> TReal: """gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor""" - if op.Size(op.Shape(index)) == 0: # When (index) is empty, return (self) + if IsScalar(index): # When (index) is empty, return (self) result = self else: - if op.Size(op.Shape(self)) == 0: # Unsqueeze for GatherElements op + if IsScalar(self): # Unsqueeze for GatherElements op self = op.Reshape(self, op.Constant(value_ints=[-1])) if op.Size(index) == 0: # Return empty array result = op.CastLike(index, self) @@ -3568,7 +3671,9 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal")) +@torch_op( + ("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge") +) def aten_ge(self: TReal, other: TReal) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3724,7 +3829,7 @@ def aten_gru_cell( raise NotImplementedError() -@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater")) +@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt")) def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3831,7 +3936,7 @@ def reshape_to_atleast_2d(tensor): result = op.ConcatFromSequence(tensors_atleast_2d, axis=1, new_axis=0) # hstack expects a non-empty sequence of tensors. So we don't need to check for length - rank_1d_or_less = op.Less(op.Size(op.Shape(op.SequenceAt(tensors, 0))), 2) + rank_1d_or_less = op.Less(Rank(op.SequenceAt(tensors, 0)), 2) if rank_1d_or_less: result = op.Reshape(result, op.Constant(value_ints=[-1])) return result @@ -4051,12 +4156,9 @@ def aten_index_put_bool( # change array([F,F,T,F,F]) to array([2]) index = op.ArgMax(index_int) # assume index only have 1 True # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Gather(op.Shape(self), 1) - index_dim_0 = op.Gather(op.Shape(index), 0) - neg_1 = op.Constant(value_ints=[-1]) - shape = op.Concat( - op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0 - ) + self_dim_1 = op.Shape(self, start=1, end=2) + index_dim_0 = op.Shape(index, start=0, end=1) + shape = op.Concat(self_dim_1, index_dim_0, axis=0) new_ind = op.Expand(index, shape) new_ind_t = op.Transpose(new_ind) @@ -4094,7 +4196,7 @@ def aten_index_reduce( def aten_index_select(self: TTensor, dim: int, index: IntType) -> TTensor: """index_select(Tensor self, int dim, Tensor index) -> Tensor""" - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -4235,9 +4337,7 @@ def aten_is_same_size(self: TTensor, other: TTensor) -> BOOL: # Cannot compare different shape of two tensors using op.Equal() # So we need to compare the rank first, if rank is same, then compare shape - self_rank = op.Size(op.Shape(self)) - other_rank = op.Size(op.Shape(other)) - result = op.Equal(self_rank, other_rank) + result = op.Equal(Rank(self), Rank(other)) if result: # Same rank, then compare shape self_shape = op.Shape(self) other_shape = op.Shape(other) @@ -4436,7 +4536,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::le", "aten::le.Tensor")) +@torch_op(("aten::le", "aten::le.Tensor", "_operator::le")) def aten_le(self: TReal, other: TReal) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4579,8 +4679,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: """logcumsumexp(Tensor self, int dim) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): result = self else: # Make dim 1-d @@ -4692,7 +4791,7 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat: """logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor""" - if op.Size(op.Shape(self)) == 0: + if IsScalar(self): # A scalar result = self else: @@ -4740,7 +4839,7 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less")) +@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt")) def aten_lt(self: TReal, other: TReal) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4898,13 +4997,13 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType: def aten_max(self: TReal) -> TReal: """max(Tensor self) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + self_is_scalar = IsScalar(self) + if self_is_scalar: self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ReduceMax(self, keepdims=False) - if self_rank == 0: + if self_is_scalar: result = op.Squeeze(result) return result @@ -4914,7 +5013,7 @@ def aten_max(self: TReal) -> TReal: def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, INT64]: """max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" - if op.Size(op.Shape(self)) == 0: + if IsScalar(self): result = self indices = op.Constant(value_int=0) else: @@ -4950,10 +5049,10 @@ def aten_mean(self: TReal) -> TReal: def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: """mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" - if op.Size(op.Shape(self)) == 0: + if IsScalar(self): result = self else: - if op.Size(op.Shape(dim)) == 0: + if IsScalar(dim): dim = op.Unsqueeze(dim, axes=0) result = op.ReduceMean(self, axes=dim, keepdims=keepdim) return result @@ -4981,7 +5080,7 @@ def aten_min(self: TReal) -> TReal: @torch_op("aten::min.dim") def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, TInt]: """min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" - if op.Size(op.Shape(self)) == 0: + if IsScalar(self): result = self indices = op.Constant(value_int=0) else: @@ -5280,7 +5379,6 @@ def aten_mm( ) -> TRealUnlessInt16OrInt8: """mm(Tensor self, Tensor mat2) -> Tensor""" - # TODO(justinchuby): Specify type conversion for uint8/int8/int16 return op.MatMul(self, mat2) @@ -5343,11 +5441,10 @@ def aten_msort(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::mul", "aten::mul.Tensor")) +@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul")) def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - # FIXME(titaiwang): get rid of this when we have type_promotion - other = op.CastLike(other, self) + return op.Mul(self, other) @@ -5361,6 +5458,29 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: return op.And(self, other) +@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), complex=True) +def aten_mul_complex(self: TReal, other: TReal) -> TReal: + """mul.Tensor(Tensor self, Tensor other) -> Tensor""" + + self_real = op.Slice(self, [0], [1], axes=[-1]) + self_imag = op.Slice(self, [1], [2], axes=[-1]) + other_real = op.Slice(other, [0], [1], axes=[-1]) + other_imag = op.Slice(other, [1], [2], axes=[-1]) + + # Complex multiplication + # (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + + ac = op.Mul(self_real, other_real) + bd = op.Mul(self_imag, other_imag) + ad = op.Mul(self_real, other_imag) + bc = op.Mul(self_imag, other_real) + + real = op.Sub(ac, bd) + imag = op.Add(ad, bc) + + return op.Concat(real, imag, axis=-1) + + @torch_op("aten::multinomial") def aten_multinomial( self: TFloat, @@ -5369,14 +5489,14 @@ def aten_multinomial( ) -> TInt: """multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor""" # ONNX Multinomial doesn't support 1D input - if op.Size(op.Shape(self)) == 1: + if Rank(self) == 1: unsqueezed_input = op.Unsqueeze(self, axes=0) else: unsqueezed_input = self # ONNX multinomial expects log probability log_input = op.Log(unsqueezed_input) result = op.Multinomial(log_input, dtype=INT64.dtype, sample_size=num_samples) - if op.Size(op.Shape(self)) == 1: + if Rank(self) == 1: result = op.Squeeze(result) return result @@ -5454,21 +5574,17 @@ def aten_nansum( def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor: """narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)""" - dim_rank = op.Size(op.Shape(dim)) - if dim_rank == 0: + if IsScalar(dim): dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - start_rank = op.Size(op.Shape(start)) - if start_rank == 0: + if IsScalar(start): start = op.Reshape(start, op.Constant(value_ints=[-1])) - length_rank = op.Size(op.Shape(length)) - if length_rank == 0: + if IsScalar(length): length = op.Reshape(length, op.Constant(value_ints=[-1])) end = op.Add(start, length) - result = op.Slice(self, start, end, dim) - return result + return op.Slice(self, start, end, dim) def aten_narrow_copy(self: TensorType, dim: int, start: INT64, length: INT64) -> TensorType: @@ -5477,7 +5593,42 @@ def aten_narrow_copy(self: TensorType, dim: int, start: INT64, length: INT64) -> raise NotImplementedError() -@torch_op("aten::native_batch_norm", trace_only=True) +# NOTE: https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510 +# _native_batch_norm_legit_no_training and _native_batch_norm_legit are meant to +# replace native_batch_norm within unknown time period. +# TODO: Refactor this after native_batch_norm is deprecated. +@torch_op("aten::_native_batch_norm_legit_no_training", trace_only=True) +def aten__native_batch_norm_no_training( + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + running_mean: Optional[TFloat] = None, + running_var: Optional[TFloat] = None, + momentum: float = 0.9, + eps: float = 1e-05, +) -> Tuple[TFloat, TFloat, TFloat]: + """_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" + + return aten_native_batch_norm( + input, weight, bias, running_mean, running_var, False, momentum, eps + ) + + +@torch_op("aten::_native_batch_norm_legit.no_stats", trace_only=True) +def aten__native_batch_norm_no_stats( + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + training: bool = False, + momentum: float = 0.9, + eps: float = 1e-05, +) -> Tuple[TFloat, TFloat, TFloat]: + """_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" + + return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps) + + +@torch_op(("aten::native_batch_norm", "aten::_native_batch_norm_legit"), trace_only=True) def aten_native_batch_norm( input: TFloat, weight: Optional[TFloat] = None, @@ -5577,12 +5728,131 @@ def _aten_native_batch_norm_inference_onnx( momentum=momentum, training_mode=training, ) + # NOTE: mean and var are omitted in inference mode # Cannot return 2 dup output, so have to do twice with different variable name - empty_mean = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype) - empty_var = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype) + empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm) + empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm) return norm, empty_mean, empty_var +# TODO: This op is using duplicated code from aten_native_batch_norm, +# need to refactor it later. https://github.com/microsoft/onnxscript/issues/1125 +# NOTE: This op is invoked by PyTorch Functionalization, and not in +# native_functions.yaml, It can be found in torch/_decomp/decompositions.py +@torch_op("aten::_native_batch_norm_legit_functional", trace_only=True) +def aten__native_batch_norm_legit_functional( + input: TFloat, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + running_mean: Optional[TFloat] = None, + running_var: Optional[TFloat] = None, + training: bool = False, + momentum: float = 0.9, + eps: float = 1e-05, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: + if weight is None: # Set to 1.0 as default + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + axes = list(range(len(input.shape))) + axes.pop(1) + axes = op.Constant(value_ints=axes) + if running_mean is None: # Using input mean + running_mean = op.Squeeze(op.ReduceMean(input, axes)) + + if running_var is None: # Using input var + mean = op.ReduceMean(input, axes) + input_sub_mean = op.Sub(input, mean) + sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) + running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes)) + + # Have to split to 2 private functions, because training_function return 3 outputs + # While inference_function return 1 output + if training is True: + norm, mean, var, new_mean, new_var = _aten__native_batch_norm_training_functional_onnx( + input, weight, bias, running_mean, running_var, axes, training, momentum, eps + ) + else: + ( + norm, + mean, + var, + new_mean, + new_var, + ) = _aten__native_batch_norm_inference_functional_onnx( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + return norm, mean, var, new_mean, new_var + + +@torch_op("aten::_native_batch_norm_legit_functional", private=True) +def _aten__native_batch_norm_training_functional_onnx( + input: TFloat, + weight: TFloat, + bias: TFloat, + running_mean: TFloat, + running_var: TFloat, + axes: INT64, + training: bool, + momentum: float, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: + # Assert(training is True) + norm, running_mean, running_var = op.BatchNormalization( + input, + weight, + bias, + running_mean, + running_var, + epsilon=eps, + momentum=momentum, + training_mode=training, + ) + # Compute var and rstd + mean = op.ReduceMean(input, axes) + input_sub_mean = op.Sub(input, mean) + sqr = op.Mul(input_sub_mean, input_sub_mean) + var = op.ReduceMean(sqr, axes, keepdims=False) + rstd = op.Div(1.0, op.Sqrt(var + eps)) + # Get mean again with size = [1, C] + mean = op.ReduceMean(input, axes, keepdims=False) + # NOTE: Fixed to be FLOAT dtype + running_mean = op.Cast(running_mean, to=FLOAT.dtype) + running_var = op.Cast(running_var, to=FLOAT.dtype) + return norm, mean, rstd, running_mean, running_var + + +@torch_op("aten::_native_batch_norm_legit_functional", private=True) +def _aten__native_batch_norm_inference_functional_onnx( + input: TFloat, + weight: TFloat, + bias: TFloat, + running_mean: TFloat, + running_var: TFloat, + training: bool, + momentum: float, + eps: float, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: + # Assert(training is False) + norm = op.BatchNormalization( + input, + weight, + bias, + running_mean, + running_var, + epsilon=eps, + momentum=momentum, + training_mode=training, + ) + # NOTE: mean and var are ommited in inference mode + # Cannot return 2 dup output, so have to do twice with different variable name + empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm) + empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm) + return norm, empty_mean, empty_var, running_mean, running_var + + def aten_native_batch_norm_backward( grad_out: TensorType, input: TensorType, @@ -5679,7 +5949,7 @@ def _aten_native_group_norm_onnx( norm = op.Reshape(norm, op.Shape(input)) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - input_rank = op.Size(op.Shape(input)) + input_rank = Rank(input) axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) @@ -5773,14 +6043,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType: raise NotImplementedError() -@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor")) +@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne")) def aten_ne(self: TReal, other: TReal) -> BOOL: """ne.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Not(op.Equal(self, other)) -@torch_op("aten::neg") +@torch_op(("aten::neg", "_operator::neg")) def aten_neg(self: TReal) -> TReal: """neg(Tensor self) -> Tensor""" @@ -5935,8 +6205,7 @@ def aten_normal( ) -> TFloat: # type: ignore[type-var] """normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor""" - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + if IsScalar(self): self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.RandomNormalLike(self, mean=mean, scale=std) @@ -6160,7 +6429,9 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar")) +@torch_op( + ("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow") +) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" @@ -6853,26 +7124,33 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: @torch_op(("aten::rsub", "aten::rsub.Scalar")) def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # FIXME(titaiwang): get rid of this when we have type_promotion - other = op.CastLike(other, self) - alpha = op.CastLike(alpha, self) + return op.Sub(other, op.Mul(self, alpha)) -@torch_op("aten::scalar_tensor") -def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var] +@torch_op(("aten::rsub", "aten::rsub.Scalar"), trace_only=True, complex=True) +def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + + return aten_rsub(self, other, alpha) + + +@torch_op("aten::scalar_tensor", trace_only=True) +def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - return op.Cast(s, to=dtype) + # Set trace_only=True because different if branches return different dtypes + # which is not supported in an ONNX function + return common_ops.cast_to(s, dtype=dtype) -@torch_op("aten::scalar_tensor") -def aten_scalar_tensor_sym_number( - s: Union[FLOAT, INT32, BOOL], dtype: int = FLOAT.dtype -) -> TTensor: +@torch_op("aten::scalar_tensor", trace_only=True) +def aten_scalar_tensor_sym_number(s: RealType, dtype: int = FLOAT.dtype) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - return op.Cast(s, to=dtype) + # Set trace_only=True because different if branches return different dtypes + # which is not supported in an ONNX function + return common_ops.cast_to(s, dtype=dtype) @torch_op("aten::scatter_add") @@ -6918,14 +7196,14 @@ def _aten_scatter_reduce_onnx( dim: int, onnx_reduce: str, ): - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) + self_is_scalar = IsScalar(self) + if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) neg_1 = op.Constant(value_ints=[-1]) self = op.Reshape(self, neg_1) index = op.Reshape(index, neg_1) src = op.Reshape(src, neg_1) result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) - if self_rank == 0: + if self_is_scalar: result = op.Squeeze(result) return result @@ -7165,7 +7443,7 @@ def aten_softmax( ) -> TFloatOrBFloat16: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) @@ -7181,7 +7459,7 @@ def aten_softmax( def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) @@ -7263,7 +7541,7 @@ def aten_squeeze(self: TTensor) -> TTensor: @torch_op("aten::squeeze.dim") def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor: result = self - if op.Size(op.Shape(self)) > 0: # type: ignore[operator] + if Rank(self) > 0: # type: ignore[operator] # check if specified dimension is 1, do squeeze shape = op.Shape(self) dim_size = op.Gather(shape, dim, axis=0) @@ -7274,6 +7552,15 @@ def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor: return result +@torch_op("aten::squeeze.dim", complex=True, trace_only=True) +def aten_squeeze_dim_complex(self: TTensor, dim: int) -> TTensor: + if dim < 0: + # Account for the complex dimension in ONNX + dim = dim - 1 + + return aten_squeeze_dim(self, dim) + + def aten_squeeze_copy(self: TensorType) -> TensorType: """squeeze_copy(Tensor self) -> Tensor""" @@ -7308,8 +7595,7 @@ def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, @torch_op("aten::stft", private=True) def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: - signal_shape = op.Shape(self) - signal_rank = op.Size(signal_shape) + signal_rank = Rank(self) if signal_rank == 1: # Add a batch dimension self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -7321,7 +7607,7 @@ def _center_window_around_zeros_if_needed( window: TFloatOrBFloat16, n_fft: int ) -> TFloatOrBFloat16: # first dimension - n_win = op.Gather(op.Shape(window), 0) + n_win = op.Shape(window, start=0, end=1) # Center window around zeros if needed (required by ONNX's STFT) if n_win < n_fft: left = (n_fft - n_win) / 2 @@ -7440,7 +7726,7 @@ def aten_stft( return result -@torch_op(("aten::sub", "aten::sub.Tensor")) +@torch_op(("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub")) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" alpha = op.CastLike(alpha, other) @@ -7449,10 +7735,15 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Sub(self, other) -def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType: - """subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" +@torch_op( + ("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"), + trace_only=True, + complex=True, +) +def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - raise NotImplementedError() + return aten_sub(self, other, alpha=alpha) @torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True) @@ -7478,11 +7769,11 @@ def aten_sum_dim_IntList( @torch_op("aten::sum", private=True) def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Reshape(self, op.Constant(value_ints=[-1])) - if op.Size(op.Shape(dim)) == 0: + if IsScalar(dim): dim = op.Reshape(dim, op.Constant(value_ints=[-1])) dim = op.Cast(dim, to=INT64.dtype) result = op.ReduceSum(self, dim, keepdims=keepdim) @@ -7494,7 +7785,7 @@ def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: @torch_op("aten::sum", private=True) def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal: - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -7564,7 +7855,7 @@ def aten_symeig( def aten_t(self: TTensor) -> TTensor: """t(Tensor(a) self) -> Tensor(a)""" - rank = op.Size(op.Shape(self)) + rank = Rank(self) if rank == 2: result = op.Transpose(self, perm=[1, 0]) else: @@ -7651,8 +7942,7 @@ def aten_threshold_backward( def aten_tile(self: TTensor, dims: INT64) -> TTensor: """tile(Tensor self, int[] dims) -> Tensor""" - self_shape = op.Shape(self) - self_rank = op.Size(self_shape) + self_rank = Rank(self) dims_rank = op.Size(dims) diff = op.Sub(self_rank, dims_rank) @@ -7668,6 +7958,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor: # pad self.shape with 1 diff_1d = op.Reshape(op.Abs(diff), op.Constant(value_ints=[1])) exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) + self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) self = op.Reshape(self, self_final_shape) @@ -7742,7 +8033,7 @@ def aten_topk( ) -> Tuple[TReal, INT64]: """topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)""" - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) k = op.Reshape(op.Cast(k, to=INT64.dtype), op.Constant(value_ints=[1])) @@ -7785,6 +8076,32 @@ def aten_transpose(self, dim0: int, dim1: int): return result +@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True, complex=True) +def aten_transpose_complex(self, dim0: int, dim1: int): + """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" + + # Use trace only to construct the prem attribute in Transpose + self_rank = len(self.shape) # type: ignore[attr-defined] + + if self_rank == 0: + result = self + else: + # Python code, change when onnxscript supports this + # Handle when dim0 or dim1 is negative. ONNX uses the last axis to + # represent to complex axis so we need to move the dim one axis toward the start. + if dim0 < 0: + dim0 = dim0 - 1 + if dim1 < 0: + dim1 = dim1 - 1 + dims = list(range(self_rank)) + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + # Python code ends + + result = op.Transpose(self, perm=dims) + + return result + + def aten_triangular_solve( self: TensorType, A: TensorType, @@ -7894,6 +8211,9 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor: if self_rank == 0: result = op.Unsqueeze(self, 0) else: + # Handle negative dimension + if dimension < 0: + dimension = dimension + self_rank dim_size = self.shape[dimension] target_end = (dim_size - size) // step + 1 if target_end >= 1: # the rank of final reuslt will be self_rank + 1 @@ -8053,16 +8373,15 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: @torch_op("aten::var_mean.dim", trace_only=True) def aten_var_mean_dim( - self: TReal, dim: Optional[int], unbiased: bool = True, keepdim: bool = False + self: TReal, dim: int, unbiased: bool = True, keepdim: bool = False ) -> Tuple[TReal, TReal]: """var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" - # Although dim is Optional in signature, but we assume it must has value for this overload + # Although dim is Optional in signature, but we assume it must have value for this overload # Assert(dim is not None) - if isinstance(dim, Tuple): - dim_tensor = op.Constant(value_ints=dim) - else: - dim_tensor = op.Constant(value_int=dim) + if isinstance(dim, int): + dim = (dim,) + dim_tensor = op.Constant(value_ints=dim) return _aten_var_mean_dim_onnx( self, dim_tensor, correction=float(unbiased), keepdim=keepdim ) @@ -8083,10 +8402,9 @@ def aten_var_mean_correction( if dim is None: var, mean = _aten_var_mean_onnx(self, correction, keepdim) else: - if isinstance(dim, Tuple): - dim_tensor = op.Constant(value_ints=dim) - else: - dim_tensor = op.Constant(value_int=dim) + if isinstance(dim, int): + dim = (dim,) + dim_tensor = op.Constant(value_ints=dim) var, mean = _aten_var_mean_dim_onnx(self, dim_tensor, correction, keepdim) return var, mean @@ -8147,6 +8465,15 @@ def aten_view(self: TTensor, size: IntType) -> TTensor: return op.Reshape(self, size) +@torch_op("aten::view", complex=True) +def aten_view_complex(self: TTensor, size: IntType) -> TTensor: + """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" + + size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) + return op.Reshape(self, complex_size) + + @torch_op("aten::view_as") def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 807560ee9..cf8ec866b 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -13,9 +13,165 @@ from typing import Optional, Sequence +from onnxscript import INT64 +from onnxscript.function_libs.torch_lib.registration import torch_op +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat +from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +@torch_op( + ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), + private=True, + complex=True, +) +def _fftn_onnx_normalization( + self, + transformed: TFloat, + normalization: int, + forward: bool, + dims: Sequence[int], +) -> TFloat: + # Obtain the total_sample_count (n) for normalization + self_shape = op.Shape(self) + total_sample_count = op.ReduceProd(self_shape[dims], keepdims=0) + total_sample_count = op.CastLike(total_sample_count, transformed) + + # Normalize the result + # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn + # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 + if normalization == 1: + # "forward" - normalize by 1/n + if forward: + result = op.Div(transformed, op.Sqrt(total_sample_count)) + else: + result = op.Mul(transformed, op.Sqrt(total_sample_count)) + elif normalization == 2: + # "ortho" - normalize by 1/sqrt(n) + if forward: + result = op.Div(transformed, total_sample_count) + else: + result = transformed + else: + # "backward" - no normalization + if forward: + result = transformed + else: + result = op.Mul(transformed, total_sample_count) + + return result + + +@torch_op( + ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), + trace_only=True, + private=True, + complex=True, +) +def _fftn_onnx( + self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool +) -> TFloat: + """Standard complex to complex or real to complex FFT (forward or backward). + + This is a private shared function for implementing the various FFT functions. + + Args: + self: The input tensor. + dims: The dimensions to apply FFT. + normalization: The normalization mode. + inverse: Whether to compute the inverse FFT. + onesided: Whether to compute the one-sided FFT, which retains only the + positive frequencies. + + Returns: + The transformed tensor. + """ + + # NOTE: trace_only because we need to process each dimension in a loop + # NOTE: SymInt dim is not support because DFT-17 needs a static axis + # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support + + # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new + # dimension at the beginning to represent the batch dimension. + transformed = op.Unsqueeze(self, axes=[0]) + + for dim_ in dims: + if dim_ >= 0: + # Add 1 to account for the batch dimension when counting axes from the left + dim_ = dim_ + 1 + transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided) + # Remove the batch dimension + transformed = op.Squeeze(transformed, axes=[0]) + + return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) + + +@torch_op("aten::_fft_c2c", trace_only=True, complex=True) +def aten__fft_c2c( + self: TFloat, dim: Sequence[int], normalization: int, forward: bool +) -> TFloat: + """_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + + Standard complex to complex FFT (forward or backward). + """ + + # NOTE: trace_only because we need to negate forward + # NOTE: SymInt dim is not support because DFT-17 needs a static axis + # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support + + # ONNX DFT input assumes the last dimension is the complex dimension. + # Thus dim=-1 in PyTorch is dim=-2 in ONNX. + dim = [d - 1 if d < 0 else d for d in dim] + return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False) + + +@torch_op("aten::_fft_c2r", trace_only=True, complex=True) +def aten__fft_c2r( + self: TFloat, + dim: Sequence[int], + normalization: int, + last_dim_size: INT64, # pylint: disable=unused-argument +) -> TFloat: + """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + + Complex to real inverse FFT. + """ + + # TODO(justinchuby): Figure out what last_dim_size does + + self_rank = len(self.shape) + # ONNX DFT input assumes the last dimension is the complex dimension. + # Thus dim=-1 in PyTorch is dim=-2 in ONNX. + dim = [(d - 1) + self_rank if d < 0 else d for d in dim] + transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) + # Take only the real part + real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) + + return op.Squeeze(real_part, axes=[-1]) + + +@torch_op("aten::_fft_r2c", trace_only=True) +def aten__fft_r2c( + self: TFloat, dim: Sequence[int], normalization: int, onesided: bool +) -> TFloat: + """_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + + Real to complex forward FFT. + """ + + # Add a new dimension at the end + signal = op.Unsqueeze(self, axes=[-1]) + # No need to fill the imaginary part because ONNX DFT accepts real inputs + # https://onnx.ai/onnx/operators/onnx__DFT.html#inputs + + self_rank = len(self.shape) + # ONNX DFT input assumes the last dimension is the complex dimension. + # Thus dim=-1 in PyTorch is dim=-2 in ONNX. + dim = [(d - 1) + self_rank if d < 0 else d for d in dim] + + return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided) + + def aten_fft_fft( self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None ) -> TensorType: diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 81eb5fcd6..48d4f6027 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -14,11 +14,14 @@ from typing import Optional, Sequence from onnxscript import BOOL, FLOAT, INT64 +from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import TFloat from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +IsScalar = common_ops.IsScalar + def aten_linalg_cholesky(self: TensorType, upper: bool = False) -> TensorType: """linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor""" @@ -46,10 +49,11 @@ def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> Ten raise NotImplementedError() -def aten_linalg_det(A: TensorType) -> TensorType: +@torch_op(("aten::linalg_det", "aten::det")) +def aten_linalg_det(A: TFloat) -> TFloat: """linalg_det(Tensor A) -> Tensor""" - raise NotImplementedError() + return op.Det(A) def aten_linalg_diagonal( @@ -331,8 +335,8 @@ def aten_linalg_vector_norm( @torch_op("aten::linalg_vector_norm", private=True) def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat: - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + self_is_scalar = IsScalar(self) + if self_is_scalar: self = op.Unsqueeze(self, axes=[0]) self = op.Abs(self) @@ -345,12 +349,13 @@ def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool self_bool = op.Cast(self, to=BOOL.dtype) self_0_1 = op.CastLike(self_bool, self) result = op.ReduceSum(self_0_1, keepdims=False) + # TODO(microsoft/onnxruntime#18338): Use ReduceL1/L2 when ONNX Runtime is fixed else: ord_float = op.CastLike(ord, self) self_pow = op.Pow(self, ord_float) result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float)) - if self_rank == 0: + if self_is_scalar: result = op.Squeeze(result) return result @@ -360,8 +365,8 @@ def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool def _aten_linalg_vector_norm_onnx( self: TFloat, ord: float, dim: INT64, keepdim: bool ) -> TFloat: - self_rank = op.Size(op.Shape(self)) - if self_rank == 0: + self_is_scalar = IsScalar(self) + if self_is_scalar: self = op.Unsqueeze(self, axes=[0]) dim = op.Reshape(dim, op.Constant(value_ints=[-1])) @@ -375,12 +380,16 @@ def _aten_linalg_vector_norm_onnx( self_bool = op.Cast(self, to=BOOL.dtype) self_0_1 = op.CastLike(self_bool, self) result = op.ReduceSum(self_0_1, dim, keepdims=keepdim) + elif ord == 1.0: + result = op.ReduceL1(self, dim, keepdims=keepdim) + elif ord == 2.0: + result = op.ReduceL2(self, dim, keepdims=keepdim) else: ord_float = op.CastLike(ord, self) self_pow = op.Pow(self, ord_float) result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float)) - if self_rank == 0: + if self_is_scalar: result = op.Squeeze(result) return result diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 9f819482a..d4d28059e 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -20,6 +20,7 @@ import onnx from onnxscript import FLOAT, INT64 +from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -33,6 +34,7 @@ from onnxscript.onnx_types import BOOL, TensorType _MATH_PI = math.pi +Rank = common_ops.Rank @torch_op("aten::aten_adaptive_avg_pool1d") @@ -42,7 +44,7 @@ def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat: # assert output_size == [1] # TODO(justinchuby): Specify input constraints - if op.Size(op.Shape(self)) == 2: + if Rank(self) == 2: # Unbatched case self = op.Unsqueeze(self, op.Constant(value_ints=[0])) pooled = op.GlobalAveragePool(self) @@ -60,7 +62,7 @@ def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat: # assert output_size == [1, 1] # TODO(justinchuby): Specify input constraints - if op.Size(op.Shape(self)) == 3: + if Rank(self) == 3: # Unbatched case self = op.Unsqueeze(self, op.Constant(value_ints=[0])) pooled = op.GlobalAveragePool(self) @@ -78,7 +80,7 @@ def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat: # assert output_size == [1, 1, 1] # TODO(justinchuby): Specify input constraints - if op.Size(op.Shape(self)) == 4: + if Rank(self) == 4: # Unbatched case self = op.Unsqueeze(self, op.Constant(value_ints=[0])) pooled = op.GlobalAveragePool(self) @@ -864,8 +866,8 @@ def _aten_max_pool_onnx( ceil_mode: bool, unbatched_rank: int, ) -> TFloatOrUInt8: - self_rank = op.Size(op.Shape(self)) - if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self_rank_is_unbatched_rank = Rank(self) == unbatched_rank + if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 self = op.Unsqueeze(self, op.Constant(value_ints=[0])) pool_result, _ = op.MaxPool( @@ -877,7 +879,7 @@ def _aten_max_pool_onnx( strides=strides, ) - if self_rank == unbatched_rank: + if self_rank_is_unbatched_rank: pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) return pool_result @@ -999,8 +1001,8 @@ def _aten_max_pool_with_indices_onnx( n_dims_zero: Sequence[int], n_dims_axes: Sequence[int], ) -> Tuple[TFloatOrUInt8, INT64]: - self_rank = op.Size(op.Shape(self)) - if self_rank == unbatched_rank: + self_rank_is_unbatched_rank = Rank(self) == unbatched_rank + if self_rank_is_unbatched_rank: self = op.Unsqueeze(self, axes=0) pool_result, indices = op.MaxPool( @@ -1055,7 +1057,7 @@ def _aten_max_pool_with_indices_onnx( delta = op.Slice(flatten_indices, axes=axes, starts=starts, ends=ends) indices = op.Sub(indices, delta) - if self_rank == unbatched_rank: + if self_rank_is_unbatched_rank: pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) indices = op.Squeeze(indices, op.Constant(value_ints=[0])) @@ -1227,8 +1229,8 @@ def aten_nll_loss( ) -> TFloat: """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" - rank_self = op.Size(op.Shape(self)) - if rank_self == 1: # self rank should be at least 2 + self_rank_is_1 = Rank(self) == 1 + if self_rank_is_1: # self rank should be at least 2 self = op.Unsqueeze(self, op.Constant(value_ints=[0])) rank_target = op.Size(op.Shape(target)) @@ -1248,7 +1250,7 @@ def aten_nll_loss( self, target, ignore_index=ignore_index, reduction="sum" ) - if rank_self == 1: + if self_rank_is_1: result = op.Squeeze(result) return result @@ -1264,8 +1266,8 @@ def aten_nll_loss_weight( ) -> TFloat: """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" - rank_self = op.Size(op.Shape(self)) - if rank_self == 1: # self rank should be at least 2 + self_rank_is_1 = Rank(self) == 1 + if self_rank_is_1: # self rank should be at least 2 self = op.Unsqueeze(self, op.Constant(value_ints=[0])) rank_target = op.Size(op.Shape(target)) @@ -1285,7 +1287,7 @@ def aten_nll_loss_weight( self, target, weight, ignore_index=ignore_index, reduction="sum" ) - if rank_self == 1: + if self_rank_is_1: result = op.Squeeze(result) return result @@ -1415,7 +1417,7 @@ def aten_reflection_pad2d(self: TTensor, padding: INT64) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) zero = op.Constant(value_ints=[0]) # [0] * (rank * 2 - len(padding)) - rank = op.Size(op.Shape(self)) + rank = Rank(self) zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1) zeros = op.Expand(zero, zero_count) # list(padding[:]) + [0] * (dim * 2 - len(padding)) @@ -1494,7 +1496,7 @@ def aten_replication_pad2d(self: TTensor, padding: INT64) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) zero = op.Constant(value_ints=[0]) # [0] * (rank * 2 - len(padding)) - rank = op.Size(op.Shape(self)) + rank = Rank(self) zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1) zeros = op.Expand(zero, zero_count) # list(padding[:]) + [0] * (dim * 2 - len(padding)) @@ -1530,7 +1532,7 @@ def aten_replication_pad3d(self: TTensor, padding: INT64) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) zero = op.Constant(value_ints=[0]) # [0] * (rank * 2 - len(padding)) - rank = op.Size(op.Shape(self)) + rank = Rank(self) zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1) zeros = op.Expand(zero, zero_count) # list(padding[:]) + [0] * (dim * 2 - len(padding)) @@ -2036,10 +2038,13 @@ def aten_soft_margin_loss_backward( raise NotImplementedError() -def aten_softplus(self: TensorType, beta: float = 1.0, threshold: float = 20.0) -> TensorType: +@torch_op("aten::softplus") +def aten_softplus(self: TFloat, beta: float = 1.0, threshold: float = 20.0) -> TFloat: """softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor""" - raise NotImplementedError() + self_scaled = self * beta + softplus = op.Softplus(self_scaled) / beta + return op.Where(self_scaled > threshold, self, softplus) def aten_softplus_backward( diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ec4f74bfd..2b465cacc 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -14,8 +14,9 @@ from typing import Optional, Sequence from onnxscript import INT64 +from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TTensor +from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -216,11 +217,13 @@ def prims_conj_physical(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("prims::convert_element_type") -def prims_convert_element_type(a: TensorType, dtype: int) -> TensorType: +@torch_op("prims::convert_element_type", trace_only=True) +def prims_convert_element_type(a: RealType, dtype: int) -> RealType: """convert_element_type(Tensor a, ScalarType dtype) -> Tensor""" - return op.Cast(a, to=dtype) + # Set trace_only=True because different if branches return different dtypes + # which is not supported in an ONNX function + return common_ops.cast_to(a, dtype) def prims_copy_strided(a: TensorType, stride: INT64) -> TensorType: diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 807969934..1c273e2e0 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -14,11 +14,14 @@ from typing import Optional, Sequence from onnxscript import FLOAT +from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16 from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +IsScalar = common_ops.IsScalar + def aten_special_airy_ai(x: TensorType) -> TensorType: """special_airy_ai(Tensor x) -> Tensor""" @@ -215,7 +218,7 @@ def aten_special_log_softmax( ) -> TFloatOrBFloat16: """special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor""" - self_is_scalar = op.Size(op.Shape(self)) == 0 + self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.LogSoftmax(self, axis=dim) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index a2267d70f..f9c9a9fc7 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Generator, Optional import onnxscript +from onnxscript.function_libs.torch_lib import _constants # Regex that will match "::[.]" _QUALIFIED_OPERATOR_NAME_REGEX = re.compile( @@ -110,7 +111,7 @@ def torch_op( trace_only: Whether the function should only be traced and not compiled. private: Whether the function is private (not directly exposed). It should be true for all functions with names starting with "_". - complex: Whether the function supports complex. + complex: Whether the function expects complex-valued inputs. """ if registry is None: registry = default_registry @@ -119,7 +120,7 @@ def wrapper( func: FunctionType, ) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction: # Compile the function - custom_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1) + custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1) processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction if trace_only: diff --git a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py index 65784cb94..544866646 100644 --- a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py +++ b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py @@ -212,6 +212,8 @@ def create_mismatch_report( expected, error: Exception, ) -> None: + torch.set_printoptions(threshold=sys.maxsize) + error_text = str(error) error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) short_test_name = test_name.split(".")[-1] diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 066e87f45..c23f4f4ae 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -190,6 +190,68 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): ) +def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): + del self # Unused + # Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79 + is_fp16_or_chalf = dtype in (torch.complex32, torch.half) + if not is_fp16_or_chalf: + nd_tensor = functools.partial( + opinfo_core.make_tensor, + (S, S + 1, S + 2), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = functools.partial( + opinfo_core.make_tensor, + (31,), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + else: + low = None + high = None + shapes = ((2, 8, 9), (33,)) + + nd_tensor = functools.partial( + opinfo_core.make_tensor, + shapes[0], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = functools.partial( + opinfo_core.make_tensor, + shapes[1], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + + for normalization, forward in itertools.product((0, 1, 2), (True, False)): + # 1-D + yield opinfo_core.SampleInput( + oned_tensor(), dim=(0,), normalization=normalization, forward=forward + ) + # N-D + for dim in [ + (0,), + (1,), + (2,), + (1, 2), + (0, 1), + (0, 1, 2), + ]: + yield opinfo_core.SampleInput( + nd_tensor(), dim=dim, normalization=normalization, forward=forward + ) + + def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs): del op_info # unused del kwargs @@ -818,6 +880,36 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) +def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + def make_input(shape): + return common_methods_invocations.make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + + def make_long_input(shape, *, low, high, noncontiguous=False): + return common_methods_invocations.make_tensor( + shape, + device=device, + dtype=torch.long, + low=low, + high=high, + noncontiguous=noncontiguous, + ) + + for max_norm in (0.5, 1.0, 5.0): + for norm_type in (0.8, 1.0, 2.0, 2.5): + idx = make_long_input((6,), low=0, high=S) + weights = make_input((S, S)) * 2 + yield common_methods_invocations.SampleInput( + weights, + args=(idx,), + kwargs={"max_norm": max_norm, "norm_type": norm_type}, + ) + + def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1061,11 +1153,12 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): requires_grad=requires_grad, **kwargs, ) - dimension = 1 - size = 2 - step = 2 - # target_end = (3 - 2) // 2 + 1 = 1 - yield opinfo_core.SampleInput(t, args=(dimension, size, step)) + for dimension, size, step in [ + (1, 2, 2), + (-1, 2, 2), + (-2, 2, 2), + ]: + yield opinfo_core.SampleInput(t, args=(dimension, size, step)) def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): @@ -1197,6 +1290,52 @@ def sample_inputs_scaled_dot_product_flash_attention( yield from samples +# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args: +# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps) +# 2. (input, weight, bias, training, momentum, eps) +# which requires two function signatures to take the inputs, that's why we have +# two sample_inputs functions here instead. +def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): + samples = common_methods_invocations.sample_inputs_batch_norm( + op_info, device, dtype, requires_grad, **kwargs + ) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get("training", True) + momentum = sample.kwargs.get("momentum", 0.5) + eps = sample.kwargs.get("eps", 1e-5) + if args[0] is not None and args[1] is not None: + yield opinfo_core.SampleInput( + sample.input, + args=(args[2], args[3], args[0], args[1], training, momentum, eps), + ) + + +def sample_inputs__native_batch_norm_legit_no_stats( + op_info, device, dtype, requires_grad, **kwargs +): + samples = common_methods_invocations.sample_inputs_batch_norm( + op_info, device, dtype, requires_grad, **kwargs + ) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get("training", True) + momentum = sample.kwargs.get("momentum", 0.5) + eps = sample.kwargs.get("eps", 1e-5) + if args[0] is not None and args[1] is None: + yield opinfo_core.SampleInput( + sample.input, args=(args[2], args[3], training, momentum, eps) + ) + + # NOTE: How to create an OpInfo: # 1. Create a function that generates sample inputs for the op. # This function should yield SampleInputs. @@ -1212,6 +1351,13 @@ def sample_inputs_scaled_dot_product_flash_attention( # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ + opinfo_core.OpInfo( + "ops.aten._fft_c2c", + aten_name="_fft_c2c", + dtypes=common_dtype.complex_types(), + sample_inputs_func=sample_inputs__fft_c2c, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten._local_scalar_dense", aten_name="_local_scalar_dense", @@ -1240,6 +1386,13 @@ def sample_inputs_scaled_dot_product_flash_attention( sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.embedding_renorm", + aten_name="embedding_renorm", + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=sample_inputs_embedding_renorm, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.conv3d", aten_name="conv3d", @@ -1485,8 +1638,7 @@ def sample_inputs_scaled_dot_product_flash_attention( supports_out=False, ), opinfo_core.OpInfo( - "unfold_extra", - op=lambda x, *args: x.unfold(*args), + "ops.aten.unfold", aten_name="unfold", dtypes=common_dtype.all_types(), sample_inputs_func=sample_inputs_unfold, @@ -1527,4 +1679,34 @@ def sample_inputs_scaled_dot_product_flash_attention( supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit", + aten_name="_native_batch_norm_legit", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit, + ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit_functional", + aten_name="_native_batch_norm_legit_functional", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit, + ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit.no_stats", + aten_name="_native_batch_norm_legit.no_stats", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 33c3602be..47957e1b4 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -184,7 +184,6 @@ def run_test_output_match( # Obtain the tolerance for the op rtol, atol = torchlib_op_info.get_tolerance(dtype) - for i, cpu_sample in enumerate(samples): inputs = (cpu_sample.input, *cpu_sample.args) # Provide the repr to subtest because tensors are not serializable in parallel test runs diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index 7a2e5b404..3b6021c5d 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -460,7 +460,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, input.value = arg onnxscript_args.append(input) ort_inputs[input_name] = arg - elif isinstance(arg, Sequence): + elif isinstance(arg, (list, tuple)): + # str is also a sequence but we do not want to treat it as a tensor sequence_input = [] for j, subarg in enumerate(arg): if isinstance(subarg, np.ndarray): diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index c23729ee5..c50d75238 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -47,6 +47,7 @@ from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib.ops import core as core_ops +from onnxscript.function_libs.torch_lib.ops import fft as fft_ops from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops from onnxscript.function_libs.torch_lib.ops import nn as nn_ops from onnxscript.function_libs.torch_lib.ops import special as special_ops @@ -79,8 +80,8 @@ class TorchLibOpInfo: nondeterministic: bool = False # Whether to compare the shape only for the output[index] # For example: (1,2) means compare value for output[0] and shape for output[1] and [2] - # We may be able to combine this with the nondeterminstic option - compare_shape_only_for_output: tuple[int] = () + # We may be able to combine this with the nondeterministic option + compare_shape_only_for_output: tuple[int, ...] = () # Whether the function is designed for complex inputs complex: bool = False # The acceptable tolerance of the inference result difference between PyTorch and ORT. @@ -235,6 +236,13 @@ def _dropout_input_wrangler( return args, kwargs +def _einsum_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Swap the equation and tensors to revert the special handling in the OpInfo + return [args[1], args[0]], kwargs + + def _embedding_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -450,6 +458,13 @@ def _where_input_wrangler( # Ops to be tested for numerical consistency between onnx and pytorch # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ( + TorchLibOpInfo( + "ops.aten._fft_c2c", # Custom from extra_opinfo + fft_ops.aten__fft_c2c, + tolerance={torch.complex64: (3e-3, 1.8e-4)}, + trace_only=True, + complex=True, + ), TorchLibOpInfo( "ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense, @@ -485,10 +500,28 @@ def _where_input_wrangler( TorchLibOpInfo("acos", core_ops.aten_acos), TorchLibOpInfo("acosh", core_ops.aten_acosh), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), + TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True), TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}), TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv), TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}), - TorchLibOpInfo("addmm", core_ops.aten_addmm), + TorchLibOpInfo("addmm", core_ops.aten_addmm) + .xfail( + "decomposed", + reason=( + "The float attributes alpha/beta come in as int in the test cases, which breaks" + "eager mode. We don't need to care about this as long as the full graph tests pass" + ), + test_class_name="TestOutputConsistencyEager", + ) + .xfail( + dtypes=(torch.int16, torch.int32, torch.int64), + reason="ONNX Runtime does not support int inputs to Gemm", + ) + .xfail( + "decomposed", + dtypes=(torch.int16, torch.int32, torch.int64), + reason="ONNX Runtime does not support int inputs to Gemm", + ), TorchLibOpInfo("addmv", core_ops.aten_addmv), TorchLibOpInfo( "addr", @@ -683,6 +716,8 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", ), + TorchLibOpInfo("true_divide", core_ops.aten_div), + TorchLibOpInfo("true_divide", core_ops.aten_div_complex, complex=True), TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True) .skip( variant_name="no_rounding_mode", @@ -711,6 +746,17 @@ def _where_input_wrangler( input_wrangler=_empty_input_wrangler, nondeterministic=True, ), + TorchLibOpInfo( + "einsum", core_ops.aten_einsum, trace_only=True, input_wrangler=_einsum_input_wrangler + ) + .xfail( + reason="fixme: PyTorch produces int64 output with int32 input", + dtypes=(torch.int32,), + ) + .xfail( + reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", + matcher=lambda sample: sample.args[0] == "...ik, ...j -> ij", + ), # TorchLibOpInfo("empty_strided", core_ops.aten_empty_strided), # empty_strided is not in OPS_DB TorchLibOpInfo("eq", core_ops.aten_eq), TorchLibOpInfo("equal", core_ops.aten_equal), @@ -778,6 +824,7 @@ def _where_input_wrangler( TorchLibOpInfo("isneginf", core_ops.aten_isneginf), TorchLibOpInfo("isposinf", core_ops.aten_isposinf), TorchLibOpInfo("lift_fresh_copy", core_ops.aten_lift_fresh_copy), + TorchLibOpInfo("linalg.det", linalg_ops.aten_linalg_det), TorchLibOpInfo( "linalg.vector_norm", linalg_ops.aten_linalg_vector_norm, @@ -905,6 +952,7 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT), TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), + TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), @@ -1044,6 +1092,12 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), ), + TorchLibOpInfo( + "ops.aten.embedding_renorm", + core_ops.aten_embedding_renorm, + tolerance={torch.float16: (1e-2, 1e-2)}, + compare_shape_only_for_output=(1, 2, 3), + ), TorchLibOpInfo( "nn.functional.embedding", core_ops.aten_embedding, @@ -1249,10 +1303,19 @@ def _where_input_wrangler( TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), TorchLibOpInfo("rsub", core_ops.aten_rsub), + TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True, trace_only=True), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor, input_wrangler=_scalar_tensor_input_wrangler, + trace_only=True, + ), + TorchLibOpInfo( + "scalar_tensor", + core_ops.aten_scalar_tensor, + input_wrangler=_scalar_tensor_input_wrangler, + trace_only=True, + complex=True, ), TorchLibOpInfo( "scatter_add", @@ -1282,6 +1345,11 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", ), + TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail( + dtypes=(torch.float16,), + reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", + test_class_name="TestOutputConsistencyEager", + ), TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes, @@ -1324,6 +1392,15 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", ), + TorchLibOpInfo( + "squeeze_dim", + core_ops.aten_squeeze_dim_complex, + complex=True, + trace_only=True, + ).skip( + matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), + reason="this Aten overload only support one tensor as input and one int as args by design", + ), TorchLibOpInfo( "squeeze", core_ops.aten_squeeze, @@ -1333,6 +1410,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("stack", core_ops.aten_stack), TorchLibOpInfo("sub", core_ops.aten_sub), + TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True, trace_only=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB TorchLibOpInfo( "t", @@ -1386,9 +1464,10 @@ def _where_input_wrangler( reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), TorchLibOpInfo("unfold", core_ops.aten_unfold, trace_only=True), - TorchLibOpInfo("unfold_extra", core_ops.aten_unfold, trace_only=True), + TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold, trace_only=True), TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze), TorchLibOpInfo("view", core_ops.aten_view), + TorchLibOpInfo("view", core_ops.aten_view_complex, complex=True), TorchLibOpInfo("view_as", core_ops.aten_view_as), TorchLibOpInfo("view_as_complex", core_ops.aten_view_as_complex), TorchLibOpInfo("view_as_complex_copy", core_ops.aten_view_as_complex_copy), @@ -1621,6 +1700,20 @@ def _where_input_wrangler( reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975", ), TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True), + TorchLibOpInfo( + "ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True + ), + TorchLibOpInfo( + "ops.aten._native_batch_norm_legit.no_stats", + core_ops.aten__native_batch_norm_no_stats, + trace_only=True, + ), + TorchLibOpInfo( + "ops.aten._native_batch_norm_legit_functional", + core_ops.aten__native_batch_norm_legit_functional, + trace_only=True, + compare_shape_only_for_output=(3, 4), + ), TorchLibOpInfo( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, @@ -1910,12 +2003,16 @@ def _where_input_wrangler( "ops.aten.tensor.bool", core_ops.aten_tensor_bool ), # Custom from extra_opinfo TorchLibOpInfo( - "ops.aten.tensor.float", core_ops.aten_tensor_float # Custom from extra_opinfo + "ops.aten.tensor.float", + core_ops.aten_tensor_float, # Custom from extra_opinfo ), TorchLibOpInfo( "ops.aten.tensor.int", core_ops.aten_tensor_int ), # Custom from extra_opinfo TorchLibOpInfo("transpose", core_ops.aten_transpose, trace_only=True), + TorchLibOpInfo( + "transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True + ), TorchLibOpInfo( "var_mean", core_ops.aten_var_mean, diff --git a/onnxscript/tests/functions/onnxfns1A_test.py b/onnxscript/tests/functions/onnxfns1A_test.py index 2d3b6287f..09302a634 100644 --- a/onnxscript/tests/functions/onnxfns1A_test.py +++ b/onnxscript/tests/functions/onnxfns1A_test.py @@ -1,6 +1,5 @@ import unittest -import onnx import pytest from onnxscript.tests.common import onnx_script_test_case @@ -16,51 +15,27 @@ def setUpClass(cls): def test_onnxfns_relu(self): self.run_onnx_test(onnxfns1A.Relu) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="ONNX 1.13 does not support default values", - ) def test_onnxfns_selu(self): self.run_onnx_test(onnxfns1A.Selu) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="current onnx does not support default values", - ) def test_onnxfns_elu(self): self.run_onnx_test(onnxfns1A.Elu) def test_onnxfns_elu05(self): self.run_onnx_test(onnxfns1A.Elu05) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="ONNX 1.13 does not support default values", - ) def test_onnxfns_thresholded_relu(self): self.run_onnx_test(onnxfns1A.ThresholdedRelu) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="ONNX 1.13 does not support default values", - ) def test_onnxfns_leaky_relu(self): self.run_onnx_test(onnxfns1A.LeakyRelu) def test_onnxfns_prelu(self): self.run_onnx_test(onnxfns1A.PRelu) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="current onnx does not support default values", - ) def test_onnxfns_hard_sigmoid(self): self.run_onnx_test(onnxfns1A.HardSigmoid) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="current onnx does not support default values", - ) def test_onnxfns_shrink(self): self.run_onnx_test(onnxfns1A.Shrink) @@ -71,7 +46,8 @@ def test_onnxfns_hard_softsign(self): self.run_onnx_test(onnxfns1A.Softsign) @pytest.mark.xfail( - reason="Clip has optional input min and max. Need to find out how to pass default min and max to the test case executor." + strict=True, + reason="Clip has optional input min and max. Need to find out how to pass default min and max to the test case executor.", ) def test_onnxfns_hard_clip(self): self.run_onnx_test(onnxfns1A.Clip) diff --git a/onnxscript/tests/functions/onnxfns_test.py b/onnxscript/tests/functions/onnxfns_test.py index d9bcd80ec..68ae2e3b8 100644 --- a/onnxscript/tests/functions/onnxfns_test.py +++ b/onnxscript/tests/functions/onnxfns_test.py @@ -5,8 +5,6 @@ import unittest -import onnx - from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns1 @@ -57,10 +55,6 @@ def test_onnxfns_hard_softplus(self): def test_onnxfns_hard_softsign(self): self.run_onnx_test(onnxfns1.Softsign) - @unittest.skipIf( - not hasattr(onnx.FunctionProto, "attribute_proto"), - reason="current onnx does not support default values", - ) def test_onnxfns_hard_clip(self): self.run_onnx_test( onnxfns1.Clip, diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index b6cc0352e..53b640ab7 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -225,9 +225,7 @@ def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]: if isinstance(pytype, typing.TypeVar): constraints = pytype.__constraints__ if constraints: - return pytype_to_type_strings( - Union.__getitem__(constraints) - ) # pylint: disable=unnecessary-dunder-call + return pytype_to_type_strings(Union.__getitem__(constraints)) # pylint: disable=unnecessary-dunder-call bound = pytype.__bound__ if bound is None: return list(ALL_TENSOR_TYPE_STRINGS) diff --git a/onnxscript/values.py b/onnxscript/values.py index 36ce179e6..db52e9dc1 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -10,8 +10,14 @@ import types import typing from enum import IntFlag -from typing import _GenericAlias # type: ignore[attr-defined] -from typing import Any, ClassVar, Optional, Protocol, Sequence +from typing import ( # type: ignore[attr-defined] + Any, + ClassVar, + Optional, + Protocol, + Sequence, + _GenericAlias, +) import onnx import onnx.defs diff --git a/pyproject.toml b/pyproject.toml index ce600d321..1320277f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ select = [ "E", # pycodestyle "F", # Pyflakes "G", # flake8-logging-format + "I", # isort "ISC", # flake8-implicit-str-concat "N", # pep8-naming "NPY", # modern numpy @@ -143,7 +144,7 @@ ignore = [ "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing ] -line-length = 120 +line-length = 95 ignore-init-module-imports = true [tool.ruff.per-file-ignores] diff --git a/requirements-dev.txt b/requirements-dev.txt index 2292af13e..218ec9715 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ setuptools>=61.0.0 numpy onnx-weekly==1.15.0.dev20230807 -onnxruntime +onnxruntime>=1.15.1 typing_extensions # Docs site diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index dc98af02f..8b9ac433b 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.15.0.dev20231002 +onnx-weekly==1.16.0.dev20231106 diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 634b49fb2..5f24b4672 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,13 +1,10 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.1.0 +ruff==0.1.5 # MYPY -mypy==1.6.0 +mypy==1.7.0 types-PyYAML==6.0.12.11 -# BLACK-ISORT -black==23.9.1 -isort==5.12.0 # PYLINT pylint==2.17.6 # EDITORCONFIG-CHECKER