Skip to content

Commit

Permalink
Merge branch 'abjindal/test_pr_lnf_without_macro' of github.com:micro…
Browse files Browse the repository at this point in the history
…soft/onnxruntime into abjindal/test_pr_lnf_wo_macro_w_skip_lnf
  • Loading branch information
ajindal1 committed Sep 19, 2023
2 parents 730fab3 + dc54590 commit f9f3fae
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,22 +413,33 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
std::cout << "LNF Start Changes" << std::endl;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
std::cout << "LNF Mul node is Null" << std::endl;
continue;
}
// if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
// graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
std::cout << "LNF Scale set" << std::endl;
scale = mul_node.MutableInputDefs()[i];
}
// }
}

for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
bias = last_add_node.MutableInputDefs()[i];
}
if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
std::cout << "LNF Last add node is Null" << std::endl;
continue;
}
// if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
// graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
std::cout << "LNF Bias set" << std::endl;
bias = last_add_node.MutableInputDefs()[i];
}
// }
}
if (scale == nullptr || bias == nullptr) {
continue;
Expand Down

0 comments on commit f9f3fae

Please sign in to comment.