Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abjindal/test pr lnf wo macro w skip lnf #17610

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,22 +413,33 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
std::cout << "LNF Start Changes" << std::endl;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
std::cout << "LNF Mul node is Null" << std::endl;
continue;
}
// if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
// graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
std::cout << "LNF Scale set" << std::endl;
scale = mul_node.MutableInputDefs()[i];
}
// }
}

for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
bias = last_add_node.MutableInputDefs()[i];
}
if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
std::cout << "LNF Last add node is Null" << std::endl;
continue;
}
// if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
// graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
std::cout << "LNF Bias set" << std::endl;
bias = last_add_node.MutableInputDefs()[i];
}
// }
}
if (scale == nullptr || bias == nullptr) {
continue;
Expand Down Expand Up @@ -666,21 +677,28 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
// scale and bias could be multi-dims; we only support it for training at the moment
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
std::cout << "SLNF Start Changes" << std::endl;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
// if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
// graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
std::cout << "SLNF Mul Node Nullptr" << std::endl;
continue;
}
#ifdef ENABLE_TRAINING_CORE
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
std::cout << "SLNF ENABLE_TRAINING_CORE ON" << std::endl;
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
#else
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
std::cout << "SLNF ENABLE_TRAINING_CORE OFF" << std::endl;
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
// }
}

if (scale == nullptr) {
Expand Down
Loading