From 1a84f53c35049192b1d380cc374a9be9f6cf8f0a Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 23 Sep 2024 22:02:29 +0200 Subject: [PATCH] Make argmin/armax support identical data types and add int64 support (#21641) --- docs/OperatorKernels.md | 12 +-- .../providers/cpu/cpu_execution_provider.cc | 42 ++++++++++ .../providers/cpu/reduction/reduction_ops.cc | 14 ++++ .../cpu/reduction/reduction_ops_test.cc | 77 +++++++++++++++++++ 4 files changed, 139 insertions(+), 6 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 121240e6e18f9..407e08c96a891 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -27,12 +27,12 @@ Do not modify directly.* |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[1, 10]|**T** = tensor(float), tensor(int32)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |Asin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| |Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7ed776f1358a5..7b1b136eb091e 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -227,10 +227,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, GRU); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, LSTM); @@ -408,9 +414,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); @@ -636,9 +646,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat); @@ -1443,16 +1457,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, ReduceSumSquare)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1725,12 +1751,20 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2065,11 +2099,19 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 5aac1d9387f57..24fbfbe8d525b 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -288,22 +288,36 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceSumSquare, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMin, 13); FastReduceKind operator|(FastReduceKind a, FastReduceKind b) { return static_cast(static_cast(a) | static_cast(b)); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0697187a777d6..0968bc32e0de4 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3246,6 +3246,26 @@ TEST(ReductionOpTest, ArgMax_do_not_keepdims_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: node1: at least 2 dimensions are required for input } +TEST(ReductionOpTest, ArgMax_int64) { + OpTester test("ArgMax", 13); + test.AddAttribute("axis", (int64_t)1); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {3, 1, 2}, + {1, 1, + 1, 1, + 1, 1}); + test.Run(); +} + TEST(ReductionOpTest, ArgMax_int32) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); @@ -3511,6 +3531,63 @@ TEST(ReductionOpTest, ArgMin_do_not_keepdims_2_select_last) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMin_uint8) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + +TEST(ReductionOpTest, ArgMin_int8) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(ReductionOpTest, ArgMin_int64) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + TEST(ReductionOpTest, ArgMin_int32) { OpTester test("ArgMin"); test.AddAttribute("axis", (int64_t)0);