Skip to content

Commit

Permalink
Implement ArgMin and ArgMax for int64
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Sep 7, 2023
1 parent 0a3eb60 commit da91ffc
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat);
Expand Down Expand Up @@ -1849,12 +1851,16 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
uint8_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int32_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int64_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int32_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
int64_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,14 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13);

REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12)
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13);

FastReduceKind operator|(FastReduceKind a, FastReduceKind b) {
return static_cast<FastReduceKind>(static_cast<uint8_t>(a) | static_cast<uint8_t>(b));
Expand Down
81 changes: 81 additions & 0 deletions onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <random>
#include <cmath>
#include <type_traits>
#include <limits>
#include "gtest/gtest.h"
#include "test/common/dnnl_op_test_utils.h"
#include "test/common/tensor_op_test_utils.h"
Expand Down Expand Up @@ -2999,6 +3000,46 @@ TEST(ReductionOpTest, ArgMax_uint8) {
test.Run();
}

TEST(ReductionOpTest, ArgMax_int64) {
OpTester test("ArgMin");
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<int64_t>("data", {3, 2, 2},
{std::numeric_limits<int64_t>::max(), 2,
3, 4,

5, 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{0, 0,
0, 0});
test.Run();
}

TEST(ReductionOpTest, ArgMax_int64_select_last) {
OpTester test("ArgMin", 12);
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddAttribute("select_last_index", (int64_t)1);
test.AddInput<int64_t>("data", {3, 2, 2},
{std::numeric_limits<int64_t>::max(), 2,
3, 4,

std::numeric_limits<int64_t>::min(), 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{1, 0,
0, 0});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(ReductionOpTest, ArgMax2D) {
OpTester test("ArgMax");
test.AddAttribute("axis", (int64_t)1);
Expand Down Expand Up @@ -3197,6 +3238,46 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) {
test.Run();
}

TEST(ReductionOpTest, ArgMin_int64) {
OpTester test("ArgMin");
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<int64_t>("data", {3, 2, 2},
{std::numeric_limits<int64_t>::min(), 2,
3, 4,

5, 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{0, 0,
0, 0});
test.Run();
}

TEST(ReductionOpTest, ArgMin_int64_select_last) {
OpTester test("ArgMin", 12);
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddAttribute("select_last_index", (int64_t)1);
test.AddInput<int64_t>("data", {3, 2, 2},
{std::numeric_limits<int64_t>::min(), 2,
3, 4,

std::numeric_limits<int64_t>::min(), 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{1, 0,
0, 0});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) {
FastReduceKind fast_kind;
TensorShapeVector fast_shape, fast_output_shape, fast_axes;
Expand Down

0 comments on commit da91ffc

Please sign in to comment.