diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 462d410e13769..fe0734c51f807 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -397,6 +397,10 @@ class Node { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Remove the specified attribute from this Node */ bool ClearAttribute(const std::string& attr_name); + + /** Gets the Node's mutable attributes. */ + NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** @@ -406,8 +410,6 @@ class Node { int PruneRemovableAttributes(gsl::span removable_attributes); #if !defined(ORT_MINIMAL_BUILD) - /** Gets the Node's mutable attributes. */ - NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve. @param attr_name Attribute name for the GraphProto attribute. @@ -441,6 +443,13 @@ class Node { return attr_to_subgraph_map_; } + /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node. + * @returns a mutable map of mutable subgraphs. + */ + std::unordered_map>& GetMutableMapOfAttributeNameToSubgraph() { + return attr_to_subgraph_map_; + } + /** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node. @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance. nullptr if the Node has no subgraphs. @@ -586,7 +595,7 @@ class Node { // create a Graph instance for an attribute that contains a GraphProto void CreateSubgraph(const std::string& attr_name); - const std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } + std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } // validate and update the input arg count common::Status UpdateInputArgCount(); @@ -1134,6 +1143,26 @@ class Graph { */ Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); + /** + Directly insert one of the If node branches into this Graph. + `If` node condition must be a constant. The function would + rename the nodes of the corresponding subgraph to make sure there is no conflict. + + Explicit and implicit inputs references stay the same. + + All of the outputs of the subgraph being inlined should be renamed + to the outputs of the If node. + + The function will process any subgraphs in each of the nodes being inlined, + and will rename any references to the new names introduced. + + @param condition_value If condition value + @param if_node - the node that contains the graph_to_inline. This node is going + to be deleted and replaced by the corresponding graph (either then or else) + @param logger + */ + Status InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger); + /** Directly insert the nodes in the function Node provided into this Graph. The Graph needs to be Resolve()d after this call. diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 7b0a834a7ffc0..a266c9ab04a2e 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -432,7 +432,7 @@ class Inliner { // Process a node: void transform(NodeProto& n) { if (!n.name().empty()) - n.set_name(prefix_ + n.name()); + n.set_name(prefix_ + "_" + n.name()); for (auto& x : *n.mutable_input()) { rename(x, false); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4b3cafcb39b78..3763e0758cc5c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -984,6 +984,7 @@ bool Node::ClearAttribute(const std::string& attr_name) { graph_->SetGraphProtoSyncNeeded(); return attributes_.erase(attr_name) > 0; } + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) int Node::PruneRemovableAttributes(gsl::span removable_attributes) { @@ -4047,6 +4048,301 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod return Status::OK(); } +static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap& name_to_nodearg, + Graph& graph) { + for (auto& node : graph.Nodes()) { + if (node.ContainsSubgraph()) { + for (auto& [name, subgraph] : node.GetAttributeNameToMutableSubgraphMap()) { + ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph); + } + } + + // NodeArgs need to be updated + for (auto& input_def : node.MutableInputDefs()) { + if (input_def->Exists()) { + auto hit = name_to_nodearg.find(input_def->Name()); + if (hit != name_to_nodearg.cend()) { + input_def = hit->second; + } + } + } + } +} + +Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger) { + static const std::string then_branch{"then_branch"}; + static const std::string else_branch{"else_branch"}; + Graph* sub_graph; + if (condition_value) { + sub_graph = if_node.GetMutableGraphAttribute(then_branch); + } else { + sub_graph = if_node.GetMutableGraphAttribute(else_branch); + } + + if (sub_graph == nullptr) { + auto str = MakeString("Unable to constant fold If node: '", if_node.Name(), "' Unable to fetch: ", + (condition_value ? then_branch : else_branch)); + LOGS(logger, WARNING) << str; + return Status::OK(); + } + + Graph& graph_to_inline = *sub_graph; + + std::string unique_id{if_node.Name()}; + if (condition_value) { + unique_id.append(then_branch); + } else { + unique_id.append(else_branch); + } + + unique_id = GenerateNodeName(unique_id); + + auto make_unique = [&unique_id](const std::string& name) { + return unique_id + '_' + name; + }; + + // Check if the name is an input or implicit input. + // These are not renamed, and we do not need to adjust subgraphs for them. + // Implicit inputs would cover both If node input and implicit inputs. + // Reason: there are no explicit inputs to the subgraphs, and the subgraph's + // implicit inputs must be covered by the implicit inputs of the If node. + InlinedHashMap outer_scope_values; + const auto if_implicit_inputs = if_node.MutableImplicitInputDefs(); + outer_scope_values.reserve(if_implicit_inputs.size()); + + for (auto* input : if_implicit_inputs) { + const auto& name = input->Name(); + ORT_IGNORE_RETURN_VALUE(outer_scope_values.emplace(name, input)); + } + + // Name mapping from the graph to inline to the graph we are inlining into + // we also use this to process any subgraphs in the graph we are inlining + InlinedHashMap name_to_nodearg; + + // We are going to map the outputs of the graph to inline to the outputs of the If node. + // They are assumed to be in the same order. + const auto node_output_defs = if_node.MutableOutputDefs(); + const auto graph_output_defs = graph_to_inline.GetOutputs(); + for (size_t i = 0; i < graph_output_defs.size(); ++i) { + name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); + } + + // Move initializers from the subgraph to the destination graph. + for (int i = 0, limit = graph_to_inline.graph_proto_->initializer_size(); i < limit; ++i) { + auto* initializer = graph_to_inline.graph_proto_->mutable_initializer(i); + const std::string src_name = initializer->name(); + +#if !defined(DISABLE_SPARSE_TENSORS) + bool has_sparse_origin = false; + if (!graph_to_inline.sparse_tensor_names_.empty()) { + auto hit = graph_to_inline.sparse_tensor_names_.find(src_name); + if (hit != graph_to_inline.sparse_tensor_names_.cend()) { + has_sparse_origin = true; + // Erase the entry that will be invalidated + graph_to_inline.sparse_tensor_names_.erase(hit); + } + } +#endif + + graph_to_inline.name_to_initial_tensor_.erase(src_name); + const gsl::not_null tensor{graph_proto_->add_initializer()}; + *tensor = std::move(*initializer); + + // Check if this is an output of the graph + auto hit = name_to_nodearg.find(src_name); + if (hit != name_to_nodearg.cend()) { + // We rename it to If node output. + tensor->set_name(hit->second->Name()); + } else { + NodeArg* node_arg = graph_to_inline.GetNodeArg(src_name); + assert(node_arg != nullptr); + auto new_name = GenerateNodeArgName(make_unique(src_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, node_arg->TypeAsProto()); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(src_name, &new_arg)); + tensor->set_name(std::move(new_name)); + } + + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Initializer name: ", tensor->name(), " from graph: ", + graph_to_inline.Name(), " conflicts with graph initializer. Check name generation above."); + +#if !defined(DISABLE_SPARSE_TENSORS) + if (has_sparse_origin) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif + } + + // Look up nodes that would be providing input to our nodes (implicit and explicit) + // and any nodes that take the output of our nodes (used to be If output) + // Map of NodeArg name to pair of Node* and input index in the destination node + using NodeAndIndex = std::pair, int>; + using ArgNameToNodeMap = InlinedHashMap; + ArgNameToNodeMap input_args; + // Map of NodeArg name to pair of Node* and output index in the source node. + ArgNameToNodeMap output_args; + + auto map_defs = [](Node& node, ArgNameToNodeMap& map, bool input) { + const auto defs = (input) ? node.InputDefs() : node.OutputDefs(); + map.reserve(map.size() + defs.size()); + int arg_pos = -1; + for (auto* node_arg : defs) { + ++arg_pos; + if (node_arg->Exists()) { + map.emplace(node_arg->Name(), std::make_pair(&node, arg_pos)); + } + } + }; + + const bool is_this_main_graph = (parent_graph_ == nullptr); + // Map the inputs and outputs of the If node to the nodes in the graph to inline. + if (!is_this_main_graph) { + for (auto& node : Nodes()) { + if (node.Index() == if_node.Index()) { + continue; + } + map_defs(node, input_args, true); + map_defs(node, output_args, false); + } + } + + // We want to make sure we get nodes in topological order + // because Constant folding may cause the nodes appear in + // a different order. + InlinedVector new_nodes; + GraphViewer graph(graph_to_inline); + for (const auto node_idx : graph.GetNodesInTopologicalOrder()) { + // GraphViewer filters out nullptrs + auto* node = graph_to_inline.GetNode(node_idx); + assert(node->OpType() != kConstant); + + InlinedVector new_node_input_defs; + for (const auto* input_def : node->InputDefs()) { + if (input_def->Exists()) { + // Check if this is one of the implicit graph inputs + // then leave the name as is and re-use the NodeArg + const auto& input_name = input_def->Name(); + auto outer_hit = outer_scope_values.find(input_name); + if (outer_hit != outer_scope_values.cend()) { + new_node_input_defs.push_back(outer_hit->second); + } else { + auto hit = name_to_nodearg.find(input_name); + if (hit != name_to_nodearg.cend()) { + // This is other node output, constant node or initializer that was renamed. + new_node_input_defs.push_back(hit->second); + } else { + ORT_THROW("Node's: ", node->Name(), " input: ", input_name, + " is not If node's input or previous node output in this subgraph"); + } + } + } + } + + InlinedVector new_node_output_defs; + for (const auto* output_def : node->OutputDefs()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the graph outputs, we rename it to + // If node output. + new_node_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_node_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } + } + + const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); + Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), + new_node_input_defs, + new_node_output_defs, + nullptr, + node->Domain()); + + if (!is_this_main_graph) { + map_defs(new_node, input_args, true); + map_defs(new_node, output_args, false); + new_nodes.push_back(&new_node); + } + + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; + + if (node->ContainsSubgraph()) { + auto& subgraphs = node->MutableSubgraphs(); + + // Check if any of this node implicit inputs of this graph is in the renaming map + int renames_subgraph_names = 0; + auto& new_implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : new_implicit_defs) { + auto hit = name_to_nodearg.find(input_def->Name()); + if (hit != name_to_nodearg.cend()) { + input_def = hit->second; + ++renames_subgraph_names; + } + } + + for (auto& subgraph : subgraphs) { + if (renames_subgraph_names > 0) { + // We need to rename the subgraph node names + // because they may refer to the implicit inputs + // that were renamed. + ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph); + } + subgraph->parent_node_ = &new_node; + subgraph->parent_graph_ = this; + } + + new_node.MutableSubgraphs() = std::move(subgraphs); + new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); + new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs); + } + + new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); + } + + // Let's rebuild local connections, so next time a GraphViewer is able to perform topological sort. + // We only need to do so if this graph is not the main graph, because the main graph is going to resolve + // and it is not possible to inline the same nodes again. + if (!is_this_main_graph) { + for (auto* node : new_nodes) { + int arg_pos = -1; + for (auto* input_def : node->InputDefs()) { + ++arg_pos; + auto hit = output_args.find(input_def->Name()); + if (hit != output_args.cend()) { + // The input to this node is an output from a previous node in this graph. + // Create relationship between this node (node), and the node providing the output (output_node). + const auto& [producer, src_idx] = hit->second; + AddEdge(producer->Index(), node->Index(), src_idx, arg_pos); + } + } + + // Check if any of the outputs for inlined nodes are inputs to other nodes in the graph. + // (outputs of If node) + arg_pos = -1; + for (auto& output_def : node->OutputDefs()) { + ++arg_pos; + auto hit = input_args.find(output_def->Name()); + if (hit != input_args.cend()) { + // The output of this node is an input to another node in this graph. + // Create relationship between this node (node), and the node using the input (input_node). + const auto& [consumer, dst_idx] = hit->second; + AddEdge(node->Index(), consumer->Index(), arg_pos, dst_idx); + } + } + } + } + + LOGS(logger, INFO) << "Constant folded (inlined) " << (condition_value ? then_branch : else_branch) + << " for If node: " << if_node.Name(); + + return Status::OK(); +} + Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { auto to_node_arg = [this](const std::string& name) { return &this->GetOrCreateNodeArg(name, nullptr); diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index f46273f2680a9..e3a2f2d74c0d4 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -4,6 +4,7 @@ #include #include "core/optimizer/constant_folding.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" @@ -90,6 +91,45 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { return is_concrete_shape; // convert to constant if this is true } +// This function inlines the appropriate subgraph. It does not literally fold it. +static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Logger& logger, bool& folded) { + folded = false; + // First, find out which subgraph to inline + // We need to fetch the constant argument. + assert(if_node.InputDefs().size() == 1); + const auto* condition_def = if_node.InputDefs()[0]; + + // We need to check if the condition is a constant. + constexpr bool check_outer_scope_true = true; + const ONNX_NAMESPACE::TensorProto* initializer = + graph.GetConstantInitializer(condition_def->Name(), check_outer_scope_true); + if (initializer == nullptr) { + return Status::OK(); + } + + // This is a boolean initializer with a single element. + Initializer condition{*initializer}; + ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), + "' is expected to have a single boolean element"); + + const bool condition_value = *condition.data(); + + auto status = graph.InlineIfSubgraph(condition_value, if_node, logger); + + if (!status.IsOK()) { + LOGS(logger, WARNING) << "Unable to constant fold. InlineIfSubgraph failed " + << " node '" << if_node.Name() << "': " + << status.ErrorMessage(); + return status; + } + + graph_utils::RemoveNodeOutputEdges(graph, if_node); + graph.RemoveNode(if_node.Index()); + + folded = true; + return status; +} + Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { bool have_updated_nodes = false; GraphViewer graph_viewer(graph); @@ -118,7 +158,20 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } bool converted_to_constant = false; - if (node->OpType().compare("Shape") == 0) { + if (node->OpType().compare("If") == 0) { + // This process constant folds the If node only, + // but inlines the nodes of the corresponding branch graph. + // It does not convert the node to a constant in a common sense. + // We call it constant folding because the `If` node constant condition + // may enable us to inline the corresponding branch graph. + bool folded = false; + ORT_RETURN_IF_ERROR(ConstantFoldIfNode(graph, *node, logger, folded)); + if (folded) { + // Node removal is done within ConstantFoldIfNode() + modified = true; + have_updated_nodes = true; + } + } else if (node->OpType().compare("Shape") == 0) { converted_to_constant = ConstantFoldShapeNode(graph, *node); } else { InitializedTensorSet constant_inputs; diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index f9a5c7618601c..84d8a9c56df89 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -543,14 +543,12 @@ TEST(FunctionTest, TestInlinedLocalFunctionRemoved) { InferenceSessionWrapper session_object{session_options, GetEnvironment()}; std::stringstream sstr(serialized_model); - auto status = session_object.Load(sstr); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(session_object.Load(sstr)); auto model_proto = session_object.GetModel().ToProto(); ASSERT_EQ(1, model_proto.functions_size()); - status = session_object.Initialize(); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(session_object.Initialize()); // All functions removed model_proto = session_object.GetModel().ToProto(); diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index feff607703341..7a67747f7cf4c 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -63,7 +63,8 @@ std::function GetGraphBuilder(const GraphConfig& config return graph.ToGraphProto(); }; - auto* if_input = builder.MakeInitializerBool({}, {true}); + // Make this an input to prevent If constant folding affecting this test + auto* if_input = builder.MakeInput({1}, {true}); auto* if_output = builder.MakeOutput(); Node& if_node = builder.AddNode("If", {if_input}, {if_output}); if_node.AddAttribute("then_branch", create_if_subgraph(true)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e0f63ea58e772..b82f3345dfcd1 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -10,6 +10,8 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +#include "onnx/defs/parser.h" +#include "onnx/defs/printer.h" #include "asserts.h" #include "core/common/span_utils.h" @@ -1022,6 +1024,158 @@ TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) { ASSERT_EQ(op_to_count.size(), 0U) << "Identity node should have been removed"; } +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInlining) { + // This test covers the following necessary cases: + // The input refers to the explicit or implicit inputs of If node. + // The output of the node is the output of the subgraph being inlined. + // Constant nodes and initializers are promoted to the outer graph. + // The initializer or a constant node is the output of the subgraph being inlined. + // Nested subgraphs names are renamed as appropriate. + // In all If node is constant folded twice. The last If node is not constant + // folded because the input is indirectly dependent on the size of the input. + // XXX: Can we constant fold Size() if the graph input shape is fixed? + + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[128] x, float[128] x1) => (float[N] y) + { + y = local.aten_gather (x, x1) + } + < + opset_import: [ "" : 16, "local" : 1], + domain: "local" + > + aten_gather (self, index) => (result_16) + { + tmp = Shape (index) + tmp_0 = Size (tmp) + int64_0 = Constant () + int64_0_cast = CastLike (int64_0, tmp_0) + cond = Equal (tmp_0, int64_0_cast) + result_16 = If (cond) ( result) { + result = Identity (self) + }, else_branch: graph = elseGraph_10 () => ( result_15) { + tmp_1 = Shape (self) + tmp_2 = Size (tmp_1) + int64_0_3 = Constant () + int64_0_3_cast = CastLike (int64_0_3, tmp_2) + cond_4 = Equal (tmp_2, int64_0_3_cast) + self_8 = If (cond_4) ( self_6) { + tmp_5 = Constant () + self_6 = Reshape (self, tmp_5) + }, else_branch: graph = elseGraph_13 () => ( self_7) { + self_7 = Identity (self) + }> + tmp_9 = Size (index) + int64_0_10 = Constant () + int64_0_10_cast = CastLike (int64_0_10, tmp_9) + cond_11 = Equal (tmp_9, int64_0_10_cast) + result_15 = If (cond_11) ( result_12) { + result_12 = CastLike (index, self_8) + }, else_branch: graph = elseGraph_15 () => ( result_14) { + index_13 = Cast (index) + result_14 = GatherElements (self_8, index_13) + }> + }> + } +)"; + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + { + // Test that the model is loadable and check the function call node. + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["local.aten_gather"], 1); + model_proto = p_model->ToProto(); + } + + std::string serialized_model; + const bool serialization_status = model_proto.SerializeToString(&serialized_model); + ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + + // AOT inlining is necessary in this case, so the If nodes within the function + // are brought out to the outer scope. So we load this into a session object. + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // const auto resulting_model_proto = session_object.GetModel().ToProto(); + // std::string printed_model = ONNX_NAMESPACE::ProtoToString(resulting_model_proto); + // ASSERT_FALSE(printed_model.empty()); + // std::cout << printed_model << std::endl; + + // This is the resulting model proto. + // The remaining If node is not constant foldable because Size() does not constant fold + // although the shape is fixed. + /* + < + ir_version: 8, + opset_import: ["" : 16, "local" : 1, + "com.microsoft.nchwc" : 1, + "ai.onnx.ml" : 4, + "com.ms.internal.nhwc" : 20, + "ai.onnx.training" : 1, + "ai.onnx.preview.training" : 1, + "com.microsoft" : 1, + "com.microsoft.experimental" : 1, + "org.pytorch.aten" : 1] + > + agraph (float[128] x, float[128] x1) => (float[128] y) { + _if_elseGraph_10__inlfunc_aten_gather_tmp_9 = Size (x1) + _if_elseGraph_10__inlfunc_aten_gather_cond_11 = + Equal (_if_elseGraph_10__inlfunc_aten_gather_tmp_9, ortshared_7_0_1_0_token_10) + y = If (_if_elseGraph_10__inlfunc_aten_gather_cond_11) (float[128] _inlfunc_aten_gather_result_12) { + _inlfunc_aten_gather_result_12 = Cast (x1) + }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) { + _inlfunc_aten_gather_index_13 = Cast (x1) + _inlfunc_aten_gather_result_14 = GatherElements (x, _inlfunc_aten_gather_index_13) + }> + } + */ + + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["local.aten_gather"], 0); + ASSERT_EQ(op_to_count["If"], 1); +} + +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "transform_nested_ifs_toplogical_sorted_nodes.onnx"; + + SessionOptions so; + so.session_logid = "GraphTransformationTests.ConstantFoldingIfConstantInliningRebuildEdges"; + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["pkg.onnxscript.torch_lib._aten_linalg_vector_norm_no_dim_onnx"], 0); + ASSERT_EQ(op_to_count["If"], 0); + ASSERT_EQ(op_to_count["Reshape"], 1); + ASSERT_EQ(op_to_count["Abs"], 1); + ASSERT_EQ(op_to_count["Mul"], 1); + ASSERT_EQ(op_to_count["ReduceSum"], 1); + ASSERT_EQ(op_to_count["Sqrt"], 1); + ASSERT_EQ(op_to_count["Cast"], 2); +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx"; diff --git a/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.onnx b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.onnx new file mode 100644 index 0000000000000..afb499a347ec7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.onnx differ diff --git a/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.py b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.py new file mode 100644 index 0000000000000..ebda865895d02 --- /dev/null +++ b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.py @@ -0,0 +1,859 @@ +import google.protobuf.text_format +import onnx +from numpy import array, float16 + +import onnxruntime as ort + +# Run n times +N = 1 + +onnx_model_text = """ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.2.0" +graph { + node { + output: "_val_1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value_ints" + ints: -1 + type: INTS + } + doc_string: "" + } + node { + input: "input_0" + input: "_val_1" + output: "_val_2" + name: "Reshape_1" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + doc_string: "" + } + node { + input: "_val_2" + output: "_val_3" + name: "_aten_linalg_vector_norm_no_dim_onnx_2" + op_type: "_aten_linalg_vector_norm_no_dim_onnx" + attribute { + name: "keepdim" + i: 0 + type: INT + } + attribute { + name: "ord" + f: 2.0 + type: FLOAT + } + doc_string: "" + domain: "pkg.onnxscript.torch_lib" + } + name: "main_graph" + input { + name: "input_0" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + output { + name: "_val_3" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + value_info { + name: "_val_1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + value_info { + name: "_val_2" + type { + tensor_type { + elem_type: 10 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + domain: "pkg.onnxscript.torch_lib" + version: 1 +} +opset_import { + domain: "" + version: 18 +} +opset_import { + domain: "pkg.onnxscript.torch_lib.common" + version: 1 +} +functions { + name: "_aten_linalg_vector_norm_no_dim_onnx" + input: "self" + output: "result_29" + attribute: "ord" + attribute: "keepdim" + node { + input: "self" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "self_rank" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "int64_0" + name: "n2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0" + input: "self_rank" + output: "int64_0_cast" + name: "n3" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_cast" + output: "cond" + name: "n4" + op_type: "Equal" + domain: "" + } + node { + input: "cond" + output: "self_2" + name: "n5" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + output: "int64_0_1d" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 0 + name: "int64_0_1d" + } + type: TENSOR + } + domain: "" + } + node { + input: "self" + input: "int64_0_1d" + output: "self_0" + name: "n1" + op_type: "Unsqueeze" + domain: "" + } + name: "thenGraph_4" + output { + name: "self_0" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "self" + output: "self_1" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_4" + output { + name: "self_1" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + input: "self_2" + output: "self_3" + name: "n6" + op_type: "Abs" + domain: "" + } + node { + output: "ord" + name: "n7" + op_type: "Constant" + attribute { + name: "value_float" + type: FLOAT + ref_attr_name: "ord" + } + domain: "" + } + node { + input: "ord" + output: "ord_4" + name: "n8" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + domain: "" + } + node { + input: "ord_4" + output: "cond_5" + name: "n9" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 0 + type: INT + } + attribute { + name: "detect_positive" + i: 1 + type: INT + } + domain: "" + } + node { + input: "cond_5" + output: "result_24" + name: "n10" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result" + name: "n0" + op_type: "ReduceMax" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_9" + output { + name: "result" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + output: "cond_6" + name: "n0" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 1 + type: INT + } + attribute { + name: "detect_positive" + i: 0 + type: INT + } + domain: "" + } + node { + input: "cond_6" + output: "result_23" + name: "n1" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_7" + name: "n0" + op_type: "ReduceMin" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_11" + output { + name: "result_7" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 0.0 + name: "const" + } + type: TENSOR + } + domain: "" + } + node { + input: "const" + input: "ord_4" + output: "const_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_cast" + output: "cond_8" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_8" + output: "result_22" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "self_bool" + name: "n0" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + domain: "" + } + node { + input: "self_bool" + input: "self_3" + output: "self_0_1" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "self_0_1" + output: "result_9" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + i: 0 + type: INT + } + domain: "" + } + name: "thenGraph_13" + output { + name: "result_9" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_10" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_10" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_10" + input: "ord_4" + output: "const_10_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_10_cast" + output: "cond_11" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_11" + output: "result_21" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_12" + name: "n0" + op_type: "ReduceL1" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_18" + output { + name: "result_12" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_13" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 2.0 + name: "const_13" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_13" + input: "ord_4" + output: "const_13_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_13_cast" + output: "cond_14" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_14" + output: "result_20" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_15" + name: "n0" + op_type: "ReduceL2" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_20" + output { + name: "result_15" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + input: "self_3" + output: "ord_float" + name: "n0" + op_type: "CastLike" + domain: "" + } + node { + input: "self_3" + input: "ord_float" + output: "self_pow" + name: "n1" + op_type: "Pow" + domain: "" + } + node { + input: "self_pow" + output: "tmp_16" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + node { + output: "const_17" + name: "n3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_17" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_17" + input: "ord_float" + output: "const_17_cast" + name: "n4" + op_type: "CastLike" + domain: "" + } + node { + input: "const_17_cast" + input: "ord_float" + output: "tmp_18" + name: "n5" + op_type: "Div" + domain: "" + } + node { + input: "tmp_16" + input: "tmp_18" + output: "result_19" + name: "n6" + op_type: "Pow" + domain: "" + } + name: "elseGraph_20" + output { + name: "result_19" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_18" + output { + name: "result_20" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_13" + output { + name: "result_21" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_11" + output { + name: "result_22" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_9" + output { + name: "result_23" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + output: "int64_0_25" + name: "n11" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0_25" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0_25" + input: "self_rank" + output: "int64_0_25_cast" + name: "n12" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_25_cast" + output: "cond_26" + name: "n13" + op_type: "Equal" + domain: "" + } + node { + input: "cond_26" + output: "result_29" + name: "n14" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "result_24" + output: "result_27" + name: "n0" + op_type: "Squeeze" + domain: "" + } + name: "thenGraph_27" + output { + name: "result_27" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "result_24" + output: "result_28" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_27" + output { + name: "result_28" + type { + } + } + } + type: GRAPH + } + domain: "" + } + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib" +} +functions { + name: "Rank" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "return_val" + name: "n1" + op_type: "Size" + domain: "" + } + doc_string: "Take the rank of the input tensor." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} +functions { + name: "IsScalar" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "tmp_0" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "tmp_1" + name: "n2" + op_type: "Constant" + attribute { + name: "value_int" + i: 0 + type: INT + } + domain: "" + } + node { + input: "tmp_0" + input: "tmp_1" + output: "return_val" + name: "n3" + op_type: "Equal" + domain: "" + } + doc_string: "Return whether the input has rank 0, or is a scalar." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} + +""" + +ort_inputs = {"input_0": array(0.8965, dtype=float16)} + +# Set up the inference session +session_options = ort.SessionOptions() +session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +onnx_model = onnx.ModelProto() +google.protobuf.text_format.Parse(onnx_model_text, onnx_model) + +# Uncomment this line to save the model to a file for examination +# onnx.save_model(onnx_model, "transform_nested_ifs_toplogical_sorted_nodes.onnx") + +onnx.checker.check_model(onnx_model) +session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)) + +# Run the model +for _ in range(N): + ort_outputs = session.run(None, ort_inputs)