From 927ffe48ea11dbab4b393988b20699855c6fa785 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Fri, 12 Jan 2024 20:43:44 +0000 Subject: [PATCH] Label encoder opset4 (#17977) ### Description Implements LabelEncoder as per `ai.onnx.ml` opset 4 for the upcoming ONNX 1.15 release. ~~This currently depends on a new ONNX release candidate and so is marked as draft in the meantime.~~ ### Motivation and Context Closes https://github.com/microsoft/onnxruntime/issues/17602 --- docs/OperatorKernels.md | 3 +- .../providers/cpu/cpu_execution_provider.cc | 1490 ++++++++++------- .../core/providers/cpu/ml/label_encoder.cc | 430 +++-- .../core/providers/cpu/ml/label_encoder.h | 191 ++- .../providers/cpu/ml/label_encoder_test.cc | 282 +++- .../onnx_backend_test_series_filters.jsonc | 4 - 6 files changed, 1587 insertions(+), 813 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a2bb39da76235..394bd7ad2abae 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -425,7 +425,8 @@ Do not modify directly.* |DictVectorizer|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = map(int64,tensor(double)), map(int64,tensor(float)), map(int64,tensor(string)), map(string,tensor(double)), map(string,tensor(float)), map(string,tensor(int64))
**T2** = tensor(double), tensor(float), tensor(int64), tensor(string)| |FeatureVectorizer|*in* X:**T1**
*out* Y:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Imputer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(int64)| -|LabelEncoder|*in* X:**T1**
*out* Y:**T2**|2+|**T1** = tensor(float), tensor(int64), tensor(string)
**T2** = tensor(float), tensor(int64), tensor(string)| +|LabelEncoder|*in* X:**T1**
*out* Y:**T2**|4+|**T1** = tensor(double), tensor(float), tensor(int64), tensor(string)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int64), tensor(string)| +|||[2, 3]|**T1** = tensor(float), tensor(int64), tensor(string)
**T2** = tensor(float), tensor(int64), tensor(string)| |||1|**T1** = tensor(int64), tensor(string)
**T2** = tensor(int64), tensor(string)| |LinearClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |LinearRegressor|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 6aef03a32db09..cbdf79caf3afd 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -25,8 +25,7 @@ struct KernelRegistryAndStatus { namespace onnxruntime { CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} { -} + : IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {} std::vector CPUExecutionProvider::CreatePreferredAllocators() { bool create_arena = info_.create_arena; @@ -155,8 +154,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, TopK); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, TopK); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); @@ -185,10 +186,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, - ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, - ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, @@ -290,17 +289,28 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, Sign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Erf); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int64_t_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_string_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_float_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int32_t_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_int32_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int32_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, + int64_t_int64_t_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_string_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_float_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int32_t_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_int32_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int32_t, + OneHot); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); @@ -331,8 +341,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 15, PRelu); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, Upsample); @@ -350,11 +362,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int8_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, + QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, + QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); @@ -400,12 +417,18 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, + ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, + ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, + ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); @@ -424,10 +447,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, + ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax); @@ -453,7 +480,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv); #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, MLFloat16, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 18, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 18, MLFloat16, + AveragePool); #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, If); @@ -531,15 +559,22 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Ei // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MLFloat16_float, Dropout); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MLFloat16_double, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_float, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, + Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, + Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, + Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, + GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); @@ -549,9 +584,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Clip); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, + DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Expand); @@ -577,8 +615,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Max); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Mean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, + QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, + QuantizeLinear); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, Size); @@ -699,12 +739,18 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, + ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, + ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, + ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceMax); @@ -723,10 +769,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, + ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); @@ -774,8 +824,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 15, Identity); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN); @@ -1035,96 +1087,127 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1155,7 +1238,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1163,8 +1247,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1298,31 +1380,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 9 BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1393,10 +1472,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, IsNaN)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1441,36 +1519,36 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, NonZero)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1499,40 +1577,42 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1552,10 +1632,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1611,15 +1698,17 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // REVIEW(codemzs): ConstEigenVectorArrayMap.cast, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 13 BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // OpSet 14 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2190,29 +2258,37 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #if !defined(DISABLE_OPTIONAL_TYPE) - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, #endif // Opset 16 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2246,14 +2322,14 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { LayerNormalization)>, // Opset 18 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2388,8 +2496,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { DequantizeLinear)>, #endif BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2413,9 +2523,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { #endif BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2436,18 +2549,27 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2468,23 +2590,37 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -2532,23 +2668,37 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int32_t, Scaler); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, SVMClassifier); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, SVMRegressor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int64_t, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int32_t, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, TreeEnsembleRegressor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, TreeEnsembleRegressor); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int64_t, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int32_t, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, + TreeEnsembleRegressor); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, + TreeEnsembleRegressor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap); - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_float, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_float, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_int64, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_int64, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_string, LabelEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_float, LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, float_string, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, string_float, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, int64_float, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, float_int64, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, int64_string, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, string_int64, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, int64_int64, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, string_string, + LabelEncoder); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, 3, float_float, + LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int64_t, TreeEnsembleClassifier); @@ -2556,6 +2706,22 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, int64_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, int64_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, int64_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_int16, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, double_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_double, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, int64_double, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, double_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, double_double, LabelEncoder); + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -2606,46 +2772,45 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.cc b/onnxruntime/core/providers/cpu/ml/label_encoder.cc index 7f626cfefb0c8..65102b62a963b 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.cc +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.cc @@ -10,14 +10,12 @@ namespace onnxruntime { namespace ml { ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL( - LabelEncoder, - 1, 1, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) + LabelEncoder, 1, 1, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .SinceVersion(1, 2), LabelEncoder); @@ -39,12 +37,11 @@ Status LabelEncoder::Compute(OpKernelContext* context) const { // map isn't going to change so get end() once instead of calling inside the for_each loop const auto map_end = string_to_int_map_.end(); - std::for_each(input.begin(), input.end(), - [&out, &map_end, this](const std::string& value) { - auto map_to = string_to_int_map_.find(value); - *out = map_to == map_end ? default_int_ : map_to->second; - ++out; - }); + std::for_each(input.begin(), input.end(), [&out, &map_end, this](const std::string& value) { + auto map_to = string_to_int_map_.find(value); + *out = map_to == map_end ? default_int_ : map_to->second; + ++out; + }); } else { if (!Y.IsDataTypeString()) return Status(ONNXRUNTIME, FAIL, "Input of tensor(int64) must have output of tensor(string)"); @@ -55,169 +52,346 @@ Status LabelEncoder::Compute(OpKernelContext* context) const { const auto map_end = int_to_string_map_.end(); - std::for_each(input.begin(), input.end(), - [&out, &map_end, this](const int64_t& value) { - auto map_to = int_to_string_map_.find(value); - *out = map_to == map_end ? default_string_ : map_to->second; - ++out; - }); + std::for_each(input.begin(), input.end(), [&out, &map_end, this](const int64_t& value) { + auto map_to = int_to_string_map_.find(value); + *out = map_to == map_end ? default_string_ : map_to->second; + ++out; + }); } return Status::OK(); } -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - float_string, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, float_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2); template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_floats"; - _value_field_name = "values_strings"; - info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); -}; + key_field_name_ = "keys_floats"; + value_field_name_ = "values_strings"; + info.GetAttrOrDefault("default_string", &default_value_, std::string("_Unused")); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - string_float, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, string_float, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2); template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_strings"; - _value_field_name = "values_floats"; - info.GetAttrOrDefault("default_float", &_default_value, -0.0f); -}; + key_field_name_ = "keys_strings"; + value_field_name_ = "values_floats"; + info.GetAttrOrDefault("default_float", &default_value_, -0.0f); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - int64_float, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, int64_float, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2); template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_int64s"; - _value_field_name = "values_floats"; - info.GetAttrOrDefault("default_float", &_default_value, -0.0f); -}; + key_field_name_ = "keys_int64s"; + value_field_name_ = "values_floats"; + info.GetAttrOrDefault("default_float", &default_value_, -0.0f); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - float_int64, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, float_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2); template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_floats"; - _value_field_name = "values_int64s"; - info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); -}; + key_field_name_ = "keys_floats"; + value_field_name_ = "values_int64s"; + info.GetAttrOrDefault("default_int64", &default_value_, (std::int64_t)-1); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - string_string, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, string_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2) template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_strings"; - _value_field_name = "values_strings"; - info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); -}; + key_field_name_ = "keys_strings"; + value_field_name_ = "values_strings"; + info.GetAttrOrDefault("default_string", &default_value_, std::string("_Unused")); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - float_float, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, float_float, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2) template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_floats"; - _value_field_name = "values_floats"; - info.GetAttrOrDefault("default_float", &_default_value, -0.0f); -}; + key_field_name_ = "keys_floats"; + value_field_name_ = "values_floats"; + info.GetAttrOrDefault("default_float", &default_value_, -0.0f); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - int64_string, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, int64_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2) template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_int64s"; - _value_field_name = "values_strings"; - info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); -}; + key_field_name_ = "keys_int64s"; + value_field_name_ = "values_strings"; + info.GetAttrOrDefault("default_string", &default_value_, std::string("_Unused")); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - string_int64, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, string_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2) template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_strings"; - _value_field_name = "values_int64s"; - info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); -}; + key_field_name_ = "keys_strings"; + value_field_name_ = "values_int64s"; + info.GetAttrOrDefault("default_int64", &default_value_, static_cast(-1)); +} -ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( - LabelEncoder, - 2, - int64_int64, - KernelDefBuilder().TypeConstraint("T1", - std::vector{DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{DataTypeImpl::GetTensorType()}), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL( + LabelEncoder, 2, 3, int64_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), LabelEncoder_2) template <> void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { - _key_field_name = "keys_int64s"; - _value_field_name = "values_int64s"; - info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); -}; + key_field_name_ = "keys_int64s"; + value_field_name_ = "values_int64s"; + info.GetAttrOrDefault("default_int64", &default_value_, static_cast(-1)); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, int64_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_int64s"; + value_field_name_ = "values_int64s"; + default_value_ = GetDefault(kernel_info, "default_int64", static_cast(-1)); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, int64_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_int64s"; + value_field_name_ = "values_strings"; + default_value_ = GetDefault(kernel_info, "default_string", std::string("_Unused")); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, int64_float, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_int64s"; + value_field_name_ = "values_floats"; + default_value_ = GetDefault(kernel_info, "default_float", 0.f); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(LabelEncoder, 4, float_float, + KernelDefBuilder() + .TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_floats"; + value_field_name_ = "values_floats"; + default_value_ = GetDefault(kernel_info, "default_float", -0.f); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, float_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_floats"; + value_field_name_ = "values_strings"; + default_value_ = GetDefault(kernel_info, "default_string", std::string("_Unused")); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, float_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_floats"; + value_field_name_ = "values_int64s"; + default_value_ = GetDefault(kernel_info, "default_int64", static_cast(-1)); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, string_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_strings"; + value_field_name_ = "values_int64s"; + default_value_ = GetDefault(kernel_info, "default_int64", static_cast(-1)); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, string_float, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_strings"; + value_field_name_ = "values_floats"; + default_value_ = GetDefault(kernel_info, "default_float", 0.f); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, string_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_strings"; + value_field_name_ = "values_strings"; + default_value_ = GetDefault(kernel_info, "default_string", std::string("_Unused")); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, string_int16, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_strings"; + default_value_ = static_cast(GetDefault(kernel_info, "", static_cast(-1))); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(LabelEncoder, 4, double_double, + KernelDefBuilder() + .TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + default_value_ = GetDefault(kernel_info, "default_float", -0.); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, double_string, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + value_field_name_ = "values_strings"; + default_value_ = GetDefault(kernel_info, "default_string", std::string("_Unused")); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, string_double, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_strings"; + default_value_ = GetDefault(kernel_info, "default_float", -0.); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, double_int64, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + value_field_name_ = "values_int64s"; + default_value_ = GetDefault(kernel_info, "default_int64", static_cast(-1)); +} + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, 4, int64_double, + KernelDefBuilder() + .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_4) + +template <> +void LabelEncoder_4::InitializeAttrFields(const OpKernelInfo& kernel_info) { + key_field_name_ = "keys_int64s"; + default_value_ = GetDefault(kernel_info, "default_float", -0.); +} } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.h b/onnxruntime/core/providers/cpu/ml/label_encoder.h index 1b4fa01900ae9..0f9f7cfb5dba6 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.h +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.h @@ -6,6 +6,8 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/providers/cpu/ml/ml_common.h" +#include "core/framework/tensorprotoutils.h" +#include "core/common/safeint.h" namespace onnxruntime { namespace ml { @@ -53,57 +55,182 @@ class LabelEncoder_2 final : public OpKernel { std::vector keys; std::vector values; - ORT_THROW_IF_ERROR(info.GetAttrs(_key_field_name, keys)); - ORT_THROW_IF_ERROR(info.GetAttrs(_value_field_name, values)); + ORT_THROW_IF_ERROR(info.GetAttrs(key_field_name_, keys)); + ORT_THROW_IF_ERROR(info.GetAttrs(value_field_name_, values)); auto num_keys = keys.size(); auto num_values = values.size(); - ORT_ENFORCE(num_keys == num_values, - "The ", _key_field_name, " and ", _value_field_name, " attribtues in LabelEncoder ", - "(name: ", info.node().Name(), ") must have the same length. ", - "However, the number of key is ", num_keys, " and the number of ", - "values is ", num_values, "."); - _map.reserve(num_keys); - for (size_t i = 0; i < num_keys; ++i) - _map.emplace(keys[i], values[i]); + ORT_ENFORCE(num_keys == num_values, "The ", key_field_name_, " and ", value_field_name_, + " attributes in LabelEncoder ", "(name: ", info.node().Name(), ") must have the same length. ", + "However, the number of key is ", num_keys, " and the number of ", "values is ", num_values, "."); + map_.reserve(num_keys); + for (size_t i = 0; i < num_keys; ++i) map_.emplace(keys[i], values[i]); } Status Compute(OpKernelContext* context) const override { - const auto* tensor_pointer = context->Input(0); - if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - const Tensor& X = *tensor_pointer; - const TensorShape& shape = X.Shape(); - Tensor& Y = *context->Output(0, shape); - - auto input = X.template DataAsSpan(); - auto output = Y.template MutableDataAsSpan(); - - for (int64_t i = 0; i < shape.Size(); ++i) { - const auto found = _map.find(input[onnxruntime::narrow(i)]); - if (found == _map.end()) - output[onnxruntime::narrow(i)] = _default_value; - else - output[onnxruntime::narrow(i)] = found->second; + const auto* X = context->Input(0); + const TensorShape& shape = X->Shape(); + auto* Y = context->Output(0, shape); + + auto input = X->template DataAsSpan(); + auto output = Y->template MutableDataAsSpan(); + auto input_iter = input.begin(); + auto output_iter = output.begin(); + while (input_iter != input.end()) { + const auto found = map_.find(*input_iter); + *output_iter = found == map_.end() ? default_value_ : found->second; + ++output_iter; + ++input_iter; } - return Status::OK(); } private: // Specialize this method to set attribute names. For example, if keys' type - // is 64-bit integer, _key_field_name should be "keys_int64s". Field names + // is 64-bit integer, key_field_name_ should be "keys_int64s". Field names // for other types can be found in ONNX spec. void InitializeSomeFields(const OpKernelInfo& info); // A collection of key-value pairs. Each (a_key, a_value) pair // means that the "a_key" in the input would be mapped to "a_value". - // If _map doesn't contain "a_key", we use _default_value as its output. - InlinedHashMap _map; - TValue _default_value; + // If map_ doesn't contain "a_key", we use default_value_ as its output. + InlinedHashMap map_; + TValue default_value_; // ONNX attribute name to load keys. - std::string _key_field_name; + std::string key_field_name_; // ONNX attribute name to load values. - std::string _value_field_name; + std::string value_field_name_; +}; + +template +std::vector GetAttribute(const OpKernelInfo& info, const std::string& name, const std::string& tensor_name) { + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + std::vector attrs; + if (info.GetAttrs(name, attrs).IsOK()) { + return attrs; + } + } + ONNX_NAMESPACE::TensorProto attr_tensor_proto; + auto result = info.GetAttr(tensor_name, &attr_tensor_proto); + if (name.empty()) { + ORT_ENFORCE(result.IsOK(), "LabelEncoder is missing attribute ", tensor_name); + } else { + ORT_ENFORCE(result.IsOK(), "LabelEncoder is missing attribute ", tensor_name, " or ", name); + } + SafeInt element_count(1); + for (auto dim : attr_tensor_proto.dims()) { + element_count *= dim; + } + const SafeInt tensor_size(element_count); + std::vector out(tensor_size); + result = utils::UnpackTensor(attr_tensor_proto, Path(), out.data(), tensor_size); + ORT_ENFORCE(result.IsOK(), "LabelEncoder could not unpack tensor attribute ", name); + return out; +} + +template +T GetDefault(const OpKernelInfo& info, const std::string& attr_name, const T& backup) { + ONNX_NAMESPACE::TensorProto attr_tensor_proto; + auto result = info.GetAttr("default_tensor", &attr_tensor_proto); + if (result.IsOK() && utils::HasDataType(attr_tensor_proto)) { + T default_value; + result = utils::UnpackTensor(attr_tensor_proto, Path(), &default_value, 1); + ORT_ENFORCE(result.IsOK(), "LabelEncoder could not unpack default tensor ", attr_name); + return default_value; + } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + T default_value; + result = info.GetAttr(attr_name, &default_value); + if (result.IsOK()) { + return default_value; + } + } + return backup; +} + +// We don't make use of InlinedHashMap since we make use of a custom hash and equality function. +// Introducing new template parameters in inlined_containers_fwd.h creates compilation errors +// (see https://github.com/microsoft/onnxruntime/pull/17977#discussion_r1446510961). +#ifndef DISABLE_ABSEIL +template +using HashFunc = absl::container_internal::hash_default_hash; + +template +using EqualFunc = absl::container_internal::hash_default_eq; + +template +using HashMap = absl::flat_hash_map; +#else +template +using HashFunc = std::hash; + +template +using EqualFunc = std::equal_to; + +template +using HashMap = std::unordered_map; +#endif // DISABLE_ABSEIL + +template +struct NaNHash { + size_t operator()(const T& value) const { + if constexpr (std::is_floating_point_v) { + if (std::isnan(value)) { + return 0; + } + } + return HashFunc{}(value); + } +}; + +template +struct NaNEqual { + bool operator()(const T& lhs, const T& rhs) const { + if constexpr (std::is_floating_point_v) { + if (std::isnan(lhs) && std::isnan(rhs)) { + return true; + } + } + return EqualFunc{}(lhs, rhs); + } +}; + +template +class LabelEncoder_4 final : public OpKernel { + public: + LabelEncoder_4(const OpKernelInfo& kernel_info) : OpKernel(kernel_info) { + InitializeAttrFields(kernel_info); + auto keys = GetAttribute(kernel_info, key_field_name_, "keys_tensor"); + auto values = GetAttribute(kernel_info, value_field_name_, "values_tensor"); + ORT_ENFORCE(keys.size() == values.size(), "Keys and values must have the same length."); + for (size_t i = 0; i < keys.size(); ++i) { + map_.emplace(keys[i], values[i]); + } + } + Status Compute(OpKernelContext* context) const override { + const auto* X = context->Input(0); + const TensorShape& shape = X->Shape(); + auto* Y = context->Output(0, shape); + + auto input = X->template DataAsSpan(); + auto output = Y->template MutableDataAsSpan(); + auto input_iter = input.begin(); + auto output_iter = output.begin(); + while (input_iter != input.end()) { + const auto found = map_.find(*input_iter); + *output_iter = found == map_.end() ? default_value_ : found->second; + ++output_iter; + ++input_iter; + } + return Status::OK(); + } + + private: + void InitializeAttrFields(const OpKernelInfo& kernel_info); + HashMap, NaNEqual> map_; + TValue default_value_; + std::string key_field_name_; + std::string value_field_name_; }; + } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc b/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc index 2ce652e833717..63001dd1063ce 100644 --- a/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc +++ b/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc @@ -8,7 +8,8 @@ namespace onnxruntime { namespace test { template -static void RunTest(const std::vector& dims, const std::vector& input, const std::vector& output) { +static void RunTest(const std::vector& dims, const std::vector& input, + const std::vector& output) { OpTester test("LabelEncoder", 1, onnxruntime::kMLDomain); static const std::vector labels = {"Beer", "Wine", "Tequila"}; @@ -231,5 +232,284 @@ TEST(LabelEncoder, FloatToFloatOpset2) { test.Run(); } +TEST(LabelEncoder, Int64toInt64Opset4) { + std::vector dims{1, 5}; + + std::vector input{1, 2, 3, 4, 5}; + std::vector output{12, 13, 14, 15, 42}; + std::vector key_data{1, 2, 3, 4}; + std::vector value_data{12, 13, 14, 15}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + test.AddAttribute("keys_int64s", key_data); + test.AddAttribute("values_int64s", value_data); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + default_proto.add_dims(1); + default_proto.add_int64_data(42); + test.AddAttribute("default_tensor", default_proto); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + test.Run(); +} + +TEST(LabelEncoder, StringtoInt16Opset4) { + std::vector dims{1, 5}; + + const std::vector input{"a", "b", "d", "c", "g"}; + const std::vector output{0, 1, 42, 2, 42}; + const std::vector key_data{"a", "b", "c"}; + const std::vector value_data{0, 1, 2}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + test.AddAttribute("keys_strings", key_data); + + ONNX_NAMESPACE::TensorProto values_proto; + values_proto.set_name("values_tensor"); + values_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT16); + values_proto.add_dims(value_data.size()); + for (const auto value : value_data) { + values_proto.add_int32_data(value); + } + + test.AddAttribute("values_tensor", values_proto); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT16); + default_proto.add_dims(1); + default_proto.add_int32_data(42); + test.AddAttribute("default_tensor", default_proto); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + test.Run(); +} + +TEST(LabelEncoder, Int64toStringOpset4) { + std::vector dims{1, 5}; + + std::vector input{1, 2, 3, 4, 5}; + std::vector output{"Hello", "world", "_Unused", "onnxruntime", "!"}; + std::vector key_data{1, 2, 4, 5}; + std::vector value_data{"Hello", "world", "onnxruntime", "!"}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + ONNX_NAMESPACE::TensorProto keys_proto; + keys_proto.set_name("keys_tensor"); + keys_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + keys_proto.add_dims(key_data.size()); + for (const auto key : key_data) { + keys_proto.add_int64_data(key); + } + test.AddAttribute("keys_tensor", keys_proto); + + ONNX_NAMESPACE::TensorProto values_proto; + values_proto.set_name("values_tensor"); + values_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); + values_proto.add_dims(value_data.size()); + for (const auto& value : value_data) { + values_proto.add_string_data(value); + } + test.AddAttribute("values_tensor", values_proto); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); + default_proto.add_dims(1); + default_proto.add_string_data("_Unused"); + test.AddAttribute("default_tensor", default_proto); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, StringToFloatOpset4) { + std::vector dims{1, 5}; + + std::vector input{"Hello", "world", "Random", "onnxruntime", "!"}; + std::vector output{3.14f, 2.0f, -0.0f, 2.718f, 5.0f}; + std::vector key_data{"Hello", "world", "onnxruntime", "!"}; + std::vector value_data{3.14f, 2.0f, 2.718f, 5.0f}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + ONNX_NAMESPACE::TensorProto keys_proto; + keys_proto.set_name("keys_tensor"); + keys_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); + keys_proto.add_dims(key_data.size()); + for (const auto& key : key_data) { + keys_proto.add_string_data(key); + } + test.AddAttribute("keys_tensor", keys_proto); + + ONNX_NAMESPACE::TensorProto values_proto; + values_proto.set_name("values_tensor"); + values_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values_proto.add_dims(value_data.size()); + for (const auto& value : value_data) { + values_proto.add_float_data(value); + } + test.AddAttribute("values_tensor", values_proto); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + default_proto.add_dims(1); + default_proto.add_float_data(-0.0f); + test.AddAttribute("default_tensor", default_proto); + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, StringToDoubleOpset4) { + std::vector dims{1, 5}; + + std::vector input{"Hello", "world", "Random", "onnxruntime", "!"}; + std::vector output{0.1, 1.1231e30, -0.0, 2.718, 5.0}; + std::vector key_data{"Hello", "world", "onnxruntime", "!"}; + std::vector value_data{0.1, 1.1231e30, 2.718, 5.0}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + ONNX_NAMESPACE::TensorProto keys_proto; + keys_proto.set_name("keys_tensor"); + keys_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); + keys_proto.add_dims(key_data.size()); + for (const auto& key : key_data) { + keys_proto.add_string_data(key); + } + test.AddAttribute("keys_tensor", keys_proto); + + ONNX_NAMESPACE::TensorProto values_proto; + values_proto.set_name("values_tensor"); + values_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + values_proto.add_dims(value_data.size()); + for (const auto& value : value_data) { + values_proto.add_double_data(value); + } + test.AddAttribute("values_tensor", values_proto); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + default_proto.add_dims(1); + default_proto.add_double_data(-0.0); + test.AddAttribute("default_tensor", default_proto); + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, TensorBasedAttributesOpset4) { + std::vector dims{1, 5}; + + std::vector input{1, 2, 3, 4, 5}; + std::vector output{12, 13, 14, 15, 42}; + std::vector key_data{1, 2, 3, 4}; + std::vector value_data{12, 13, 14, 15}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + ONNX_NAMESPACE::TensorProto keys_proto; + keys_proto.set_name("keys_tensor"); + keys_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + keys_proto.add_dims(key_data.size()); + for (const auto key : key_data) { + keys_proto.add_int64_data(key); + } + test.AddAttribute("keys_tensor", keys_proto); + + ONNX_NAMESPACE::TensorProto values_proto; + values_proto.set_name("values_tensor"); + values_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + values_proto.add_dims(value_data.size()); + for (const auto value : value_data) { + values_proto.add_int64_data(value); + } + test.AddAttribute("values_tensor", values_proto); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + default_proto.add_dims(1); + default_proto.add_int64_data(42); + test.AddAttribute("default_tensor", default_proto); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, NaNsMappedTogetherOpset4) { + std::vector dims{1, 6}; + std::vector input{3.14f, std::nanf("1"), 2.718f, std::nanf("2"), 5.f, -1.f}; + std::vector output{"a", "ONNX", "b", "ONNX", "c", "onnxruntime"}; + std::vector key_data{3.14f, 2.718f, 5.0f, std::nanf("3")}; + std::vector value_data{"a", "b", "c", "ONNX"}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + test.AddAttribute("keys_floats", key_data); + test.AddAttribute("values_strings", value_data); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); + default_proto.add_dims(1); + default_proto.add_string_data("onnxruntime"); + test.AddAttribute("default_tensor", default_proto); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, DoubleNaNsMappedTogetherOpset4) { + std::vector dims{1, 6}; + std::vector input{3.14, std::nan("1"), 2.718, std::nan("2"), 5.0, -1}; + std::vector output{"a", "ONNX", "b", "ONNX", "c", "onnxruntime"}; + std::vector key_data{3.14, 2.718, 5.0, std::nan("3")}; + std::vector value_data{"a", "b", "c", "ONNX"}; + + OpTester test("LabelEncoder", 4, onnxruntime::kMLDomain); + + ONNX_NAMESPACE::TensorProto keys_proto; + keys_proto.set_name("keys_tensor"); + keys_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + keys_proto.add_dims(key_data.size()); + for (const auto key : key_data) { + keys_proto.add_double_data(key); + } + test.AddAttribute("keys_tensor", keys_proto); + + test.AddAttribute("values_strings", value_data); + + ONNX_NAMESPACE::TensorProto default_proto; + default_proto.set_name("default_tensor"); + default_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); + default_proto.add_dims(1); + default_proto.add_string_data("onnxruntime"); + test.AddAttribute("default_tensor", default_proto); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index ed263515d6dd6..ca089c42032b1 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -235,10 +235,6 @@ "^test_resize_upsample_sizes_nearest_not_larger_cuda", "^test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cuda", // onnx 1.15 (opset 20) new and updated op tests - "^test_ai_onnx_ml_label_encoder_string_int", - "^test_ai_onnx_ml_label_encoder_string_int_no_default", - "^test_ai_onnx_ml_label_encoder_tensor_mapping", - "^test_ai_onnx_ml_label_encoder_tensor_value_only_mapping", "^test_image_decoder_decode_bmp_rgb", "^test_image_decoder_decode_jpeg2k_rgb", "^test_image_decoder_decode_jpeg_bgr",