From 0532cfb2b06a3aef4f80f4967e43140b8e05d8b2 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 9 Jul 2024 14:30:09 -0700 Subject: [PATCH 01/17] First working version --- .../qdq_transformer/qdq_propagation.cc | 224 +++++++++++++----- .../test/providers/qnn/qnn_basic_test.cc | 53 +++++ 2 files changed, 212 insertions(+), 65 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index f0e76312d6e00..915f70de8f0bd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -3,8 +3,10 @@ #include "core/optimizer/qdq_transformer/qdq_propagation.h" +#include #include +#include "core/common/inlined_containers_fwd.h" #include "core/graph/extended_graph_edge.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" @@ -20,26 +22,55 @@ bool CanNodePropagate(const Node& node) { graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13}); + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}); } -// convert this: src_node -> dst_node -// to this: src_node -> Q -> DQ -> dst_node +// convert this: src_node --+--> dst_node_0 +// | +// +--> dst_node_1 +// | ... +// +--> dst_node_n +// +// to this: src_node -> Q --+--> DQ -> dst_node_0 +// | +// +--> DQ -> dst_node_1 +// | ... +// +--> DQ -> dst_node_n // assumptions: -// 1. insertion_edge is valid - node indexes refer to valid nodes, arg name refers to a valid NodeArg, and it -// corresponds to an actual graph relationship +// 1. insertion_edges are valid - insertion edges have the same source node, node indexes refer to valid nodes, +// arg name refers to a valid NodeArg, and it corresponds to an actual graph relationship // 2. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers -Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, - NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, - const std::string& qdq_domain, const logging::Logger& logger) { - auto* src_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source); - auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); +Status InsertQDQPairs(Graph& graph, const InlinedVector& insertion_edges, + NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, + const std::string& qdq_domain, const logging::Logger& logger) { + if (insertion_edges.empty()) { + return Status::OK(); + } + + const auto& src_info = insertion_edges[0].GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); + Node* src_node = src_info.has_value() ? graph.GetNode(src_info->node_idx) : nullptr; + bool has_some_dst_nodes = false; + + for (const auto& insertion_edge : insertion_edges) { + const auto& edge_src_info = insertion_edge.GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); + + ORT_RETURN_IF_NOT((edge_src_info.has_value() == src_info.has_value()) && + (!src_info.has_value() || + (src_info->node_idx == edge_src_info->node_idx && src_info->arg_idx == edge_src_info->arg_idx)), + "Expect all insertion edges to come from the same source node's output slot."); + + has_some_dst_nodes = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination) != nullptr; + } - ORT_ENFORCE(src_node || dst_node, "At least one graph node must be specified in the propagation edge."); + ORT_RETURN_IF_NOT(src_node || has_some_dst_nodes, + "At least one graph node must be specified in the propagation edge."); - const auto& base_name = insertion_edge.arg_name; + const auto& base_name = insertion_edges[0].arg_name; auto& base_node_arg = *graph.GetNodeArg(base_name); +#if 0 + // TODO: Fix logging for multiple dst nodes LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between " << (src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") : "input") @@ -47,9 +78,17 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, << (dst_node ? MakeString("node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")") : "output") << " at NodeArg \"" << base_name << "\"."; +#else + ORT_UNUSED_PARAMETER(logger); +#endif - // set up new NodeArgs - auto& pre_q_nodearg = insertion_edge.HasGraphInputOrInitializer() + auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) { + return zero_point ? InlinedVector{&data, &scale, zero_point} + : InlinedVector{&data, &scale}; + }; + + // Create Q node that will be inserted after src_node + auto& pre_q_nodearg = insertion_edges[0].HasGraphInputOrInitializer() ? base_node_arg : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_pre_q"), nullptr); @@ -57,17 +96,6 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, auto& q_to_dq_nodearg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_q_to_dq"), nullptr); - auto& post_dq_nodearg = insertion_edge.HasGraphOutput() - ? base_node_arg - : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_post_dq"), - nullptr); - - // set up new Nodes - auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) { - return zero_point ? std::vector{&data, &scale, zero_point} - : std::vector{&data, &scale}; - }; - auto& q_node = graph.AddNode(graph.GenerateNodeName(base_name + "_q"), QDQ::QOpName, "Inserted by QDQPropagationTransformer", @@ -81,35 +109,54 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(q_node), "Failed to set op schema for added Q node."); - auto& dq_node = graph.AddNode(graph.GenerateNodeName(base_name + "_dq"), - QDQ::DQOpName, - "Inserted by QDQPropagationTransformer", - // inputs - make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg, - zp_initializer_nodearg_ptr), - // outputs - {&post_dq_nodearg}, - nullptr, // attributes - qdq_domain); - - ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); - - // set up edges - if (src_node && dst_node) { - graph.RemoveEdge(src_node->Index(), dst_node->Index(), - insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx); - } + // Remove original edges between src and dst nodes. + for (const auto& insertion_edge : insertion_edges) { + auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - if (src_node) { - src_node->MutableOutputDefs()[insertion_edge.src->arg_idx] = &pre_q_nodearg; - graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edge.src->arg_idx, 0); + if (src_node && dst_node) { + graph.RemoveEdge(src_node->Index(), dst_node->Index(), + insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx); + } } - graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0); + // Create a DQ node for each dst node and connect all edges. + for (size_t edge_idx = 0; edge_idx < insertion_edges.size(); ++edge_idx) { + const auto& insertion_edge = insertion_edges[edge_idx]; + const std::string edge_suffix = edge_idx == 0 ? "" : std::to_string(edge_idx); + auto& post_dq_nodearg = insertion_edge.HasGraphOutput() + ? base_node_arg + : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_post_dq" + edge_suffix), + nullptr); + + auto& dq_node = graph.AddNode(graph.GenerateNodeName(base_name + "_dq" + edge_suffix), + QDQ::DQOpName, + "Inserted by QDQPropagationTransformer", + // inputs + make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg, + zp_initializer_nodearg_ptr), + // outputs + {&post_dq_nodearg}, + nullptr, // attributes + qdq_domain); + + ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); + + auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + + // Add edge from src to Q node. Only do this in the first iteration of this loop. + if (src_node && edge_idx == 0) { + src_node->MutableOutputDefs()[insertion_edge.src->arg_idx] = &pre_q_nodearg; + graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edge.src->arg_idx, 0); + } + + // Add edge from Q to DQ + graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0); - if (dst_node) { - dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg; - graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx); + // Add edge from DQ to dst_node + if (dst_node) { + dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg; + graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx); + } } return Status::OK(); @@ -173,20 +220,55 @@ std::optional GetNextEdge(const Graph& graph, const Node& nod return std::nullopt; } -std::optional GetNextPropagationEdge(const Graph& graph, - const ExtendedGraphEdge& edge) { +InlinedVector GetNextEdges(const Graph& graph, const Node& node) { + // for now we can just consider the first output (index 0) + InlinedVector next_edges; + + const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, 0); + if (output_edges.empty()) { + // maybe edge to output + auto edge = ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, 0); + if (edge.has_value()) { + next_edges.push_back(edge.value()); + } + } else if (!graph.IsOutput(node.OutputDefs()[0])) { + // edges to next nodes + for (const auto& output_edge : output_edges) { + next_edges.push_back(ExtendedGraphEdge::CreateFromValidGraphEdge(output_edge)); + } + } + + return next_edges; +} + +InlinedVector GetNextPropagationEdges(const Graph& graph, + const ExtendedGraphEdge& edge) { if (edge.HasGraphOutput()) { - return std::nullopt; + return {}; } const auto* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); ORT_ENFORCE(dst_node != nullptr); - if (!CanNodePropagate(*dst_node)) { - return std::nullopt; + const bool can_prop = CanNodePropagate(*dst_node); + if (!can_prop) { + return {}; } - return GetNextEdge(graph, *dst_node); + auto all_next_edges = GetNextEdges(graph, *dst_node); + InlinedVector next_prop_edges; + next_prop_edges.reserve(all_next_edges.size()); + + // Filter out edges that end in Q nodes. + // There is no need to insert a Q node in an edge that already ends in a Q node. + std::copy_if(all_next_edges.begin(), all_next_edges.end(), std::back_inserter(next_prop_edges), + [&graph](const ExtendedGraphEdge& e) -> bool { + const auto* dst_node = e.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + const bool is_q_node = dst_node && QDQ::MatchQNode(*dst_node); + return !is_q_node; + }); + + return next_prop_edges; } class GraphConstantInitializerGetter { @@ -233,16 +315,27 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, continue; } - for (auto curr_edge = GetNextPropagationEdge(graph, *edge_after_dq); - curr_edge.has_value(); - curr_edge = GetNextPropagationEdge(graph, *curr_edge)) { - if (const auto* dst_node = curr_edge->GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - dst_node && QDQ::MatchQNode(*dst_node)) { - break; - } + InlinedVector> edge_groups; + InlinedVector first_edge_group = GetNextPropagationEdges(graph, *edge_after_dq); + + if (!first_edge_group.empty()) { + edge_groups.push_back(std::move(first_edge_group)); + } + + while (!edge_groups.empty()) { + InlinedVector prop_edges = edge_groups.back(); + edge_groups.pop_back(); - ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, dq_scale, dq_zero_point, dq_node.Domain(), logger)); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, prop_edges, dq_scale, dq_zero_point, dq_node.Domain(), logger)); modified = true; + + for (const auto& prop_edge : prop_edges) { + InlinedVector next_edge_group = GetNextPropagationEdges(graph, prop_edge); + + if (!next_edge_group.empty()) { + edge_groups.push_back(std::move(next_edge_group)); + } + } } } @@ -290,7 +383,8 @@ Status PropagateQBackward(Graph& graph, gsl::span node_indices, break; } - ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, q_scale, q_zero_point, q_node.Domain(), logger)); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, InlinedVector{*curr_edge}, q_scale, q_zero_point, + q_node.Domain(), logger)); modified = true; } } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 6173f46839a81..d0033dc0ee84c 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -903,6 +903,59 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } +TEST_F(QnnHTPBackendTests, SliceQDQPropagation_MultConsumers) { + Ort::SessionOptions so; + + // Ensure all type/shape inference warnings result in errors! + // so.AddConfigEntry(kOrtSessionOptionsConfigStrictShapeTypeInference, "1"); + // so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP. + so.AddConfigEntry(kDebugLayoutTransformation, "1"); + so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + // so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + // ort_env->UpdateEnvWithCustomLogLevel(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE); + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + so.AppendExecutionProvider("QNN", options); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "slice_qdq_propagation_mult_consumers.onnx"; + Ort::Session session(*ort_env, ort_model_path, so); + + // image: 1,3,640,640 + std::vector input0_data(1 * 3 * 640 * 640); + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add input0 + std::array inputs_shape{1, 3, 640, 640}; + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), inputs_shape.data(), inputs_shape.size())); + ort_input_names.push_back("image"); + + // Run session and get outputs + std::array output_names{"output_0", "output_1", "output_2"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output shape. + Ort::Value& ort_output = ort_outputs[0]; + auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); + std::vector output_shape = typeshape.GetShape(); + + EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 8400, 4)); + const uint8_t* results = ort_output.GetTensorData(); + + for (size_t i = 0; i < typeshape.GetElementCount() && i < 10; i++) { + std::cout << i << ": " << results[i] << std::endl; + } +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) From 65166b8995128d8600bbec25ebebbfd5c09d99c3 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 9 Jul 2024 14:49:41 -0700 Subject: [PATCH 02/17] Simplify edge group loop --- .../qdq_transformer/qdq_propagation.cc | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 915f70de8f0bd..e8b751188e883 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -250,8 +250,7 @@ InlinedVector GetNextPropagationEdges(const Graph& graph, const auto* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); ORT_ENFORCE(dst_node != nullptr); - const bool can_prop = CanNodePropagate(*dst_node); - if (!can_prop) { + if (!CanNodePropagate(*dst_node)) { return {}; } @@ -315,26 +314,21 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, continue; } - InlinedVector> edge_groups; - InlinedVector first_edge_group = GetNextPropagationEdges(graph, *edge_after_dq); - - if (!first_edge_group.empty()) { - edge_groups.push_back(std::move(first_edge_group)); - } + InlinedVector> edge_groups = {GetNextPropagationEdges(graph, *edge_after_dq)}; while (!edge_groups.empty()) { - InlinedVector prop_edges = edge_groups.back(); + InlinedVector edges = edge_groups.back(); edge_groups.pop_back(); - ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, prop_edges, dq_scale, dq_zero_point, dq_node.Domain(), logger)); - modified = true; + if (edges.empty()) { + continue; + } - for (const auto& prop_edge : prop_edges) { - InlinedVector next_edge_group = GetNextPropagationEdges(graph, prop_edge); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, edges, dq_scale, dq_zero_point, dq_node.Domain(), logger)); + modified = true; - if (!next_edge_group.empty()) { - edge_groups.push_back(std::move(next_edge_group)); - } + for (const auto& edge : edges) { + edge_groups.push_back(GetNextPropagationEdges(graph, edge)); } } } From d9d76fc58790260e440e94d987ed668eda2c57a9 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 9 Jul 2024 14:59:45 -0700 Subject: [PATCH 03/17] Clean up edge case handling --- .../core/optimizer/qdq_transformer/qdq_propagation.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index e8b751188e883..f83de96b2c080 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -44,9 +44,7 @@ bool CanNodePropagate(const Node& node) { Status InsertQDQPairs(Graph& graph, const InlinedVector& insertion_edges, NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, const std::string& qdq_domain, const logging::Logger& logger) { - if (insertion_edges.empty()) { - return Status::OK(); - } + ORT_RETURN_IF(insertion_edges.empty(), "Expected at least one edge into which to insert QDQ pair."); const auto& src_info = insertion_edges[0].GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); Node* src_node = src_info.has_value() ? graph.GetNode(src_info->node_idx) : nullptr; @@ -60,7 +58,7 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse (src_info->node_idx == edge_src_info->node_idx && src_info->arg_idx == edge_src_info->arg_idx)), "Expect all insertion edges to come from the same source node's output slot."); - has_some_dst_nodes = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination) != nullptr; + has_some_dst_nodes = insertion_edge.GetNodeInfoAtEnd(ExtendedGraphEdge::End::Destination).has_value(); } ORT_RETURN_IF_NOT(src_node || has_some_dst_nodes, From 66d7a9aad864aca1d70a08ac9170b8e76f3fa98e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 9 Jul 2024 17:13:01 -0700 Subject: [PATCH 04/17] Use a queue to traverse groups of edges --- .../qdq_transformer/qdq_propagation.cc | 45 +++++++------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index f83de96b2c080..f49df89028802 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -5,6 +5,7 @@ #include #include +#include #include "core/common/inlined_containers_fwd.h" #include "core/graph/extended_graph_edge.h" @@ -117,7 +118,13 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse } } - // Create a DQ node for each dst node and connect all edges. + // Add edge from src to Q node. + if (src_node) { + src_node->MutableOutputDefs()[insertion_edges[0].src->arg_idx] = &pre_q_nodearg; + graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edges[0].src->arg_idx, 0); + } + + // Create a DQ node for each dst node and connect remaining edges. for (size_t edge_idx = 0; edge_idx < insertion_edges.size(); ++edge_idx) { const auto& insertion_edge = insertion_edges[edge_idx]; const std::string edge_suffix = edge_idx == 0 ? "" : std::to_string(edge_idx); @@ -141,12 +148,6 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - // Add edge from src to Q node. Only do this in the first iteration of this loop. - if (src_node && edge_idx == 0) { - src_node->MutableOutputDefs()[insertion_edge.src->arg_idx] = &pre_q_nodearg; - graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edge.src->arg_idx, 0); - } - // Add edge from Q to DQ graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0); @@ -201,23 +202,6 @@ std::optional GetPreviousPropagationEdge(const Graph& graph, return GetPreviousEdge(graph, *src_node); } -std::optional GetNextEdge(const Graph& graph, const Node& node) { - // for now we can just consider the first output (index 0) - - const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, 0); - if (output_edges.empty()) { - // maybe edge to output - return ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, 0); - } - - if (!graph.IsOutput(node.OutputDefs()[0]) && output_edges.size() == 1) { - // single edge to next node - return ExtendedGraphEdge::CreateFromValidGraphEdge(output_edges.front()); - } - - return std::nullopt; -} - InlinedVector GetNextEdges(const Graph& graph, const Node& node) { // for now we can just consider the first output (index 0) InlinedVector next_edges; @@ -307,16 +291,17 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, ? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID] : nullptr; - const auto edge_after_dq = GetNextEdge(graph, dq_node); - if (!edge_after_dq) { + InlinedVector edges_after_dq = GetNextEdges(graph, dq_node); + if (edges_after_dq.size() != 1) { continue; } - InlinedVector> edge_groups = {GetNextPropagationEdges(graph, *edge_after_dq)}; + std::queue> edge_groups; + edge_groups.push(GetNextPropagationEdges(graph, edges_after_dq[0])); while (!edge_groups.empty()) { - InlinedVector edges = edge_groups.back(); - edge_groups.pop_back(); + const InlinedVector edges = std::move(edge_groups.front()); + edge_groups.pop(); if (edges.empty()) { continue; @@ -326,7 +311,7 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, modified = true; for (const auto& edge : edges) { - edge_groups.push_back(GetNextPropagationEdges(graph, edge)); + edge_groups.push(GetNextPropagationEdges(graph, edge)); } } } From ecddb01e1f4d5b0ea9066e1a59718092bd7854cd Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 10 Jul 2024 15:35:23 -0700 Subject: [PATCH 05/17] Move validation and logging to separate functions --- .../qdq_transformer/qdq_propagation.cc | 107 +++++++++++------- 1 file changed, 69 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index f49df89028802..1aae9c900821e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include "core/common/inlined_containers_fwd.h" #include "core/graph/extended_graph_edge.h" @@ -27,6 +28,62 @@ bool CanNodePropagate(const Node& node) { graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}); } +// Validates edges into which to insert Q -> DQ ops. +// - Must have at least one edge. +// - All edges with a source node must originate from the same source node's output. +// - All edges must be attached to either a source node or a destination node. +Status ValidateQDQInsertionEdges(Graph& graph, const InlinedVector& insertion_edges) { + ORT_RETURN_IF(insertion_edges.empty(), "Expected at least one edge into which to insert QDQ pair."); + + const auto& src_info = insertion_edges[0].GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); + const Node* src_node = src_info.has_value() ? graph.GetNode(src_info->node_idx) : nullptr; + + for (const auto& insertion_edge : insertion_edges) { + const auto& edge_src_info = insertion_edge.GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); + + ORT_RETURN_IF_NOT((edge_src_info.has_value() == src_info.has_value()) && + (!src_info.has_value() || + (src_info->node_idx == edge_src_info->node_idx && src_info->arg_idx == edge_src_info->arg_idx)), + "Expect all insertion edges to come from the same source node's output slot."); + + const Node* edge_dst_node = insertion_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + ORT_RETURN_IF_NOT(src_node != nullptr || edge_dst_node != nullptr, + "At least one graph node must be specified in the propagation edges."); + } + + return Status::OK(); +} + +// Logs information about the edges into which Q/DQ nodes will be inserted in InsertQDQPairs(). +// Assumes the edges have already been validated. +void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, + const Graph& graph, const InlinedVector& edges) { + if (!logger.OutputIsEnabled(severity, logging::DataType::SYSTEM)) { + return; + } + + const Node* src_node = edges[0].GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source); + const auto& node_arg_name = edges[0].arg_name; + std::string src_label = src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") + : "input"; + std::ostringstream dst_labels; + const size_t num_edges = edges.size(); + + for (size_t i = 0; i < num_edges; ++i) { + const ExtendedGraphEdge& edge = edges[i]; + const Node* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + dst_labels << (dst_node ? MakeString("dst node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")") + : "output") + << (i == num_edges - 1 ? "" : ","); + } + + LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between " + << (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") + : "input") + << " and " << dst_labels.str() + << " at NodeArg \"" << node_arg_name << "\"."; +} + // convert this: src_node --+--> dst_node_0 // | // +--> dst_node_1 @@ -39,47 +96,21 @@ bool CanNodePropagate(const Node& node) { // | ... // +--> DQ -> dst_node_n // assumptions: -// 1. insertion_edges are valid - insertion edges have the same source node, node indexes refer to valid nodes, -// arg name refers to a valid NodeArg, and it corresponds to an actual graph relationship -// 2. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers +// 1. All insertion edges have the same source node and the same source node output index. +// 2. Insertion_edges are valid: node indices refer to valid nodes, and arg names refer to valid NodeArgs in the graph. +// 3. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers Status InsertQDQPairs(Graph& graph, const InlinedVector& insertion_edges, NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, const std::string& qdq_domain, const logging::Logger& logger) { - ORT_RETURN_IF(insertion_edges.empty(), "Expected at least one edge into which to insert QDQ pair."); - - const auto& src_info = insertion_edges[0].GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); - Node* src_node = src_info.has_value() ? graph.GetNode(src_info->node_idx) : nullptr; - bool has_some_dst_nodes = false; - - for (const auto& insertion_edge : insertion_edges) { - const auto& edge_src_info = insertion_edge.GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); + ORT_RETURN_IF_ERROR(ValidateQDQInsertionEdges(graph, insertion_edges)); - ORT_RETURN_IF_NOT((edge_src_info.has_value() == src_info.has_value()) && - (!src_info.has_value() || - (src_info->node_idx == edge_src_info->node_idx && src_info->arg_idx == edge_src_info->arg_idx)), - "Expect all insertion edges to come from the same source node's output slot."); - - has_some_dst_nodes = insertion_edge.GetNodeInfoAtEnd(ExtendedGraphEdge::End::Destination).has_value(); - } - - ORT_RETURN_IF_NOT(src_node || has_some_dst_nodes, - "At least one graph node must be specified in the propagation edge."); + const ExtendedGraphEdge& first_edge = insertion_edges[0]; // ValidateQDQInsertionEdges() guarantees at least one edge - const auto& base_name = insertion_edges[0].arg_name; + Node* src_node = first_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source); // nullptr for graph input + const auto& base_name = first_edge.arg_name; auto& base_node_arg = *graph.GetNodeArg(base_name); -#if 0 - // TODO: Fix logging for multiple dst nodes - LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between " - << (src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") - : "input") - << " and " - << (dst_node ? MakeString("node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")") - : "output") - << " at NodeArg \"" << base_name << "\"."; -#else - ORT_UNUSED_PARAMETER(logger); -#endif + LogQDQInsertion(logger, logging::Severity::kVERBOSE, graph, insertion_edges); auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) { return zero_point ? InlinedVector{&data, &scale, zero_point} @@ -87,7 +118,7 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse }; // Create Q node that will be inserted after src_node - auto& pre_q_nodearg = insertion_edges[0].HasGraphInputOrInitializer() + auto& pre_q_nodearg = first_edge.HasGraphInputOrInitializer() ? base_node_arg : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_pre_q"), nullptr); @@ -120,8 +151,8 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse // Add edge from src to Q node. if (src_node) { - src_node->MutableOutputDefs()[insertion_edges[0].src->arg_idx] = &pre_q_nodearg; - graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edges[0].src->arg_idx, 0); + src_node->MutableOutputDefs()[first_edge.src->arg_idx] = &pre_q_nodearg; + graph.AddEdge(src_node->Index(), q_node.Index(), first_edge.src->arg_idx, 0); } // Create a DQ node for each dst node and connect remaining edges. @@ -146,7 +177,7 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); - auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + Node* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); // Add edge from Q to DQ graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0); From e347c1173edb1c402cea4b315d31f0b0ab35aea2 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 10 Jul 2024 18:34:25 -0700 Subject: [PATCH 06/17] Add QDQTransformerTest --- .../test/optimizer/qdq_transformer_test.cc | 97 +++++++++++++++++++ .../test/providers/qnn/qnn_basic_test.cc | 53 ---------- 2 files changed, 97 insertions(+), 53 deletions(-) diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1c77121ba9df1..1eb99f5f4e153 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -12,6 +12,7 @@ #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -3369,6 +3370,102 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) { #endif } +// Test propagating a DQ forward through a chain of Slice and Transpose operators that have multiple consumers. +// original model: +// in0 -> DQ -> Slice --+--> Add -> out0 +// | +// +--> TP --+--> Pow -> out1 +// | | +// | +--> Pow -> out2 +// | +// +--> TP --+--> Pow -> out3 +// | +// +--> Pow -> out4 +// expected model: +// in0 -> DQ -> Slice -> Q --+--> DQ -> Add -> out0 +// | +// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out1 +// | | +// | +--> DQ -> Pow -> out2 +// | +// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out3 +// | +// +--> DQ -> Pow -> out4 +TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector input0_shape = {1, 2, 2, 2}; + std::vector input1_shape = {1, 1, 1, 1}; + auto* input0_arg = builder.MakeInput(input0_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input1_arg = builder.MakeInput(input1_shape, {0.0f}); + auto* output0_arg = builder.MakeOutput(); + auto* output1_arg = builder.MakeOutput(); + auto* output2_arg = builder.MakeOutput(); + auto* output3_arg = builder.MakeOutput(); + auto* output4_arg = builder.MakeOutput(); + + // DQ + constexpr float qdq_scale = 1.0f; + constexpr uint8_t qdq_zero_point = 128; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input0_arg, qdq_scale, qdq_zero_point, dq_output); + + // Slice + auto* slice_output = builder.MakeIntermediate(); + auto* slice_starts = builder.Make1DInitializer(std::vector{0, 0, 0, 0}); + auto* slice_ends = builder.Make1DInitializer(std::vector{1, 1, 1, 1}); + builder.AddNode("Slice", {dq_output, slice_starts, slice_ends}, {slice_output}); + + // Add + builder.AddNode("Add", {slice_output, input1_arg}, {output0_arg}); + + // Transpose + auto* transpose0_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {slice_output}, {transpose0_output}); + + // Transpose + auto* transpose1_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {slice_output}, {transpose1_output}); + + // Pows + auto* pow_exp = builder.MakeScalarInitializer(2.0f); + builder.AddNode("Pow", {transpose0_output, pow_exp}, {output1_arg}); + builder.AddNode("Pow", {transpose0_output, pow_exp}, {output2_arg}); + builder.AddNode("Pow", {transpose1_output, pow_exp}, {output3_arg}); + builder.AddNode("Pow", {transpose1_output, pow_exp}, {output4_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + std::vector expected_op_types_in_order{ + qdq_keys.dequantize_linear, + "Slice", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Add", + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Pow"}; + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 18, 0.0, 0.0, std::make_unique()); +} + TEST(QDQTransformerTests, QDQ_Selector_Test) { const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 986f36b49d114..9489d354755e4 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -948,59 +948,6 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } -TEST_F(QnnHTPBackendTests, SliceQDQPropagation_MultConsumers) { - Ort::SessionOptions so; - - // Ensure all type/shape inference warnings result in errors! - // so.AddConfigEntry(kOrtSessionOptionsConfigStrictShapeTypeInference, "1"); - // so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP. - so.AddConfigEntry(kDebugLayoutTransformation, "1"); - so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); - // so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); - // ort_env->UpdateEnvWithCustomLogLevel(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE); - onnxruntime::ProviderOptions options; - -#if defined(_WIN32) - options["backend_path"] = "QnnHtp.dll"; -#else - options["backend_path"] = "libQnnHtp.so"; -#endif - - so.AppendExecutionProvider("QNN", options); - - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "slice_qdq_propagation_mult_consumers.onnx"; - Ort::Session session(*ort_env, ort_model_path, so); - - // image: 1,3,640,640 - std::vector input0_data(1 * 3 * 640 * 640); - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector ort_inputs; - std::vector ort_input_names; - - // Add input0 - std::array inputs_shape{1, 3, 640, 640}; - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input0_data.data(), input0_data.size(), inputs_shape.data(), inputs_shape.size())); - ort_input_names.push_back("image"); - - // Run session and get outputs - std::array output_names{"output_0", "output_1", "output_2"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output shape. - Ort::Value& ort_output = ort_outputs[0]; - auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); - std::vector output_shape = typeshape.GetShape(); - - EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 8400, 4)); - const uint8_t* results = ort_output.GetTensorData(); - - for (size_t i = 0; i < typeshape.GetElementCount() && i < 10; i++) { - std::cout << i << ": " << results[i] << std::endl; - } -} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) From 865d3f397ddcc5cc2fef092821792123d0c999d4 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 11 Jul 2024 00:47:59 -0700 Subject: [PATCH 07/17] Support propagating DQ when op with multiple consumers also generates a graph output --- .../qdq_transformer/qdq_propagation.cc | 50 ++--- .../optimizer/graph_transform_test_builder.cc | 4 +- .../test/optimizer/qdq_transformer_test.cc | 174 ++++++++++-------- 3 files changed, 124 insertions(+), 104 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 1aae9c900821e..8e0b6d0fc83b5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -139,18 +139,18 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(q_node), "Failed to set op schema for added Q node."); - // Remove original edges between src and dst nodes. - for (const auto& insertion_edge : insertion_edges) { - auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + if (src_node) { + // Remove original edges between src and dst nodes. + for (const auto& insertion_edge : insertion_edges) { + auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - if (src_node && dst_node) { - graph.RemoveEdge(src_node->Index(), dst_node->Index(), - insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx); + if (dst_node) { + graph.RemoveEdge(src_node->Index(), dst_node->Index(), + insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx); + } } - } - // Add edge from src to Q node. - if (src_node) { + // Add edge from src to Q node. src_node->MutableOutputDefs()[first_edge.src->arg_idx] = &pre_q_nodearg; graph.AddEdge(src_node->Index(), q_node.Index(), first_edge.src->arg_idx, 0); } @@ -161,10 +161,12 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse const std::string edge_suffix = edge_idx == 0 ? "" : std::to_string(edge_idx); auto& post_dq_nodearg = insertion_edge.HasGraphOutput() ? base_node_arg - : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_post_dq" + edge_suffix), + : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(MakeString(base_name, + "_post_dq", + edge_suffix)), nullptr); - auto& dq_node = graph.AddNode(graph.GenerateNodeName(base_name + "_dq" + edge_suffix), + auto& dq_node = graph.AddNode(graph.GenerateNodeName(MakeString(base_name, "_dq", edge_suffix)), QDQ::DQOpName, "Inserted by QDQPropagationTransformer", // inputs @@ -234,21 +236,19 @@ std::optional GetPreviousPropagationEdge(const Graph& graph, } InlinedVector GetNextEdges(const Graph& graph, const Node& node) { - // for now we can just consider the first output (index 0) + constexpr int node_output_index = 0; // for now we can just consider the first output (index 0) InlinedVector next_edges; + const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, static_cast(node_output_index)); - const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, 0); - if (output_edges.empty()) { - // maybe edge to output - auto edge = ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, 0); - if (edge.has_value()) { - next_edges.push_back(edge.value()); - } - } else if (!graph.IsOutput(node.OutputDefs()[0])) { - // edges to next nodes - for (const auto& output_edge : output_edges) { - next_edges.push_back(ExtendedGraphEdge::CreateFromValidGraphEdge(output_edge)); - } + // edges to next nodes + for (const auto& output_edge : output_edges) { + next_edges.push_back(ExtendedGraphEdge::CreateFromValidGraphEdge(output_edge)); + } + + // maybe edge to graph output + auto edge_to_output = ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, node_output_index); + if (edge_to_output.has_value()) { + next_edges.push_back(edge_to_output.value()); } return next_edges; @@ -322,7 +322,7 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, ? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID] : nullptr; - InlinedVector edges_after_dq = GetNextEdges(graph, dq_node); + const InlinedVector edges_after_dq = GetNextEdges(graph, dq_node); if (edges_after_dq.size() != 1) { continue; } diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 73c8b3f119103..b8d6aeab996d8 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -246,14 +246,14 @@ Status TestGraphTransformer(const std::function& ORT_RETURN_IF_ERROR(pre_graph_checker(graph)); } #if SAVE_TEST_GRAPH - ORT_RETURN_IF_ERROR(Model::Save(model, "model_original.onnx")); + ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_original.onnx"))); #endif ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger)); if (post_graph_checker) { ORT_RETURN_IF_ERROR(post_graph_checker(graph)); } #if SAVE_TEST_GRAPH - ORT_RETURN_IF_ERROR(Model::Save(model, "model_optimized.onnx")); + ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_optimized.onnx"))); #endif }; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1eb99f5f4e153..a58b12e1f975e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3372,17 +3372,21 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) { // Test propagating a DQ forward through a chain of Slice and Transpose operators that have multiple consumers. // original model: -// in0 -> DQ -> Slice --+--> Add -> out0 +// in0 -> DQ -> Slice --+--> slice_out // | -// +--> TP --+--> Pow -> out1 -// | | -// | +--> Pow -> out2 +// +--> Add -> out0 // | -// +--> TP --+--> Pow -> out3 -// | -// +--> Pow -> out4 +// +--> Transpose --+--> Pow -> out1 +// | | +// | +--> Pow -> out2 +// | +// +--> Transpose --+--> Pow -> out3 +// | +// +--> Pow -> out4 // expected model: -// in0 -> DQ -> Slice -> Q --+--> DQ -> Add -> out0 +// in0 -> DQ -> Slice -> Q --+--> DQ -> slice_out +// | +// +--> DQ -> Add -> out0 // | // +--> DQ -> TP -> Q --+--> DQ -> Pow -> out1 // | | @@ -3392,78 +3396,94 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) { // | // +--> DQ -> Pow -> out4 TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { - auto build_test_case = [&](ModelTestBuilder& builder) { - std::vector input0_shape = {1, 2, 2, 2}; - std::vector input1_shape = {1, 1, 1, 1}; - auto* input0_arg = builder.MakeInput(input0_shape, - std::numeric_limits::min(), - std::numeric_limits::max()); - auto* input1_arg = builder.MakeInput(input1_shape, {0.0f}); - auto* output0_arg = builder.MakeOutput(); - auto* output1_arg = builder.MakeOutput(); - auto* output2_arg = builder.MakeOutput(); - auto* output3_arg = builder.MakeOutput(); - auto* output4_arg = builder.MakeOutput(); - - // DQ - constexpr float qdq_scale = 1.0f; - constexpr uint8_t qdq_zero_point = 128; - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input0_arg, qdq_scale, qdq_zero_point, dq_output); - - // Slice - auto* slice_output = builder.MakeIntermediate(); - auto* slice_starts = builder.Make1DInitializer(std::vector{0, 0, 0, 0}); - auto* slice_ends = builder.Make1DInitializer(std::vector{1, 1, 1, 1}); - builder.AddNode("Slice", {dq_output, slice_starts, slice_ends}, {slice_output}); - - // Add - builder.AddNode("Add", {slice_output, input1_arg}, {output0_arg}); - - // Transpose - auto* transpose0_output = builder.MakeIntermediate(); - builder.AddNode("Transpose", {slice_output}, {transpose0_output}); - - // Transpose - auto* transpose1_output = builder.MakeIntermediate(); - builder.AddNode("Transpose", {slice_output}, {transpose1_output}); - - // Pows - auto* pow_exp = builder.MakeScalarInitializer(2.0f); - builder.AddNode("Pow", {transpose0_output, pow_exp}, {output1_arg}); - builder.AddNode("Pow", {transpose0_output, pow_exp}, {output2_arg}); - builder.AddNode("Pow", {transpose1_output, pow_exp}, {output3_arg}); - builder.AddNode("Pow", {transpose1_output, pow_exp}, {output4_arg}); - }; + auto run_test_case = [&](bool slice_has_graph_output) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector input0_shape = {1, 2, 2, 2}; + std::vector input1_shape = {1, 1, 1, 1}; + auto* input0_arg = builder.MakeInput(input0_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input1_arg = builder.MakeInput(input1_shape, {0.0f}); + auto* output0_arg = builder.MakeOutput(); + auto* output1_arg = builder.MakeOutput(); + auto* output2_arg = builder.MakeOutput(); + auto* output3_arg = builder.MakeOutput(); + auto* output4_arg = builder.MakeOutput(); + + // DQ + constexpr float qdq_scale = 1.0f; + constexpr uint8_t qdq_zero_point = 128; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input0_arg, qdq_scale, qdq_zero_point, dq_output); + + // Slice + auto* slice_output = slice_has_graph_output ? builder.MakeOutput() : builder.MakeIntermediate(); + auto* slice_starts = builder.Make1DInitializer(std::vector{0, 0, 0, 0}); + auto* slice_ends = builder.Make1DInitializer(std::vector{1, 1, 1, 1}); + builder.AddNode("Slice", {dq_output, slice_starts, slice_ends}, {slice_output}); + + // Add + builder.AddNode("Add", {slice_output, input1_arg}, {output0_arg}); + + // Transpose + auto* transpose0_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {slice_output}, {transpose0_output}); + + // Transpose + auto* transpose1_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {slice_output}, {transpose1_output}); + + // Pows + auto* pow_exp = builder.MakeScalarInitializer(2.0f); + builder.AddNode("Pow", {transpose0_output, pow_exp}, {output1_arg}); + builder.AddNode("Pow", {transpose0_output, pow_exp}, {output2_arg}); + builder.AddNode("Pow", {transpose1_output, pow_exp}, {output3_arg}); + builder.AddNode("Pow", {transpose1_output, pow_exp}, {output4_arg}); + }; - auto check_graph = [&](InferenceSessionWrapper& session) { - const QDQOpKeys qdq_keys = GetQDQOpKeys(false); - std::vector expected_op_types_in_order{ - qdq_keys.dequantize_linear, - "Slice", - qdq_keys.quantize_linear, qdq_keys.dequantize_linear, - "Add", - qdq_keys.dequantize_linear, - "Transpose", - qdq_keys.quantize_linear, qdq_keys.dequantize_linear, - "Pow", - qdq_keys.dequantize_linear, - "Pow", - qdq_keys.dequantize_linear, - "Transpose", - qdq_keys.quantize_linear, qdq_keys.dequantize_linear, - "Pow", - qdq_keys.dequantize_linear, - "Pow"}; - const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); - EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + auto check_graph = [&](InferenceSessionWrapper& session) { + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + std::vector expected_op_types_in_order; + expected_op_types_in_order.reserve(20); + expected_op_types_in_order.insert(expected_op_types_in_order.end(), + {qdq_keys.dequantize_linear, + "Slice", + qdq_keys.quantize_linear}); + + if (slice_has_graph_output) { + // Should have a DQ before the graph output generated by the Slice. + expected_op_types_in_order.push_back(qdq_keys.dequantize_linear); + } + + expected_op_types_in_order.insert(expected_op_types_in_order.end(), + {qdq_keys.dequantize_linear, + "Add", + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Pow"}); + + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 18, 0.0, 0.0, std::make_unique()); }; - TransformerTester(build_test_case, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1, - 18, 0.0, 0.0, std::make_unique()); + run_test_case(/*slice_has_graph_output*/ false); + run_test_case(/*slice_has_graph_output*/ true); } TEST(QDQTransformerTests, QDQ_Selector_Test) { From c72e374f77f0c88d75a5764e22c51690d667ddbf Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 11 Jul 2024 01:52:36 -0700 Subject: [PATCH 08/17] Don't propagate DQ forward if any edge after the data-movement op ends in a Q --- .../qdq_transformer/qdq_propagation.cc | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 8e0b6d0fc83b5..1e603d5f39224 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -3,7 +3,6 @@ #include "core/optimizer/qdq_transformer/qdq_propagation.h" -#include #include #include #include @@ -267,20 +266,19 @@ InlinedVector GetNextPropagationEdges(const Graph& graph, return {}; } - auto all_next_edges = GetNextEdges(graph, *dst_node); - InlinedVector next_prop_edges; - next_prop_edges.reserve(all_next_edges.size()); + auto next_edges = GetNextEdges(graph, *dst_node); + bool any_edge_to_q = false; - // Filter out edges that end in Q nodes. - // There is no need to insert a Q node in an edge that already ends in a Q node. - std::copy_if(all_next_edges.begin(), all_next_edges.end(), std::back_inserter(next_prop_edges), - [&graph](const ExtendedGraphEdge& e) -> bool { - const auto* dst_node = e.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - const bool is_q_node = dst_node && QDQ::MatchQNode(*dst_node); - return !is_q_node; - }); + // Check if any edge ends a Q node. If so, we don't propagate. + for (const auto& next_edge : next_edges) { + const auto* edge_dst_node = next_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) { + any_edge_to_q = true; + break; + } + } - return next_prop_edges; + return any_edge_to_q ? InlinedVector{} : next_edges; } class GraphConstantInitializerGetter { From 1853f28a44ea571aaebf45220876141867559434 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 11 Jul 2024 02:37:43 -0700 Subject: [PATCH 09/17] Move check for edge that ends in Q out of function --- .../qdq_transformer/qdq_propagation.cc | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 1e603d5f39224..5b302087da79c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -266,19 +266,7 @@ InlinedVector GetNextPropagationEdges(const Graph& graph, return {}; } - auto next_edges = GetNextEdges(graph, *dst_node); - bool any_edge_to_q = false; - - // Check if any edge ends a Q node. If so, we don't propagate. - for (const auto& next_edge : next_edges) { - const auto* edge_dst_node = next_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) { - any_edge_to_q = true; - break; - } - } - - return any_edge_to_q ? InlinedVector{} : next_edges; + return GetNextEdges(graph, *dst_node); } class GraphConstantInitializerGetter { @@ -325,6 +313,21 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, continue; } + // Utility function to check if any edge out of a node (e.g., Transpose) ends a Q node. + // If so, we don't propagate. + auto any_edge_ends_in_q = [](Graph& graph, const InlinedVector& edges) -> bool { + bool any_edge_to_q = false; + + for (const auto& edge : edges) { + const auto* edge_dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) { + any_edge_to_q = true; + break; + } + } + return any_edge_to_q; + }; + std::queue> edge_groups; edge_groups.push(GetNextPropagationEdges(graph, edges_after_dq[0])); @@ -332,7 +335,7 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, const InlinedVector edges = std::move(edge_groups.front()); edge_groups.pop(); - if (edges.empty()) { + if (edges.empty() || any_edge_ends_in_q(graph, edges)) { continue; } From 1490f5c5c8a4a1ecc84c26fa101bb4d47f3d2d8c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 11 Jul 2024 09:24:15 -0700 Subject: [PATCH 10/17] Add comment for BFS traversal of edge groups --- .../qdq_transformer/qdq_propagation.cc | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 5b302087da79c..c728a6d0c048a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -313,36 +313,41 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, continue; } - // Utility function to check if any edge out of a node (e.g., Transpose) ends a Q node. - // If so, we don't propagate. + // Utility function to check if any edge out of a node (e.g., Transpose) ends in a Q node. auto any_edge_ends_in_q = [](Graph& graph, const InlinedVector& edges) -> bool { - bool any_edge_to_q = false; - for (const auto& edge : edges) { const auto* edge_dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) { - any_edge_to_q = true; - break; + return true; } } - return any_edge_to_q; + return false; }; + // Propagate DQ forward in a BFS traversal of "edge groups". A single edge group consists of one or more edges + // that all begin at a unique source node and end at one or more destination nodes. Ex: The subgraph below shows + // an edge group (containing 3 edges) that begins at a Transpose, ends at two destination nodes, and produces a + // graph output. + // DQ -> Transpose --+--> Sigmoid -> ... + // | + // +--> Slice -> ... + // | + // +--> graph_output std::queue> edge_groups; edge_groups.push(GetNextPropagationEdges(graph, edges_after_dq[0])); while (!edge_groups.empty()) { - const InlinedVector edges = std::move(edge_groups.front()); + const InlinedVector curr_edge_group = std::move(edge_groups.front()); edge_groups.pop(); - if (edges.empty() || any_edge_ends_in_q(graph, edges)) { + if (curr_edge_group.empty() || any_edge_ends_in_q(graph, curr_edge_group)) { continue; } - ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, edges, dq_scale, dq_zero_point, dq_node.Domain(), logger)); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, curr_edge_group, dq_scale, dq_zero_point, dq_node.Domain(), logger)); modified = true; - for (const auto& edge : edges) { + for (const auto& edge : curr_edge_group) { edge_groups.push(GetNextPropagationEdges(graph, edge)); } } From aa0733d20988ab6580e4df5324de2dd55a2719b9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 11 Jul 2024 16:58:08 -0700 Subject: [PATCH 11/17] Update onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index c728a6d0c048a..cd8bb0c0af3e7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -76,7 +76,7 @@ void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, << (i == num_edges - 1 ? "" : ","); } - LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between " + LOGS(logger, severity) << "Inserting Q/DQ pair between " << (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") : "input") << " and " << dst_labels.str() From 871779660c37147c28934d79bc1a8dd85b389af0 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Jul 2024 15:00:27 -0700 Subject: [PATCH 12/17] Copy node attributes when propagating Q or DQ ops. --- .../qdq_transformer/qdq_propagation.cc | 90 ++++++++++++++----- .../test/optimizer/qdq_transformer_test.cc | 51 +++++++++++ 2 files changed, 117 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index cd8bb0c0af3e7..3efd96f4e13a9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -3,6 +3,7 @@ #include "core/optimizer/qdq_transformer/qdq_propagation.h" +#include #include #include #include @@ -20,18 +21,50 @@ namespace onnxruntime { namespace { bool CanNodePropagate(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13, 21}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}); } +// Makes matching attributes for new QuantizeLinear nodes from an existing DequantizeLinear node. +NodeAttributes MakeQAttrsFromDQ(const Node& dq_node) { + assert(dq_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchDQNode(). + // In opset <= 21, all DQ attributes (i.e., axis and block_size) are also Q attributes. + // So, set a copy of the DQ attributes. + return dq_node.GetAttributes(); +} + +// Makes matching attributes for new DequantizeLinear nodes from an existing QuantizeLinear node. +NodeAttributes MakeDQAttrsFromQ(const Node& q_node) { + assert(q_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchQNode(). + const NodeAttributes& q_attrs = q_node.GetAttributes(); + if (q_attrs.empty()) { + return {}; + } + + // In opset <= 21, only the "axis" and "block_size" attributes for Q are also DQ attributes. + NodeAttributes dq_attrs; + + auto axis_attr_it = q_attrs.find("axis"); + if (axis_attr_it != q_attrs.end()) { + dq_attrs.insert({axis_attr_it->first, axis_attr_it->second}); + } + + auto block_size_attr_it = q_attrs.find("block_size"); + if (block_size_attr_it != q_attrs.end()) { + dq_attrs.insert({block_size_attr_it->first, block_size_attr_it->second}); + } + + return dq_attrs; +} + // Validates edges into which to insert Q -> DQ ops. // - Must have at least one edge. // - All edges with a source node must originate from the same source node's output. // - All edges must be attached to either a source node or a destination node. -Status ValidateQDQInsertionEdges(Graph& graph, const InlinedVector& insertion_edges) { +Status ValidateQDQInsertionEdges(Graph& graph, gsl::span insertion_edges) { ORT_RETURN_IF(insertion_edges.empty(), "Expected at least one edge into which to insert QDQ pair."); const auto& src_info = insertion_edges[0].GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); @@ -55,9 +88,10 @@ Status ValidateQDQInsertionEdges(Graph& graph, const InlinedVector& edges) { - if (!logger.OutputIsEnabled(severity, logging::DataType::SYSTEM)) { +void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, const CodeLocation& code_location, + const Graph& graph, gsl::span edges) { + auto logging_data_type = logging::DataType::SYSTEM; + if (!logger.OutputIsEnabled(severity, logging_data_type)) { return; } @@ -76,11 +110,12 @@ void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, << (i == num_edges - 1 ? "" : ","); } - LOGS(logger, severity) << "Inserting Q/DQ pair between " - << (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") - : "input") - << " and " << dst_labels.str() - << " at NodeArg \"" << node_arg_name << "\"."; + logging::Capture(logger, severity, logging::Category::onnxruntime, logging_data_type, code_location).Stream() + << "Inserted Q/DQ pair between " + << (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") + : "input") + << " and " << dst_labels.str() + << " at NodeArg \"" << node_arg_name << "\"."; } // convert this: src_node --+--> dst_node_0 @@ -98,9 +133,10 @@ void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, // 1. All insertion edges have the same source node and the same source node output index. // 2. Insertion_edges are valid: node indices refer to valid nodes, and arg names refer to valid NodeArgs in the graph. // 3. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers -Status InsertQDQPairs(Graph& graph, const InlinedVector& insertion_edges, +Status InsertQDQPairs(Graph& graph, gsl::span insertion_edges, NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, - const std::string& qdq_domain, const logging::Logger& logger) { + const std::string& qdq_domain, const NodeAttributes& q_attrs, const NodeAttributes& dq_attrs, + const logging::Logger& logger) { ORT_RETURN_IF_ERROR(ValidateQDQInsertionEdges(graph, insertion_edges)); const ExtendedGraphEdge& first_edge = insertion_edges[0]; // ValidateQDQInsertionEdges() guarantees at least one edge @@ -109,7 +145,7 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse const auto& base_name = first_edge.arg_name; auto& base_node_arg = *graph.GetNodeArg(base_name); - LogQDQInsertion(logger, logging::Severity::kVERBOSE, graph, insertion_edges); + LogQDQInsertion(logger, logging::Severity::kVERBOSE, ORT_WHERE, graph, insertion_edges); auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) { return zero_point ? InlinedVector{&data, &scale, zero_point} @@ -133,7 +169,7 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse zp_initializer_nodearg_ptr), // outputs {&q_to_dq_nodearg}, - nullptr, // attributes + &q_attrs, // attributes qdq_domain); ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(q_node), "Failed to set op schema for added Q node."); @@ -173,7 +209,7 @@ Status InsertQDQPairs(Graph& graph, const InlinedVector& inse zp_initializer_nodearg_ptr), // outputs {&post_dq_nodearg}, - nullptr, // attributes + &dq_attrs, // attributes qdq_domain); ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); @@ -324,10 +360,10 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, return false; }; - // Propagate DQ forward in a BFS traversal of "edge groups". A single edge group consists of one or more edges - // that all begin at a unique source node and end at one or more destination nodes. Ex: The subgraph below shows - // an edge group (containing 3 edges) that begins at a Transpose, ends at two destination nodes, and produces a - // graph output. + // Propagate DQ forward in a BFS traversal of "edge groups". An "edge group" consists of one or more edges + // that all begin at the same source node's output slot and end at a graph output or a destination node. + // Ex: The subgraph below shows an edge group (containing 3 edges) that begins at a + // Transpose, ends at two destination nodes, and produces a graph output. // DQ -> Transpose --+--> Sigmoid -> ... // | // +--> Slice -> ... @@ -340,11 +376,17 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, const InlinedVector curr_edge_group = std::move(edge_groups.front()); edge_groups.pop(); + // Continue loop if edge group is empty. Also, to keep things simple, we do not yet handle edge groups in which + // one of the destination nodes is already a QuantizeLinear node. Ex: + // DQ -> Transpose --+--> QuantizeLinear -> ... + // | + // +--> Slice -> ... if (curr_edge_group.empty() || any_edge_ends_in_q(graph, curr_edge_group)) { continue; } - ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, curr_edge_group, dq_scale, dq_zero_point, dq_node.Domain(), logger)); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, curr_edge_group, dq_scale, dq_zero_point, dq_node.Domain(), + MakeQAttrsFromDQ(dq_node), dq_node.GetAttributes(), logger)); modified = true; for (const auto& edge : curr_edge_group) { @@ -398,7 +440,7 @@ Status PropagateQBackward(Graph& graph, gsl::span node_indices, } ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, InlinedVector{*curr_edge}, q_scale, q_zero_point, - q_node.Domain(), logger)); + q_node.Domain(), q_node.GetAttributes(), MakeDQAttrsFromQ(q_node), logger)); modified = true; } } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 7d7d930c43f94..56d2bc15501de 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3085,6 +3085,57 @@ TEST(QDQTransformerTests, QDQPropagation_QBackward) { #endif } +// Test backwards propagation of a QuantizeLinear node that uses the "output_dtype" attribute +// to set the quantization type (i.e., does not have an explicit zero-point input). This tests +// the copying of attributes for QDQ propagation. +TEST(QDQTransformerTests, QDQPropagation_QBackward_NoZP_OutputDtypeAttribute) { + auto test_case = [&](ONNX_NAMESPACE::TensorProto_DataType q_output_type) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 2, 2}, {-2.0f, 0.0f, 1.0f, 2.0f}); + auto* output_arg = builder.MakeOutput(); + + // add Add + auto* const_1_input = builder.MakeScalarInitializer(1.0f); + auto* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input_arg, const_1_input}, {add_output}); + + // add Transpose + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {add_output}, {transpose_output}); + + // add Q with a "output_dtype" attribute. Omit the zero-point input (defaults to 0). + constexpr float qdq_scale = 1.0f; + Node& q_node = builder.AddQuantizeLinearNode(transpose_output, qdq_scale, output_arg); + q_node.AddAttribute("output_dtype", static_cast(q_output_type)); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + std::vector expected_op_types_in_order = { + "Add", + qdq_keys.quantize_linear, + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, + }; + + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 21); // Opset >= 21 supports the "output_dtype" attribute + }; + + test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + test_case(ONNX_NAMESPACE::TensorProto_DataType_INT8); + test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT16); + test_case(ONNX_NAMESPACE::TensorProto_DataType_INT16); +} + TEST(QDQTransformerTests, QDQPropagation_DQForward) { auto test_case = [&](const std::vector& input_shape, size_t maxpool_dim, From f4390e79f441375c39af2ba695e110118a89408f Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Jul 2024 16:16:42 -0700 Subject: [PATCH 13/17] Revise comments to explain that edge groups share the same NodeArg --- .../qdq_transformer/qdq_propagation.cc | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 3efd96f4e13a9..093e02d24205c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -129,10 +129,10 @@ void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, // +--> DQ -> dst_node_1 // | ... // +--> DQ -> dst_node_n -// assumptions: -// 1. All insertion edges have the same source node and the same source node output index. -// 2. Insertion_edges are valid: node indices refer to valid nodes, and arg names refer to valid NodeArgs in the graph. -// 3. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers +// Checks that all insertion edges share the same NodeArg. That is, the edges have the same source node and the +// same source node output index. This function returns an error status if edges are invalid. +// +// Assumes that scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers. Status InsertQDQPairs(Graph& graph, gsl::span insertion_edges, NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, const std::string& qdq_domain, const NodeAttributes& q_attrs, const NodeAttributes& dq_attrs, @@ -360,23 +360,23 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, return false; }; - // Propagate DQ forward in a BFS traversal of "edge groups". An "edge group" consists of one or more edges + // Propagate DQ forward in a BFS traversal of NodeArg edges. A NodeArg "edge group" consists of one or more edges // that all begin at the same source node's output slot and end at a graph output or a destination node. - // Ex: The subgraph below shows an edge group (containing 3 edges) that begins at a + // Ex: The subgraph below shows a NodeArg edge group (containing 3 edges) that begins at a // Transpose, ends at two destination nodes, and produces a graph output. // DQ -> Transpose --+--> Sigmoid -> ... // | // +--> Slice -> ... // | // +--> graph_output - std::queue> edge_groups; - edge_groups.push(GetNextPropagationEdges(graph, edges_after_dq[0])); + std::queue> node_arg_edges; + node_arg_edges.push(GetNextPropagationEdges(graph, edges_after_dq[0])); - while (!edge_groups.empty()) { - const InlinedVector curr_edge_group = std::move(edge_groups.front()); - edge_groups.pop(); + while (!node_arg_edges.empty()) { + const InlinedVector curr_edge_group = std::move(node_arg_edges.front()); + node_arg_edges.pop(); - // Continue loop if edge group is empty. Also, to keep things simple, we do not yet handle edge groups in which + // Skip if edge group is empty. Also, to keep things simple, we do not yet handle edge groups in which // one of the destination nodes is already a QuantizeLinear node. Ex: // DQ -> Transpose --+--> QuantizeLinear -> ... // | @@ -390,7 +390,7 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, modified = true; for (const auto& edge : curr_edge_group) { - edge_groups.push(GetNextPropagationEdges(graph, edge)); + node_arg_edges.push(GetNextPropagationEdges(graph, edge)); } } } From 51ace3aee17f3fea9a7cac3e13874edb86e2e851 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Jul 2024 16:49:09 -0700 Subject: [PATCH 14/17] Update comment --- .../qdq_transformer/qdq_propagation.cc | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 093e02d24205c..dbfa2e86830c8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -118,19 +118,20 @@ void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, << " at NodeArg \"" << node_arg_name << "\"."; } -// convert this: src_node --+--> dst_node_0 -// | -// +--> dst_node_1 -// | ... -// +--> dst_node_n +// convert this: src_node (or graph input) --+--> dst_node_0 (or graph output) +// | +// +--> dst_node_1 +// | ... +// +--> dst_node_n // -// to this: src_node -> Q --+--> DQ -> dst_node_0 -// | -// +--> DQ -> dst_node_1 -// | ... -// +--> DQ -> dst_node_n -// Checks that all insertion edges share the same NodeArg. That is, the edges have the same source node and the -// same source node output index. This function returns an error status if edges are invalid. +// to this: src_node (or graph input) -> Q --+--> DQ -> dst_node_0 (or graph output) +// | +// +--> DQ -> dst_node_1 +// | ... +// +--> DQ -> dst_node_n +// Checks that all insertion edges share the same NodeArg. That is, the edges originate from the same source node +// output. If there is no src_node, then all edges should come from the same graph input. +// This function returns an error status if edges are invalid. // // Assumes that scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers. Status InsertQDQPairs(Graph& graph, gsl::span insertion_edges, From 8fbe4e35f4b0ec8264f0be8be7e079ca9778cab4 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 22 Jul 2024 17:29:16 -0700 Subject: [PATCH 15/17] Fix GitHub cpplint warnings --- onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index dbfa2e86830c8..d676bd8170eb3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/common/inlined_containers_fwd.h" #include "core/graph/extended_graph_edge.h" @@ -75,7 +76,8 @@ Status ValidateQDQInsertionEdges(Graph& graph, gsl::spannode_idx == edge_src_info->node_idx && src_info->arg_idx == edge_src_info->arg_idx)), + (src_info->node_idx == edge_src_info->node_idx && + src_info->arg_idx == edge_src_info->arg_idx)), "Expect all insertion edges to come from the same source node's output slot."); const Node* edge_dst_node = insertion_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); From 0d58b261aaadfbb504a4a04387df7629e5e6c7c4 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 23 Jul 2024 23:21:18 -0700 Subject: [PATCH 16/17] Review suggestion: simplify edge validation by checking node args --- .../qdq_transformer/qdq_propagation.cc | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index d676bd8170eb3..48d450e487a4e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -63,26 +63,31 @@ NodeAttributes MakeDQAttrsFromQ(const Node& q_node) { // Validates edges into which to insert Q -> DQ ops. // - Must have at least one edge. -// - All edges with a source node must originate from the same source node's output. +// - All edges must correspond to the same graph NodeArg (i.e., same source but potentially different destination). // - All edges must be attached to either a source node or a destination node. Status ValidateQDQInsertionEdges(Graph& graph, gsl::span insertion_edges) { - ORT_RETURN_IF(insertion_edges.empty(), "Expected at least one edge into which to insert QDQ pair."); - - const auto& src_info = insertion_edges[0].GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); - const Node* src_node = src_info.has_value() ? graph.GetNode(src_info->node_idx) : nullptr; - - for (const auto& insertion_edge : insertion_edges) { - const auto& edge_src_info = insertion_edge.GetNodeInfoAtEnd(ExtendedGraphEdge::End::Source); - - ORT_RETURN_IF_NOT((edge_src_info.has_value() == src_info.has_value()) && - (!src_info.has_value() || - (src_info->node_idx == edge_src_info->node_idx && - src_info->arg_idx == edge_src_info->arg_idx)), - "Expect all insertion edges to come from the same source node's output slot."); + const size_t num_edges = insertion_edges.size(); + ORT_RETURN_IF(num_edges == 0, "Expected at least one edge into which to insert QDQ pair."); + + const ExtendedGraphEdge& first_edge = insertion_edges[0]; + const Node* src_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source); + const Node* first_dst_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + const std::string& node_arg_name = first_edge.arg_name; + ORT_RETURN_IF_NOT(graph.GetNodeArg(node_arg_name) != nullptr, + "QDQ insertion edge does not have a valid graph NodeArg for ", node_arg_name); + ORT_RETURN_IF_NOT(src_node != nullptr || first_dst_node != nullptr, + "NodeArg ", node_arg_name, " must have a source or a destination node"); + + for (size_t i = 1; i < num_edges; i++) { + const ExtendedGraphEdge& insertion_edge = insertion_edges[i]; + ORT_RETURN_IF_NOT(insertion_edge.arg_name == node_arg_name, + "QDQ insertion edge [", i, "] has NodeArg ", insertion_edge.arg_name, + " but expected NodeArg ", node_arg_name); const Node* edge_dst_node = insertion_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); ORT_RETURN_IF_NOT(src_node != nullptr || edge_dst_node != nullptr, - "At least one graph node must be specified in the propagation edges."); + "QDQ insertion edge [", i, "] for NodeArg ", node_arg_name, + " must have a source or a destination node"); } return Status::OK(); From a3a36ed7d890930fb75bb97257710795840599eb Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 24 Jul 2024 09:21:09 -0700 Subject: [PATCH 17/17] Update onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 48d450e487a4e..7b518947138a5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -76,7 +76,8 @@ Status ValidateQDQInsertionEdges(Graph& graph, gsl::span