Skip to content

Commit

Permalink
add skip layer norm fusion changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ajindal1 committed Sep 19, 2023
1 parent f9f3fae commit e2c5055
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,21 +677,28 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
// scale and bias could be multi-dims; we only support it for training at the moment
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
std::cout << "SLNF Start Changes" << std::endl;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
// if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
// graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
std::cout << "SLNF Mul Node Nullptr" << std::endl;
continue;
}
#ifdef ENABLE_TRAINING_CORE
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
std::cout << "SLNF ENABLE_TRAINING_CORE ON" << std::endl;
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
#else
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
std::cout << "SLNF ENABLE_TRAINING_CORE OFF" << std::endl;
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
// }
}

if (scale == nullptr) {
Expand Down

0 comments on commit e2c5055

Please sign in to comment.