diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index f1ace3ea7..c90183744 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -349,6 +349,8 @@ def eval_function( # type: ignore[override] else: # Fall to call add_function_call pass + elif isinstance(args[0], Sequence): + return False else: # Python constants are scalars return True @@ -363,6 +365,8 @@ def eval_function( # type: ignore[override] else: # Fall to call add_function_call pass + elif isinstance(args[0], Sequence): + return False else: # Python constants are scalars return 0 diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8ef9cffdd..c1f244363 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -721,7 +721,7 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals return result -@torch_op("aten::argmin") +@torch_op("aten::argmin", traceable=True) def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -734,7 +734,7 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin") +@torch_op("aten::argmin", traceable=True) def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -3238,7 +3238,7 @@ def aten_exp(self: TFloat) -> TFloat: return op.Exp(self) -@torch_op("aten::exp2") +@torch_op("aten::exp2", traceable=True) def aten_exp2(self: TFloat) -> TFloat: """exp2(Tensor self) -> Tensor""" @@ -3257,7 +3257,7 @@ def aten_expand(self: TTensor, size: TInt) -> TTensor: return op.Expand(self, size) -@torch_op("aten::expand_as") +@torch_op("aten::expand_as", traceable=True) def aten_expand_as(self: TTensor, other: TTensor) -> TTensor: """expand_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -4906,7 +4906,10 @@ def aten_margin_ranking_loss( raise NotImplementedError() -@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor")) +@torch_op( + ("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"), + traceable=True, +) def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: """masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor""" # NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types. @@ -5037,7 +5040,7 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: else: if IsScalar(dim): dim = op.Unsqueeze(dim, axes=0) - result = op.ReduceMean(self, axes=dim, keepdims=keepdim) + result = op.ReduceMean(self, dim, keepdims=keepdim) return result @@ -7482,7 +7485,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB return result -@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax")) +@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), traceable=True) def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" @@ -7887,7 +7890,7 @@ def aten_symeig( raise NotImplementedError() -@torch_op("aten::t") +@torch_op("aten::t", traceable=True) def aten_t(self: TTensor) -> TTensor: """t(Tensor(a) self) -> Tensor(a)""" @@ -8063,7 +8066,7 @@ def aten_to_sparse_csr(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::topk") +@torch_op("aten::topk", traceable=True) def aten_topk( self: TReal, k: INT64, dim: int = -1, largest: bool = True, sorted: bool = True ) -> Tuple[TReal, INT64]: diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 8b71f7cb4..f35b4f611 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -25,6 +25,7 @@ ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), private=True, complex=True, + traceable=True, ) def _fftn_onnx_normalization( self, @@ -35,7 +36,7 @@ def _fftn_onnx_normalization( ) -> TFloat: # Obtain the total_sample_count (n) for normalization self_shape = op.Shape(self) - total_sample_count = op.ReduceProd(self_shape[dims], keepdims=0) + total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) total_sample_count = op.CastLike(total_sample_count, transformed) # Normalize the result diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index fddf80ac0..7890fb1c0 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -342,6 +342,7 @@ def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool self = op.Abs(self) ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT + # TODO(justinchuby): Evaluate IsInf in trace mode if op.IsInf(ord, detect_negative=0, detect_positive=1): result = op.ReduceMax(self, keepdims=keepdim) elif op.IsInf(ord, detect_negative=1, detect_positive=0): @@ -373,6 +374,7 @@ def _aten_linalg_vector_norm_onnx( dim = op.Reshape(dim, op.Constant(value_ints=[-1])) self = op.Abs(self) ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT + # TODO(justinchuby): Evaluate IsInf in trace mode if op.IsInf(ord, detect_negative=0, detect_positive=1): result = op.ReduceMax(self, dim, keepdims=keepdim) elif op.IsInf(ord, detect_negative=1, detect_positive=0): diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a6cbb1849..7730008ef 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -40,7 +40,7 @@ TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE]) -@torch_op("aten::aten_adaptive_avg_pool1d") +@torch_op("aten::aten_adaptive_avg_pool1d", traceable=True) def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat: """adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor""" @@ -58,7 +58,7 @@ def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat: return result -@torch_op("aten::aten_adaptive_avg_pool2d") +@torch_op("aten::aten_adaptive_avg_pool2d", traceable=True) def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat: """adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor""" @@ -76,7 +76,7 @@ def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat: return result -@torch_op("aten::aten_adaptive_avg_pool3d") +@torch_op("aten::aten_adaptive_avg_pool3d", traceable=True) def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat: """adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor""" @@ -350,7 +350,7 @@ def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: return op.Celu(self, alpha=alpha) # op.Celu only support float32 -@torch_op("aten::celu") +@torch_op("aten::celu", traceable=True) def aten_celu_type_promoted( self: TFloatUnlessFloat32, alpha: float = 1.0 ) -> TFloatUnlessFloat32: @@ -409,7 +409,7 @@ def aten_conv_depthwise3d( raise NotImplementedError() -@torch_op("aten::cross_entropy_loss") +@torch_op("aten::cross_entropy_loss", traceable=True) def aten_cross_entropy_loss( self: TFloatOrBFloat16, target: IntType, @@ -871,7 +871,7 @@ def aten_max_pool2d( return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3) -@torch_op("internal::max_pool", private=True) +@torch_op("internal::max_pool", private=True, traceable=True) def _aten_max_pool_onnx( self: TFloatOrUInt8, kernel_shape: Sequence[int], @@ -1003,7 +1003,7 @@ def aten_max_pool3d_with_indices( ) -@torch_op("internal::max_pool_with_indices", private=True) +@torch_op("internal::max_pool_with_indices", private=True, traceable=True) def _aten_max_pool_with_indices_onnx( self: TFloatOrUInt8, kernel_size: Sequence[int], @@ -1159,7 +1159,7 @@ def aten_mkldnn_reorder_conv3d_weight( raise NotImplementedError() -@torch_op("aten::mse_loss") +@torch_op("aten::mse_loss", traceable=True) def aten_mse_loss(self: TReal, target: TReal, reduction: int = 1) -> TReal: """mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor""" # FIXME: When reduction=0, the shape(result) will be different than other case @@ -1235,7 +1235,7 @@ def aten_multilabel_margin_loss_forward( raise NotImplementedError() -@torch_op("aten::nll_loss") +@torch_op("aten::nll_loss", traceable=True) def aten_nll_loss( self: TFloat, target: INT64, @@ -1248,7 +1248,7 @@ def aten_nll_loss( if self_rank_is_1: # self rank should be at least 2 self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - rank_target = op.Size(op.Shape(target)) + rank_target = Rank(target) if rank_target == 0: # target rank should be at least 1 target = op.Unsqueeze(target, op.Constant(value_ints=[0])) @@ -1271,7 +1271,7 @@ def aten_nll_loss( return result -@torch_op("aten::nll_loss") +@torch_op("aten::nll_loss", traceable=True) def aten_nll_loss_weight( self: TFloat, target: INT64, @@ -1282,10 +1282,11 @@ def aten_nll_loss_weight( """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" self_rank_is_1 = Rank(self) == 1 - if self_rank_is_1: # self rank should be at least 2 + if self_rank_is_1: + # self rank should be at least 2 self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - rank_target = op.Size(op.Shape(target)) + rank_target = Rank(target) if rank_target == 0: # target rank should be at least 1 target = op.Unsqueeze(target, op.Constant(value_ints=[0])) @@ -1490,7 +1491,7 @@ def aten_relu(self: TReal) -> TReal: return op.Relu(self) -@torch_op("aten::relu6") +@torch_op("aten::relu6", traceable=True) def aten_relu6(self: TReal) -> TReal: """relu6(Tensor self) -> Tensor""" @@ -1778,7 +1779,7 @@ def aten__scaled_dot_product_flash_attention( ) -@torch_op("aten::_scaled_dot_product_efficient_attention", private=True) +@torch_op("aten::_scaled_dot_product_efficient_attention", private=True, traceable=True) def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index a8f947cbf..bbe6a7850 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -481,6 +481,7 @@ def _where_input_wrangler( enabled_if=version_utils.torch_older_than("2.2"), ) .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.17"), dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", @@ -493,6 +494,7 @@ def _where_input_wrangler( enabled_if=version_utils.torch_older_than("2.2"), ) .xfail( + enabled_if=version_utils.onnxruntime_older_than("1.17"), dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", @@ -544,6 +546,13 @@ def _where_input_wrangler( "decomposed", dtypes=(torch.int16, torch.int32, torch.int64), reason="ONNX Runtime does not support int inputs to Gemm", + ) + .xfail( + "decomposed", + matcher=lambda sample: torch.numel(sample.input) == 0 + or torch.numel(sample.args[0]) == 0 + or torch.numel(sample.args[1]) == 0, + reason="ONNX Runtime does not support zero sized inputs", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo( @@ -1545,6 +1554,7 @@ def _where_input_wrangler( "t", core_ops.aten_t, ).xfail( + enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), @@ -1895,7 +1905,11 @@ def _where_input_wrangler( "native_layer_norm", core_ops.aten_native_layer_norm, trace_only=True, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (1e-1, 7e-4)}, + ).skip( + dtypes=(torch.float16,), + device_type="cpu", + reason="native_layer_norm outputs different dtypes on CPU and CUDA. Our implematation is based on that for CUDA", ), TorchLibOpInfo( "nn.functional.avg_pool1d", @@ -2091,9 +2105,14 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ).skip( + ) + .skip( enabled_if=version_utils.torch_older_than("2.1"), reason="The operator is not supported in older version.", + ) + .skip( + device_type="cpu", + reason="_scaled_dot_product_flash_attention only supports CUDA", ), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", diff --git a/requirements-dev.txt b/requirements-dev.txt index 4e11cb2f3..f1a41e266 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ setuptools>=61.0.0 numpy onnx-weekly>=1.16.0.dev20231204 -onnxruntime>=1.15.1 +onnxruntime>=1.17.0 typing_extensions # Docs site