Skip to content

Commit

Permalink
Make argmin/armax support identical data types and add int64 support (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau authored Sep 23, 2024
1 parent 80e9df8 commit 1a84f53
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 6 deletions.
12 changes: 6 additions & 6 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ Do not modify directly.*
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|AffineGrid|*in* theta:**T1**<br> *in* size:**T2**<br> *out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)|
|||[1, 10]|**T** = tensor(float), tensor(int32)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|Asin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float)|
|Asinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float)|
|Atan|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float)|
Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t,
ReduceSumSquare);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, GRU);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, LSTM);
Expand Down Expand Up @@ -408,9 +414,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1);
Expand Down Expand Up @@ -636,9 +646,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat);
Expand Down Expand Up @@ -1443,16 +1457,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
int64_t, ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
int8_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
uint8_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
int32_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
int64_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
int8_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
uint8_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
int32_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
int64_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, GRU)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, RNN)>,
Expand Down Expand Up @@ -1725,12 +1751,20 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
uint8_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
int32_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
int64_t, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
int8_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
uint8_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
int32_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
int64_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Loop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
Hardmax)>,
Expand Down Expand Up @@ -2065,11 +2099,19 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t,
ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t,
ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double,
ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t,
ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t,
ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t,
ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t,
ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,22 +288,36 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 18);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceSumSquare, 18);

REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13);

REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 11, 12)
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMin, 13);

FastReduceKind operator|(FastReduceKind a, FastReduceKind b) {
return static_cast<FastReduceKind>(static_cast<uint8_t>(a) | static_cast<uint8_t>(b));
Expand Down
77 changes: 77 additions & 0 deletions onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3246,6 +3246,26 @@ TEST(ReductionOpTest, ArgMax_do_not_keepdims_2) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: node1: at least 2 dimensions are required for input
}

TEST(ReductionOpTest, ArgMax_int64) {
OpTester test("ArgMax", 13);
test.AddAttribute("axis", (int64_t)1);
test.AddAttribute("keepdims", (int64_t)1);
test.AddInput<int64_t>("data", {3, 2, 2},
{1, 2,
3, 4,

5, 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {3, 1, 2},
{1, 1,
1, 1,
1, 1});
test.Run();
}

TEST(ReductionOpTest, ArgMax_int32) {
OpTester test("ArgMax");
test.AddAttribute("axis", (int64_t)1);
Expand Down Expand Up @@ -3511,6 +3531,63 @@ TEST(ReductionOpTest, ArgMin_do_not_keepdims_2_select_last) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(ReductionOpTest, ArgMin_uint8) {
OpTester test("ArgMin", 13);
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<uint8_t>("data", {3, 2, 2},
{1, 2,
3, 4,

5, 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{0, 0,
0, 0});
test.Run();
}

TEST(ReductionOpTest, ArgMin_int8) {
OpTester test("ArgMin", 13);
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<int8_t>("data", {3, 2, 2},
{1, 2,
3, 4,

5, 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{0, 0,
0, 0});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(ReductionOpTest, ArgMin_int64) {
OpTester test("ArgMin", 13);
test.AddAttribute("axis", (int64_t)0);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<int64_t>("data", {3, 2, 2},
{1, 2,
3, 4,

5, 6,
7, 8,

9, 10,
11, 12});
test.AddOutput<int64_t>("reduced", {2, 2},
{0, 0,
0, 0});
test.Run();
}

TEST(ReductionOpTest, ArgMin_int32) {
OpTester test("ArgMin");
test.AddAttribute("axis", (int64_t)0);
Expand Down

0 comments on commit 1a84f53

Please sign in to comment.