From dc54590fc7cdc9144afa7487c5256df8e504bf25 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Tue, 19 Sep 2023 09:16:50 -0700 Subject: [PATCH] add changes without macro --- .../core/optimizer/layer_norm_fusion.cc | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index bf36f11521be2..14249413f2f91 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -413,22 +413,33 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // because SkipLayerNorm kernel, for example, has dependency on single dim size NodeArg* scale = nullptr; NodeArg* bias = nullptr; + std::cout << "LNF Start Changes" << std::endl; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + std::cout << "LNF Mul node is Null" << std::endl; + continue; } + // if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || + // graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + std::cout << "LNF Scale set" << std::endl; + scale = mul_node.MutableInputDefs()[i]; + } + // } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - bias = last_add_node.MutableInputDefs()[i]; - } + if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) { + std::cout << "LNF Last add node is Null" << std::endl; + continue; + } + // if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || + // graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + std::cout << "LNF Bias set" << std::endl; + bias = last_add_node.MutableInputDefs()[i]; } + // } } if (scale == nullptr || bias == nullptr) { continue;