Skip to content

Commit

Permalink
disable gemm activation for non-float data types
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheil Kumar committed Feb 22, 2024
1 parent 76a2a48 commit e8772c7
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions onnxruntime/core/optimizer/gemm_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}

NodeArg* node_output = node.MutableOutputDefs()[0];
auto data_type = node_output->TypeAsProto()->tensor_type().elem_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
continue;
}

Node& gemm_node = node;
Node& act_node = *graph.GetNode(next_node.Index()); // get mutable reference

Expand Down

0 comments on commit e8772c7

Please sign in to comment.