diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 18010960e11c8..28279b937f055 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -596,9 +596,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat); @@ -1849,12 +1851,16 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index ce834e371fdef..c3f8d01cd2fa4 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -256,12 +256,14 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13); FastReduceKind operator|(FastReduceKind a, FastReduceKind b) { return static_cast(static_cast(a) | static_cast(b)); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index c9b851e450f9d..0e906e365afeb 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include "gtest/gtest.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" @@ -2999,6 +3000,46 @@ TEST(ReductionOpTest, ArgMax_uint8) { test.Run(); } +TEST(ReductionOpTest, ArgMax_int64) { + OpTester test("ArgMin"); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {std::numeric_limits::max(), 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + +TEST(ReductionOpTest, ArgMax_int64_select_last) { + OpTester test("ArgMin", 12); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddAttribute("select_last_index", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {std::numeric_limits::max(), 2, + 3, 4, + + std::numeric_limits::min(), 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {1, 0, + 0, 0}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(ReductionOpTest, ArgMax2D) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); @@ -3197,6 +3238,46 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) { test.Run(); } +TEST(ReductionOpTest, ArgMin_int64) { + OpTester test("ArgMin"); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {std::numeric_limits::min(), 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + +TEST(ReductionOpTest, ArgMin_int64_select_last) { + OpTester test("ArgMin", 12); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddAttribute("select_last_index", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {std::numeric_limits::min(), 2, + 3, 4, + + std::numeric_limits::min(), 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {1, 0, + 0, 0}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) { FastReduceKind fast_kind; TensorShapeVector fast_shape, fast_output_shape, fast_axes;