diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c66a978e9..0c750da7d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -693,8 +693,21 @@ def aten_arctanh(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::argmax", traceable=True) -def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: +@torch_op("aten::argmax", trace_only=True) +def aten_argmax( + self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False +) -> INT64: + """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" + + if dim is None: + result = _aten_argmax(self, keepdim) + else: + result = _aten_argmax_dim(self, dim, keepdim) + return result + + +@torch_op("aten::argmax", private=True, traceable=True) +def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" self_is_scaler = IsScalar(self) @@ -706,8 +719,8 @@ def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmax", traceable=True) -def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: +@torch_op("aten::argmax", private=True, traceable=True) +def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" self_is_scaler = IsScalar(self) @@ -721,8 +734,21 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals return result -@torch_op("aten::argmin", traceable=True) -def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: +@torch_op("aten::argmin", trace_only=True) +def aten_argmin( + self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False +) -> INT64: + """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" + + if dim is None: + result = _aten_argmin(self, keepdim) + else: + result = _aten_argmin_dim(self, dim, keepdim) + return result + + +@torch_op("aten::argmin", private=True, traceable=True) +def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" self_is_scaler = IsScalar(self) @@ -734,8 +760,8 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin", traceable=True) -def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: +@torch_op("aten::argmin", private=True, traceable=True) +def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" self_is_scaler = IsScalar(self) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 1fac7dd42..e4dec531a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1688,25 +1688,7 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("end") is not None, reason="arange overload does not support positional 'end' argument", ), - TorchLibOpInfo("argmax", core_ops.aten_argmax) - .skip( - matcher=lambda sample: "dim" in sample.kwargs, - reason="this overload does not support the 'dim' attribute by design", - ) - .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - dtypes=(torch.int64,), - reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo("argmax_dim", core_ops.aten_argmax_dim) - .xfail( - matcher=lambda sample: "dim" not in sample.kwargs, - reason="this overload requires the 'dim' attribute by design", - ) + TorchLibOpInfo("argmax", core_ops.aten_argmax, trace_only=True) .skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), @@ -1716,25 +1698,7 @@ def _where_input_wrangler( dtypes=(torch.int64,), reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654", ), - TorchLibOpInfo("argmin", core_ops.aten_argmin) - .skip( - matcher=lambda sample: "dim" in sample.kwargs, - reason="this overload does not support the 'dim' attribute by design", - ) - .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - dtypes=(torch.int64,), - reason="fixme: ORT did not implement ArgMin for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo("argmin_dim", core_ops.aten_argmin_dim) - .xfail( - matcher=lambda sample: "dim" not in sample.kwargs, - reason="this overload requires the 'dim' attribute by design", - ) + TorchLibOpInfo("argmin", core_ops.aten_argmin, trace_only=True) .skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), @@ -2399,8 +2363,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step")) -ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",))