diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 121240e6e18f9..407e08c96a891 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -27,12 +27,12 @@ Do not modify directly.*
|Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)|
|And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
-|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
-|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
-|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
-|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)|
-|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)|
-|||[1, 10]|**T** = tensor(float), tensor(int32)|
+|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|Asin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)|
|Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)|
|Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 7ed776f1358a5..7b1b136eb091e 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -227,10 +227,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t,
ReduceSumSquare);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, GRU);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, LSTM);
@@ -408,9 +414,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1);
@@ -636,9 +646,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat);
@@ -1443,16 +1457,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
int64_t, ReduceSumSquare)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1725,12 +1751,20 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
uint8_t, ArgMax)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2065,11 +2099,19 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ArgMax)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
index 5aac1d9387f57..24fbfbe8d525b 100644
--- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
+++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
@@ -288,22 +288,36 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 18);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceSumSquare, 18);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10);
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 1, 10)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 11, 12)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 11, 12)
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13);
+REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10);
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 1, 10)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 1, 10)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 1, 10)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 1, 10)
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 11, 12)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 11, 12)
+REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 11, 12)
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13);
REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13);
+REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13);
+REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMin, 13);
+REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMin, 13);
FastReduceKind operator|(FastReduceKind a, FastReduceKind b) {
return static_cast(static_cast(a) | static_cast(b));
diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
index 0697187a777d6..0968bc32e0de4 100644
--- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
@@ -3246,6 +3246,26 @@ TEST(ReductionOpTest, ArgMax_do_not_keepdims_2) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: node1: at least 2 dimensions are required for input
}
+TEST(ReductionOpTest, ArgMax_int64) {
+ OpTester test("ArgMax", 13);
+ test.AddAttribute("axis", (int64_t)1);
+ test.AddAttribute("keepdims", (int64_t)1);
+ test.AddInput("data", {3, 2, 2},
+ {1, 2,
+ 3, 4,
+
+ 5, 6,
+ 7, 8,
+
+ 9, 10,
+ 11, 12});
+ test.AddOutput("reduced", {3, 1, 2},
+ {1, 1,
+ 1, 1,
+ 1, 1});
+ test.Run();
+}
+
TEST(ReductionOpTest, ArgMax_int32) {
OpTester test("ArgMax");
test.AddAttribute("axis", (int64_t)1);
@@ -3511,6 +3531,63 @@ TEST(ReductionOpTest, ArgMin_do_not_keepdims_2_select_last) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
+TEST(ReductionOpTest, ArgMin_uint8) {
+ OpTester test("ArgMin", 13);
+ test.AddAttribute("axis", (int64_t)0);
+ test.AddAttribute("keepdims", (int64_t)0);
+ test.AddInput("data", {3, 2, 2},
+ {1, 2,
+ 3, 4,
+
+ 5, 6,
+ 7, 8,
+
+ 9, 10,
+ 11, 12});
+ test.AddOutput("reduced", {2, 2},
+ {0, 0,
+ 0, 0});
+ test.Run();
+}
+
+TEST(ReductionOpTest, ArgMin_int8) {
+ OpTester test("ArgMin", 13);
+ test.AddAttribute("axis", (int64_t)0);
+ test.AddAttribute("keepdims", (int64_t)0);
+ test.AddInput("data", {3, 2, 2},
+ {1, 2,
+ 3, 4,
+
+ 5, 6,
+ 7, 8,
+
+ 9, 10,
+ 11, 12});
+ test.AddOutput("reduced", {2, 2},
+ {0, 0,
+ 0, 0});
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
+TEST(ReductionOpTest, ArgMin_int64) {
+ OpTester test("ArgMin", 13);
+ test.AddAttribute("axis", (int64_t)0);
+ test.AddAttribute("keepdims", (int64_t)0);
+ test.AddInput("data", {3, 2, 2},
+ {1, 2,
+ 3, 4,
+
+ 5, 6,
+ 7, 8,
+
+ 9, 10,
+ 11, 12});
+ test.AddOutput("reduced", {2, 2},
+ {0, 0,
+ 0, 0});
+ test.Run();
+}
+
TEST(ReductionOpTest, ArgMin_int32) {
OpTester test("ArgMin");
test.AddAttribute("axis", (int64_t)0);