From e8772c76cc6865aeffd0ab460fcf3232c8ed7f71 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 22 Feb 2024 15:28:06 -0800 Subject: [PATCH 01/11] disable gemm activation for non-float data types --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c62887da09fdc..4c8a63f8eac21 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -65,6 +65,12 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } + NodeArg* node_output = node.MutableOutputDefs()[0]; + auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + continue; + } + Node& gemm_node = node; Node& act_node = *graph.GetNode(next_node.Index()); // get mutable reference From e794caf540914348e1986a7324adcc0d4402e13e Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 07:57:35 -0800 Subject: [PATCH 02/11] Register fp16 fused gemm when mlas fp16vec intrinsics supported --- onnxruntime/contrib_ops/cpu/fused_gemm.cc | 11 ++++++ .../core/optimizer/gemm_activation_fusion.cc | 32 +++++++++++++--- .../providers/cpu/activation/activations.cc | 37 ++++++++++++------- 3 files changed, 62 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.cc b/onnxruntime/contrib_ops/cpu/fused_gemm.cc index 33571e74f5763..bd3f213312c9a 100644 --- a/onnxruntime/contrib_ops/cpu/fused_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/fused_gemm.cc @@ -31,5 +31,16 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), FusedGemm); +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + FusedGemm, + 1, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedGemm); + +#endif + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 4c8a63f8eac21..c4eb8ff1bbfbd 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -57,6 +57,33 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } const Node& next_node = *(node.OutputNodesBegin()); + + NodeArg* node_output = node.MutableOutputDefs()[0]; + auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // While, there are already fp16 kernels for Relu and LeakyRelu that could be fused with the fp16 Gemm, + // Gemm relies on ElementWiseRangedTransform to define activation functions. + // ElementWiseRangedTransform is an abstract templated structure, and so itself has no dependencies. + // However, it's static Create method will create concrete implementations, ie: Relu, Softplus, etc... + // Concrete templated activations exist for any type parameter (in activations.cc) so long as the + // ElementWiseRangedTransform::Create has a type specialization for T (currently only float is defined). + // However, the parameterized implementation for MLFloat16 does not work since ElementWiseRangedTransform + // will call into EigenVectorArrayMap which does not have specializations for MLFloat16. + // That being said, some float16 specializations are implemented for *only* Rely and LeakyRelu in + // fp16_activations.h, that depend on MLAS_F16VEC_INTRINSICS_SUPPORTED. In this case we can reliably turn on + // fp16 FusedGemm. +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + const bool is_fp16_activation_supported = + next_node->OpType() == "Relu" || + next_node->OpType() == "LeakyRelu"; +#else + const bool is_fp16_activation_supported = false; +#endif + if (!is_fp16_activation_supported) { + continue; + } + } + if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; } @@ -65,11 +92,6 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } - NodeArg* node_output = node.MutableOutputDefs()[0]; - auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); - if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - continue; - } Node& gemm_node = node; Node& act_node = *graph.GetNode(next_node.Index()); // get mutable reference diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 049fee4b95308..98916fd492228 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -81,21 +81,29 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node return Status::OK(); \ } - CREATE_ELE_KERNEL(Celu); - CREATE_ELE_KERNEL(Elu); - CREATE_ELE_KERNEL(HardSigmoid); - CREATE_ELE_KERNEL(LeakyRelu); - CREATE_ELE_KERNEL(Softplus); - CREATE_ELE_KERNEL(Relu); - CREATE_ELE_KERNEL(Sigmoid); - CREATE_ELE_KERNEL(Softsign); - CREATE_ELE_KERNEL(Tanh); - CREATE_ELE_KERNEL(ThresholdedRelu); - CREATE_ELE_KERNEL(Selu); + if constexpr (std::is_same::value) { +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + CREATE_ELE_KERNEL(Relu); + CREATE_ELE_KERNEL(LeakyRelu); +#endif + } + else { + CREATE_ELE_KERNEL(Celu); + CREATE_ELE_KERNEL(Elu); + CREATE_ELE_KERNEL(HardSigmoid); + CREATE_ELE_KERNEL(LeakyRelu); + CREATE_ELE_KERNEL(Softplus); + CREATE_ELE_KERNEL(Relu); + CREATE_ELE_KERNEL(Sigmoid); + CREATE_ELE_KERNEL(Softsign); + CREATE_ELE_KERNEL(Tanh); + CREATE_ELE_KERNEL(ThresholdedRelu); + CREATE_ELE_KERNEL(Selu); #ifndef DISABLE_CONTRIB_OPS - CREATE_ELE_KERNEL(ParametricSoftplus); - CREATE_ELE_KERNEL(ScaledTanh); + CREATE_ELE_KERNEL(ParametricSoftplus); + CREATE_ELE_KERNEL(ScaledTanh); #endif + } #undef CREATE_ELE_KERNEL @@ -104,6 +112,9 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, std::unique_ptr>& out); + +template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, + std::unique_ptr>& out); } // namespace functors namespace functors { From ede1b6748efbd6cfb65748e4c6a9b0eaa850450a Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 08:20:26 -0800 Subject: [PATCH 03/11] change logic to separate fp16 logic --- .../core/optimizer/gemm_activation_fusion.cc | 24 +++++++------------ .../providers/cpu/activation/activations.cc | 4 ++++ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c4eb8ff1bbfbd..2805e7ca04cbb 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -60,22 +60,12 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l NodeArg* node_output = node.MutableOutputDefs()[0]; auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); - if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - // While, there are already fp16 kernels for Relu and LeakyRelu that could be fused with the fp16 Gemm, - // Gemm relies on ElementWiseRangedTransform to define activation functions. - // ElementWiseRangedTransform is an abstract templated structure, and so itself has no dependencies. - // However, it's static Create method will create concrete implementations, ie: Relu, Softplus, etc... - // Concrete templated activations exist for any type parameter (in activations.cc) so long as the - // ElementWiseRangedTransform::Create has a type specialization for T (currently only float is defined). - // However, the parameterized implementation for MLFloat16 does not work since ElementWiseRangedTransform - // will call into EigenVectorArrayMap which does not have specializations for MLFloat16. - // That being said, some float16 specializations are implemented for *only* Rely and LeakyRelu in - // fp16_activations.h, that depend on MLAS_F16VEC_INTRINSICS_SUPPORTED. In this case we can reliably turn on - // fp16 FusedGemm. + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + // MLFloat16 specializations are implemented for *only* Rely and LeakyRelu in + // fp16_activations.h, and they depend on MLAS_F16VEC_INTRINSICS_SUPPORTED. + // In this case we can reliably turn on fp16 FusedGemm. #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - const bool is_fp16_activation_supported = - next_node->OpType() == "Relu" || - next_node->OpType() == "LeakyRelu"; + const bool is_fp16_activation_supported = next_node->OpType() == "Relu" || next_node->OpType() == "LeakyRelu"; #else const bool is_fp16_activation_supported = false; #endif @@ -83,6 +73,10 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } } + else if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // No other ElementWiseRangedTransform::Create methods are defined! + continue; + } if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 98916fd492228..3330457dc38f5 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -81,6 +81,10 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node return Status::OK(); \ } + // ElementWiseRangedTransform::Create instantiates concrete implementations, ie: Relu, Softplus, etc... + // Concrete classes exist for any type parameter but MLFloat16 does not work because + // ElementWiseRangedTransform calls EigenVectorArrayMap which does not have + // specializations for MLFloat16. if constexpr (std::is_same::value) { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED CREATE_ELE_KERNEL(Relu); From 110ce4a12a12f88a5832a59e01d90b23d2bc7fb1 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 08:42:43 -0800 Subject: [PATCH 04/11] update comments --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 2805e7ca04cbb..c1e305f9806a5 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -56,12 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } + NodeArg* node_output = node.MutableOutputDefs()[0]; const Node& next_node = *(node.OutputNodesBegin()); - NodeArg* node_output = node.MutableOutputDefs()[0]; auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - // MLFloat16 specializations are implemented for *only* Rely and LeakyRelu in + // FusedGemm is registered in fused_gemm.cc, but underlying + // MLFloat16 specializations are implemented for *only* Relu and LeakyRelu in // fp16_activations.h, and they depend on MLAS_F16VEC_INTRINSICS_SUPPORTED. // In this case we can reliably turn on fp16 FusedGemm. #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -74,7 +75,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } } else if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - // No other ElementWiseRangedTransform::Create methods are defined! + // FusedGemm only registers float and MLFLoat16 kernels in fused_gemm.cc. continue; } From 7676f3d74043cfe9c4f398215477395c9e31fc28 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 09:14:12 -0800 Subject: [PATCH 05/11] lint --- .../core/optimizer/gemm_activation_fusion.cc | 3 +-- .../core/providers/cpu/activation/activations.cc | 13 +++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c1e305f9806a5..f53e2b0c86606 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -73,8 +73,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l if (!is_fp16_activation_supported) { continue; } - } - else if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + } else if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { // FusedGemm only registers float and MLFLoat16 kernels in fused_gemm.cc. continue; } diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 3330457dc38f5..f72f3ec31802e 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -90,8 +90,7 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node CREATE_ELE_KERNEL(Relu); CREATE_ELE_KERNEL(LeakyRelu); #endif - } - else { + } else { CREATE_ELE_KERNEL(Celu); CREATE_ELE_KERNEL(Elu); CREATE_ELE_KERNEL(HardSigmoid); @@ -114,11 +113,13 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node return Status(ONNXRUNTIME, FAIL, "unknown kernel type"); } -template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, - std::unique_ptr>& out); +template +Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, + std::unique_ptr>& out); -template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, - std::unique_ptr>& out); +template +Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, + std::unique_ptr>& out); } // namespace functors namespace functors { From 3af2256f29e45c309ed6eb30123be7ce17ad8a30 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 09:22:47 -0800 Subject: [PATCH 06/11] lint --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 1 - onnxruntime/core/providers/cpu/activation/activations.cc | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index f53e2b0c86606..5190fa5fe597d 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -86,7 +86,6 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } - Node& gemm_node = node; Node& act_node = *graph.GetNode(next_node.Index()); // get mutable reference diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index f72f3ec31802e..22ae0db347c7a 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -115,11 +115,11 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, - std::unique_ptr>& out); +std::unique_ptr>& out); template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, - std::unique_ptr>& out); +std::unique_ptr>& out); } // namespace functors namespace functors { From e4b0fe3f61d38660b6b995a3240f26b89808d4a7 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 09:31:24 -0800 Subject: [PATCH 07/11] lint --- .../core/providers/cpu/activation/activations.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 22ae0db347c7a..1f5424d734030 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -113,13 +113,11 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node return Status(ONNXRUNTIME, FAIL, "unknown kernel type"); } -template -Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, -std::unique_ptr>& out); +template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, + std::unique_ptr>& out); -template -Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, -std::unique_ptr>& out); +template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, + std::unique_ptr>& out); } // namespace functors namespace functors { From a7eacde0d736e1a6d7df58fafa3300a3ab1a188f Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 16:08:52 -0800 Subject: [PATCH 08/11] revert float16 registration --- onnxruntime/contrib_ops/cpu/fused_gemm.cc | 11 ------ .../core/optimizer/gemm_activation_fusion.cc | 21 ++-------- .../providers/cpu/activation/activations.cc | 39 +++++++------------ 3 files changed, 17 insertions(+), 54 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.cc b/onnxruntime/contrib_ops/cpu/fused_gemm.cc index bd3f213312c9a..33571e74f5763 100644 --- a/onnxruntime/contrib_ops/cpu/fused_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/fused_gemm.cc @@ -31,16 +31,5 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), FusedGemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - -ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( - FusedGemm, - 1, - MLFloat16, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - FusedGemm); - -#endif - } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 5190fa5fe597d..f741eac0e7797 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -57,27 +57,14 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } NodeArg* node_output = node.MutableOutputDefs()[0]; - const Node& next_node = *(node.OutputNodesBegin()); - auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); - if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - // FusedGemm is registered in fused_gemm.cc, but underlying - // MLFloat16 specializations are implemented for *only* Relu and LeakyRelu in - // fp16_activations.h, and they depend on MLAS_F16VEC_INTRINSICS_SUPPORTED. - // In this case we can reliably turn on fp16 FusedGemm. -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - const bool is_fp16_activation_supported = next_node->OpType() == "Relu" || next_node->OpType() == "LeakyRelu"; -#else - const bool is_fp16_activation_supported = false; -#endif - if (!is_fp16_activation_supported) { - continue; - } - } else if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - // FusedGemm only registers float and MLFLoat16 kernels in fused_gemm.cc. + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + // FusedGemm is only registered for float data type in fused_gemm.cc! + continue; } + const Node& next_node = *(node.OutputNodesBegin()); if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; } diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 1f5424d734030..b1f8ce65761fe 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -81,30 +81,20 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node return Status::OK(); \ } - // ElementWiseRangedTransform::Create instantiates concrete implementations, ie: Relu, Softplus, etc... - // Concrete classes exist for any type parameter but MLFloat16 does not work because - // ElementWiseRangedTransform calls EigenVectorArrayMap which does not have - // specializations for MLFloat16. - if constexpr (std::is_same::value) { -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - CREATE_ELE_KERNEL(Relu); - CREATE_ELE_KERNEL(LeakyRelu); -#endif - } else { - CREATE_ELE_KERNEL(Celu); - CREATE_ELE_KERNEL(Elu); - CREATE_ELE_KERNEL(HardSigmoid); - CREATE_ELE_KERNEL(LeakyRelu); - CREATE_ELE_KERNEL(Softplus); - CREATE_ELE_KERNEL(Relu); - CREATE_ELE_KERNEL(Sigmoid); - CREATE_ELE_KERNEL(Softsign); - CREATE_ELE_KERNEL(Tanh); - CREATE_ELE_KERNEL(ThresholdedRelu); - CREATE_ELE_KERNEL(Selu); + CREATE_ELE_KERNEL(Celu); + CREATE_ELE_KERNEL(Elu); + CREATE_ELE_KERNEL(HardSigmoid); + CREATE_ELE_KERNEL(LeakyRelu); + CREATE_ELE_KERNEL(Softplus); + CREATE_ELE_KERNEL(Relu); + CREATE_ELE_KERNEL(Sigmoid); + CREATE_ELE_KERNEL(Softsign); + CREATE_ELE_KERNEL(Tanh); + CREATE_ELE_KERNEL(ThresholdedRelu); + CREATE_ELE_KERNEL(Selu); #ifndef DISABLE_CONTRIB_OPS - CREATE_ELE_KERNEL(ParametricSoftplus); - CREATE_ELE_KERNEL(ScaledTanh); + CREATE_ELE_KERNEL(ParametricSoftplus); + CREATE_ELE_KERNEL(ScaledTanh); #endif } @@ -115,9 +105,6 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, std::unique_ptr>& out); - -template Status ElementWiseRangedTransform::Create(const std::string& type, const NodeAttributes& attributes, - std::unique_ptr>& out); } // namespace functors namespace functors { From 310c2d17a8d369223c22bea6046e361b8f898320 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 16:09:37 -0800 Subject: [PATCH 09/11] extra brace --- onnxruntime/core/providers/cpu/activation/activations.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index b1f8ce65761fe..049fee4b95308 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -96,7 +96,6 @@ Status ElementWiseRangedTransform::Create(const std::string& type, const Node CREATE_ELE_KERNEL(ParametricSoftplus); CREATE_ELE_KERNEL(ScaledTanh); #endif - } #undef CREATE_ELE_KERNEL From 983fa78deb16c463b556c89bda02bd36e2e0d091 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 16:10:19 -0800 Subject: [PATCH 10/11] float --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index f741eac0e7797..ab836e77ff14d 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -58,7 +58,7 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l NodeArg* node_output = node.MutableOutputDefs()[0]; auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); - if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { // FusedGemm is only registered for float data type in fused_gemm.cc! continue; From d93fa21dff0fb2d8f44b2435b9ee85ccd04f9a32 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 16:13:30 -0800 Subject: [PATCH 11/11] extra newline --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index ab836e77ff14d..50be2cbd48f7b 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -60,7 +60,6 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { // FusedGemm is only registered for float data type in fused_gemm.cc! - continue; }