diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 2c11bf144999e..13a2c85073d0d 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -4,6 +4,7 @@ #include "onnx_transpose_optimization.h" #include +#include #include #include #include @@ -93,6 +94,57 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } +/// +/// Return a DequantizeLinear node if it's input is a constant initializer with known consumers. +/// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput. +/// +/// Current graph +/// Value to check if produced by a DQ node who's input is a constant initializer +/// NodeRef for DQ node if it meets the requirements. +static std::unique_ptr GetDQWithConstInitializerInput(api::GraphRef& graph, + std::string_view input_name) { + std::unique_ptr dq_node; + auto maybe_dq_node = graph.GetNodeProducingOutput(input_name); + + if (maybe_dq_node && maybe_dq_node->OpType() == "DequantizeLinear") { + do { + auto dq_input = maybe_dq_node->Inputs()[0]; + auto dq_constant = graph.GetConstant(dq_input); + + // input to DQ must be a constant initializer + if (!dq_constant) { + break; + } + + // For now keep it simple and don't support per-axis quantization as that would require updating the + // scale and zero point values in the DQ node to re-order if transposing, or reshape if unsqueezing. + // the rank of the `scale` and `zero point` inputs must match so we only need to check `scale`. + auto dq_scale = graph.GetConstant(maybe_dq_node->Inputs()[1]); + if (!dq_scale || dq_scale->NumElements() != 1) { + break; + } + + // need to know all the initializer consumers as we're potentially going to modify it directly + auto initializer_consumers = graph.GetValueConsumers(dq_input); + if (!initializer_consumers->comprehensive) { + break; + } + + // DQ output is only used by the node we're modifying. consumers should be 0 as we've already removed the node + // being modified as a consumer. See UnsqueezeInput and TransposeInput where we remove the input prior to + // calling this function. + auto dq_consumers = graph.GetValueConsumers(input_name); + if (!dq_consumers->comprehensive || dq_consumers->nodes.size() != 0) { + break; + } + + dq_node = std::move(maybe_dq_node); + } while (false); + } + + return dq_node; +} + // Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once) static bool IsValidPerm(const std::vector& perm) { size_t rank = perm.size(); @@ -345,6 +397,11 @@ static std::vector SortedAxesForTransposedInput(const std::vectorShape(); + graph.GetValueInfo(dq.Outputs()[0])->SetShape(&new_shape); +} /////// /////// /////// /////// @@ -357,51 +414,115 @@ static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector // broadcasting. static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& axes) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer + + // Clear the input, which also removes this node as a consumer of the input node.SetInput(i, ""); std::unique_ptr constant = ctx.graph.GetLocalConstant(input); - auto consumers = ctx.graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInput(ctx.graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = ctx.graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + auto value_to_modify = dq_node ? constant_dq_input : input; + auto consumers = ctx.graph.GetValueConsumers(value_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // We will reshape the initializer. If there are existing consumers, still reshape it but add Squeeze nodes + // We will reshape the initializer. If there are existing consumers, reshape it and add Squeeze nodes // to counteract its effect. If they later Unsqueeze the same input, the Squeeze nodes will simply be deleted // (see Case 2). if (consumers->nodes.size() > 0) { - auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", input, axes); + auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", value_to_modify, axes); api::NodeRef& squeeze = *squeeze_ptr; std::string_view sq_out = squeeze.Outputs()[0]; - ctx.graph.CopyValueInfo(input, sq_out); - ReplaceValueReferences(consumers->nodes, input, sq_out); + ctx.graph.CopyValueInfo(value_to_modify, sq_out); + ReplaceValueReferences(consumers->nodes, value_to_modify, sq_out); } + auto new_shape = UnsqueezeShape(constant->Shape(), axes); - ctx.graph.ReshapeInitializer(input, new_shape); - node.SetInput(i, input); + ctx.graph.ReshapeInitializer(value_to_modify, new_shape); + + if (dq_node) { + // store the ids of any DQ nodes where we inserted a Squeeze for matching in Case 2 later + for (auto& consumer : consumers->nodes) { + if (consumer->OpType() == "DequantizeLinear") { + ctx.special_cased_dq_nodes.insert(consumer->Id()); + } + } + + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, constant_dq_input); + } + + node.SetInput(i, input); // restore the original connection return; } // Case 2: input is a Squeeze node with matching axes std::unique_ptr inp_node = ctx.graph.GetNodeProducingOutput(input); + + // check if this is a special-cased DQ node where we put the Squeeze prior to it in 'Case 1' above + if (inp_node && inp_node->OpType() == "DequantizeLinear" && + std::find_if(ctx.special_cased_dq_nodes.begin(), ctx.special_cased_dq_nodes.end(), + [&inp_node](int64_t id) { + return id == inp_node->Id(); + }) != ctx.special_cased_dq_nodes.end()) { + // set things up so we can look past the DQ node to the Squeeze that was inserted in front of the reshaped + // constant initializer that was shared with this node. + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = ctx.graph.GetNodeProducingOutput(dq_input); + consumers = ctx.graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Squeeze")) { const std::vector& inp_node_inputs = inp_node->Inputs(); std::optional> squeeze_axes = std::nullopt; squeeze_axes = ReadFromAttrOrInput(ctx, *inp_node, "axes", /*inp_index*/ 1, /*opset*/ 13); if (squeeze_axes != std::nullopt && *squeeze_axes == axes) { + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, inp_node_inputs[0]); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, inp_node_inputs[0]); + } + // Remove the Squeeze node if possible - if (consumers->comprehensive && consumers->nodes.size() == 0) { + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == (dq_node ? 1 : 0)) { ctx.graph.RemoveNode(*inp_node); + if (ctx.opset >= 13 && !ctx.graph.HasValueConsumers(inp_node_inputs[1])) { ctx.graph.RemoveInitializer(inp_node_inputs[1]); } } - node.SetInput(i, inp_node_inputs[0]); + return; } // Axes don't match. Fall through to Case 3. } + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + } + // Case 3: Add an Unsqueeze node. auto unsqueeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Unsqueeze", input, axes); api::NodeRef& unsqueeze = *unsqueeze_ptr; @@ -453,55 +574,124 @@ static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::Ten // Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one, // or transpose an initializer. -void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, - const std::vector& perm, const std::vector& perm_inv) { +static void TransposeInputImpl(api::GraphRef& graph, + std::unordered_set* special_cased_dq_nodes, + api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer + + // Clear the input which removes this node as a consumer node.SetInput(i, ""); + // Only local constants are editable std::unique_ptr constant = graph.GetLocalConstant(input); - auto consumers = graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInput(graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + auto value_to_modify = dq_node ? constant_dq_input : input; + auto consumers = graph.GetValueConsumers(value_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // Input is scalar, return early. - if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) { + // If there is only one element return early as the transpose won't change the data + if (constant->NumElements() == 1) { return; } + // This is a special case where the constant is 1D with length == perm. - // TODO: TransposeInitializer should be updated to handle this case. + // e.g. it provides a set of values that are relative to the input axes like the `sizes` input for Resize + // TODO: TransposeInitializer could be updated to handle this case. // Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if // there are no other consumers. + // NOTE: As this returns after calling Permute1DConstant we're implicitly assuming there are no + // other consumers that need updating. This would typically be the case for this sort of input though. if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast(perm.size())) { - Permute1DConstant(graph, node, *constant, i, input, perm); + assert(consumers->nodes.size() == 0); + Permute1DConstant(graph, node, *constant, i, value_to_modify, perm); return; } + if (consumers->nodes.size() > 0) { // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. - auto transpose_inv_ptr = MakeTranspose(graph, input, perm_inv); + auto transpose_inv_ptr = MakeTranspose(graph, value_to_modify, perm_inv); api::NodeRef& transpose_inv = *transpose_inv_ptr; std::string_view transpose_out = transpose_inv.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - ReplaceValueReferences(consumers->nodes, input, transpose_out); + graph.CopyValueInfo(value_to_modify, transpose_out); + ReplaceValueReferences(consumers->nodes, value_to_modify, transpose_out); } - graph.TransposeInitializer(input, perm); - node.SetInput(i, input); + + graph.TransposeInitializer(value_to_modify, perm); + + if (dq_node) { + // store the ids of any DQ nodes where we inserted a Transpose for matching in Case 2 later + for (auto& consumer : consumers->nodes) { + if (consumer->OpType() == "DequantizeLinear") { + special_cased_dq_nodes->insert(consumer->Id()); + } + } + + UpdateDQNodeInputAndShape(graph, *dq_node, value_to_modify); + } + + node.SetInput(i, input); // restore the original connection return; } // Case 2: input is a Transpose node std::unique_ptr inp_node = graph.GetNodeProducingOutput(input); + + // check if this is a special-cased DQ node where we put the Transpose prior to it in 'Case 1' above + if (inp_node && inp_node->OpType() == "DequantizeLinear" && special_cased_dq_nodes && + std::find_if(special_cased_dq_nodes->begin(), special_cased_dq_nodes->end(), + [&inp_node](int64_t id) { + return id == inp_node->Id(); + }) != special_cased_dq_nodes->end()) { + // set things up so we can look past the DQ node to the Squeeze that was inserted in front of the reshaped + // constant initializer that was shared with this node. + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = graph.GetNodeProducingOutput(dq_input); + consumers = graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Transpose")) { std::optional> perm2 = GetPermAttrIfValid(*inp_node); if (perm2 != std::nullopt && perm2->size() == perm.size()) { // If they cancel, use pre_transpose_value and remove Transpose if possible. if (*perm2 == perm_inv) { std::string_view pre_transpose_value = inp_node->Inputs()[0]; - if (consumers->comprehensive && consumers->nodes.size() == 0) { + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, pre_transpose_value); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, pre_transpose_value); + } + + // Remove the Transpose node if possible + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == (dq_node ? 1 : 0)) { graph.RemoveNode(*inp_node); } - node.SetInput(i, pre_transpose_value); + return; } else if (*perm2 == perm) { // we are trying to add a duplicate transpose. @@ -509,6 +699,14 @@ void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, return; } + // NOTE: We expect the Transpose to cancel out when handling a special-cased DQ node that was originally + // connected to a shared constant initializer, so we don't expect to get here if dq_node is not nullptr. + // Assert in a debug build so we can investigate if it ever happens, and bail in a release build. + assert(!dq_node); + if (dq_node) { + return; + } + // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove // the other Transpose. const std::vector& perm_combined = ComposePerm(*perm2, perm); @@ -517,14 +715,22 @@ void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, std::string_view transpose_out = transpose.Outputs()[0]; graph.CopyValueInfo(input, transpose_out); graph.GetValueInfo(transpose_out)->PermuteDims(perm); + if (consumers->comprehensive && consumers->nodes.size() == 0) { graph.RemoveNode(*inp_node); } + node.SetInput(i, transpose_out); + return; } } + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + } + // Case 3: A Transpose op might already exist for (size_t j = 0; j < consumers->nodes.size(); ++j) { api::NodeRef& consumer = *consumers->nodes[j]; @@ -543,8 +749,21 @@ void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, node.SetInput(i, transpose_out); } +void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, + const std::vector& perm, + const std::vector& perm_inv) { + // this TransposeInput is used by the layout transformer to wrap a node in Transpose ops. there's no OptimizerCtx + // in that scenario and we're not tracking special-cased DQ nodes as we only do that when pushing Transpose nodes. + TransposeInputImpl(graph, /* special_cased_dq_nodes */ nullptr, node, i, perm, perm_inv); +} + +void TransposeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { + TransposeInputImpl(ctx.graph, &ctx.special_cased_dq_nodes, node, i, perm, perm_inv); +} + // Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank. -static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank, +static bool NormalizeInputRanks(OptimizerCtx& ctx, api::NodeRef& node, size_t target_rank, const std::vector& input_indices) { auto inputs = node.Inputs(); @@ -579,7 +798,7 @@ void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector& input_indices) { auto perm_inv = InvertPerm(perm); for (size_t j : input_indices) { - TransposeInput(ctx.graph, node, j, perm, perm_inv); + TransposeInput(ctx, node, j, perm, perm_inv); } } @@ -734,8 +953,10 @@ static bool HandleSimpleNodeBase(HandlerArgs& args, bool broadcast_inputs) { if (broadcast_inputs && !NormalizeInputRanks(args.ctx, args.node, rank, args.transposible_inputs)) { return false; } + TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs); TransposeOutputs(args.ctx, args.node, args.perm); + return true; } @@ -1889,7 +2110,7 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, return std::nullopt; } - OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers}; + OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers, {}}; return ctx; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 1a54e7834a4ae..a1c6ae28d49c9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -48,6 +48,12 @@ struct OptimizerCtx { // Handlers for ops that are not in the ONNX opset, or for ONNX ops where special handling is required. // If a handler is not found in this map, the default handlers will be used. const HandlerMap& extended_handlers; + + // DQs nodes which had a shared constant initializer as input where we updated the initializer in-place and + // inserted a Squeeze and/or Transpose on the other usages. Nodes in this set had the Squeeze/Transpose inserted. + // If we attempt to push a Transpose through them we need to look past the DQ node to try and cancel + // out the Squeeze/Transpose. + std::unordered_set special_cased_dq_nodes; }; /// diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 40a03f24f7648..76596c8b65846 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -242,6 +242,12 @@ class NodeRef { /// since version or default value -1 virtual int SinceVersion() const = 0; + /// + /// Get the unique id of the node. + /// + /// Id + virtual int64_t Id() const = 0; + virtual ~NodeRef(){}; }; diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index b30c94d7b3e40..2fcb88cb0b9ba 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -95,7 +95,8 @@ class ApiNode final : public api::NodeRef { void ClearAttribute(std::string_view name) override; void SetInput(size_t i, std::string_view name) override; std::string_view GetExecutionProviderType() const override; - virtual int SinceVersion() const override; + int SinceVersion() const override; + int64_t Id() const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode); @@ -417,6 +418,10 @@ int ApiNode::SinceVersion() const { return node_.SinceVersion(); } +int64_t ApiNode::Id() const { + return node_.Index(); +} + // std::optional ApiGraph::Opset(std::string_view domain) const { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5a2a6efb6df4b..21c8fbe0cd2c9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1005,18 +1005,22 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn)); - if (modified) { - ORT_RETURN_IF_ERROR_SESSIONID_( - graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_)); - - // debug the graph after the L1 transformers have run against any layout transformation changes. - // this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second - // time to validate the EP that requested the layout transformation can take all nodes using the new layout. - // if that fails, this allows debugging the graph used in that GetCapability call. - if (debug_graph_fn) { - debug_graph_fn(graph_to_transform); - } - } + // Previously we ran the L1 transformers to handle constant folding of any initializers that were transposed in + // a QDQ format model. The transpose optimizer can now look past DQ nodes to directly update initializers which + // takes care of most models without needing this. + // + // if (modified) { + // ORT_RETURN_IF_ERROR_SESSIONID_( + // graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_)); + // + // debug the graph after the L1 transformers have run against any layout transformation changes. + // this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second + // time to validate the EP that requested the layout transformation can take all nodes using the new layout. + // if that fails, this allows debugging the graph used in that GetCapability call. + // if (debug_graph_fn) { + // debug_graph_fn(graph_to_transform); + //} + //} return Status::OK(); }; diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 1f4c499985ad0..fe3b783176eaf 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -12,6 +12,9 @@ #include "core/graph/node_attr_utils.h" #include "core/framework/op_node_proto_helper.h" #include "core/framework/utils.h" +#include "core/optimizer/transpose_optimization/onnx_transpose_optimization.h" +#include "core/optimizer/transpose_optimization/optimizer_api.h" +#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "test/test_environment.h" @@ -19,6 +22,7 @@ #include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test_utils.h" namespace onnxruntime { namespace test { @@ -4395,9 +4399,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue9671) { SessionOptions so; so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue9671"; - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization } // regression test for a model where the transpose optimizations incorrectly removed a node providing an implicit @@ -4409,9 +4413,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) { SessionOptions so; so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue10305"; - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization } // regression test for a model with DQ node with per-axis dequantization followed by a Transpose. @@ -4432,18 +4436,18 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { { so.graph_optimization_level = TransformerLevel::Default; // off - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig)); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); } { so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); } ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), @@ -4543,5 +4547,93 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { } #endif } + +using namespace onnx_transpose_optimization; +static CostCheckResult AlwaysPushTranspose(const api::GraphRef& /*graph*/, + const api::NodeRef& /*node*/, + const std::vector& /*perm*/, + const std::unordered_set& /*outputs_leading_to_transpose*/) { + return onnx_transpose_optimization::CostCheckResult::kPushTranspose; +} + +static void CheckSharedInitializerHandling(bool broadcast) { + auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") + : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 2, 2, 3}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + OrtValue input; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input_data, &input); + + NameMLValMap feeds{{"input0", input}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + + // get results with no modifications to the model + { + so.graph_optimization_level = TransformerLevel::Default; // off + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + // we call the ONNX transpose optimizer directly as we want to plug in the AlwaysPushTranspose cost check. + // this is to simplify the model required to exercise the shared initializer handling. + // it also means we don't need to disable optimizers that might alter the graph before the transpose optimizer + // runs (ConstantFolding, CommonSubexpressionElimination, ConstantSharing) + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + auto api_graph = MakeApiGraph(graph, TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + OptimizeResult result = Optimize(*api_graph, "", AlwaysPushTranspose); + + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "The Transpose nodes should have been pushed through and canceled out."; + + ASSERT_STATUS_OK(graph.Resolve()); + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// test we re-use a modified shared initializer wherever possible. model has one initializer that is used by 3 DQ nodes +// and one initializer that is used by 2 Add nodes. both cases should be handled with the initializer being +// modified in-place for the first usage, and the Transpose added to the second usage being cancelled out when the +// original Transpose at the start of the model is pushed down. +TEST(TransposeOptimizerTests, SharedInitializerHandling) { + CheckSharedInitializerHandling(/*broadcast*/ false); +} + +// same setup as the above test, however the initializer is broadcast to bring UnsqueezeInput into play. +// the in-place modification of the initializer for the first usage results in +// -> Transpose -> Squeeze -> {DQ | Add} +// the later usages of the initializer should attempt to cancel out the Squeeze in UnsqueezeInput, +// followed by cancelling out the Transpose in TransposeInput. +TEST(TransposeOptimizerTests, SharedInitializerHandlingBroadcast) { + CheckSharedInitializerHandling(/*broadcast*/ true); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx new file mode 100644 index 0000000000000..9d82d68a40098 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py new file mode 100644 index 0000000000000..d4e3f7e8cbab6 --- /dev/null +++ b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py @@ -0,0 +1,60 @@ +import numpy as np +import onnx +from onnx import TensorProto, helper + + +# Create a model with shared initializers that can be updated in-place by the transpose optimizer, +# including ones behind a DQ node. The transpose optimizer updates the first usage and inserts +# Transpose/Unsqueeze ops on the others (see UnsqueezeInput and TransposeInput). +# When we push the Transpose past other usages we should be able to cancel out those Transpose/Unsqueeze ops. +# We need 3 DQ nodes to ensure the Transpose or Unsqueeze added by the transpose optimizer is not +# removed prematurely. +def create_model(broadcast_weights: bool): + if broadcast_weights: + bias_shape = [2, 2] + bias_values = np.random.randn(2, 2) + else: + bias_shape = [1, 3, 2, 2] + bias_values = np.random.randn(1, 3, 2, 2) + + graph = helper.make_graph( + name="graph", + inputs=[ + helper.make_tensor_value_info("input0", TensorProto.FLOAT, [1, 2, 2, 3]), + ], + initializer=[ + helper.make_tensor("bias_quant", TensorProto.UINT8, bias_shape, bias_values.astype(np.uint8)), + helper.make_tensor("bias_fp32", TensorProto.FLOAT, bias_shape, bias_values.astype(np.float32)), + helper.make_tensor("dq_scale0", TensorProto.FLOAT, [], [1.5]), + helper.make_tensor("dq_zp0", TensorProto.UINT8, [], [5]), + helper.make_tensor("dq_scale1", TensorProto.FLOAT, [], [0.5]), + ], + nodes=[ + # Transpose input from channels last to channels first + helper.make_node("Transpose", ["input0"], ["input_T"], perm=[0, 3, 1, 2]), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0", "dq_zp0"], ["DQ0"], "DQ0"), + helper.make_node("Add", ["input_T", "DQ0"], ["A0"], "A0"), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale1"], ["DQ1"], "DQ1"), + helper.make_node("Add", ["A0", "DQ1"], ["A1"], "A1"), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0"], ["DQ2"], "DQ2"), + helper.make_node("Add", ["A1", "DQ2"], ["A2"], "A2"), + helper.make_node("Add", ["A2", "bias_fp32"], ["A3"], "A3"), + helper.make_node("Add", ["A3", "bias_fp32"], ["A4"], "A4"), + # NCHW to NHWC + helper.make_node("Transpose", ["A4"], ["output0"], perm=[0, 2, 3, 1]), + ], + outputs=[ + helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 2, 2, 3]), + ], + ) + + model = helper.make_model(graph) + onnx.checker.check_model(model, full_check=True) + return model + + +if __name__ == "__main__": + model = create_model(broadcast_weights=False) + onnx.save(model, "transpose_optimizer_shared_initializers.onnx") + model = create_model(broadcast_weights=True) + onnx.save(model, "transpose_optimizer_shared_initializers_broadcast.onnx") diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx new file mode 100644 index 0000000000000..8bb2c6fd4a8b5 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx differ