From 46c4d7fe4ad457d517fe92db7681c38849c51beb Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 18:20:22 -0800 Subject: [PATCH] Disable gemm activation for non-float data types (#19612) ### Description Disable gemm activation for non-float data types ### Motivation and Context When a float16 model contains a Gemm+Relu subgraph, the gemm_activation_fusion will kick in and cause the two nodes to be eliminated and replaced with a FusedGemm. This however is only registered for the float data type. This causes model load failures. Disable the fusion for non-float data types. --------- Co-authored-by: Sheil Kumar --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c62887da09fdc..50be2cbd48f7b 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -56,6 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } + NodeArg* node_output = node.MutableOutputDefs()[0]; + auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // FusedGemm is only registered for float data type in fused_gemm.cc! + continue; + } + const Node& next_node = *(node.OutputNodesBegin()); if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue;