Skip to content

Commit

Permalink
Disable gemm activation for non-float data types (#19612)
Browse files Browse the repository at this point in the history
### Description
Disable gemm activation for non-float data types


### Motivation and Context
When a float16 model contains a Gemm+Relu subgraph, the
gemm_activation_fusion will kick in and cause the two nodes to be
eliminated and replaced with a FusedGemm. This however is only
registered for the float data type. This causes model load failures.

Disable the fusion for non-float data types.

---------

Co-authored-by: Sheil Kumar <[email protected]>
  • Loading branch information
2 people authored and maggie1059 committed Mar 5, 2024
1 parent e0d0c6e commit e008e1c
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/core/optimizer/gemm_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}

NodeArg* node_output = node.MutableOutputDefs()[0];
auto data_type = node_output->TypeAsProto()->tensor_type().elem_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
// FusedGemm is only registered for float data type in fused_gemm.cc!
continue;
}

const Node& next_node = *(node.OutputNodesBegin());
if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
Expand Down

0 comments on commit e008e1c

Please sign in to comment.