From 749bcc793768207068eab146d5ba5be39380e784 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 31 Oct 2023 10:51:22 -0700 Subject: [PATCH] [DML EP] Add subgraph fusion support (#18125) --- .../src/DmlGraphFusionTransformer.cpp | 37 +++++++++++++- .../src/DmlGraphFusionTransformer.h | 49 +++++++++++-------- .../src/GraphPartitioner.cpp | 18 +++---- .../src/GraphPartitioner.h | 5 +- 4 files changed, 76 insertions(+), 33 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index a9d19a022d3e7..4813707cdf50c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -38,6 +38,16 @@ namespace Dml bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graph_level, logger, {}); + } + + onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const { onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider; const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); @@ -49,6 +59,30 @@ namespace Dml std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; + onnxruntime::GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) + { + auto* node = graph.GetNode(node_index); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs)); + } + } + do { // Initializers needed by any graph partition @@ -62,7 +96,8 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, - additionalSplittingNodes); + additionalSplittingNodes, + implicitInputDefs); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index b546f29f59719..19dab0c89943c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -2,32 +2,41 @@ // Licensed under the MIT License. #pragma once - +#include +#include #include "core/optimizer/graph_transformer.h" #include "core/framework/execution_providers.h" namespace Dml { - class ExecutionProviderImpl; +class ExecutionProviderImpl; + +class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain"; - class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer - { - public: - DmlGraphFusionTransformer( - const std::string& name, - const onnxruntime::IExecutionProvider* provider - ); +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger) const final; - public: - inline const static char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_"; - inline const static char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain"; + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; - private: - onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, - bool& modified, - int graph_level, - const onnxruntime::logging::Logger& logger) const final; - private: - const ExecutionProviderImpl* m_providerImpl = nullptr; - }; +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 2c8d4e4459f7f..18943878ccedc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -345,13 +345,8 @@ namespace Dml // Whether any operator in the model contains a subgraph. This is true // if the graph being partitioned is itself within a subgraph, or contains // an operator with a subgraph. - bool ModelUsesSubgraph(const onnxruntime::GraphViewer& graph) + bool ContainsSubgraph(const onnxruntime::GraphViewer& graph) { - if (graph.IsSubgraph()) - { - return true; - } - const std::vector& toplogicalOrder = graph.GetNodesInTopologicalOrder(); for (size_t nodeIndex : toplogicalOrder) @@ -384,7 +379,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - gsl::span additionalSplittingNodes) + gsl::span additionalSplittingNodes, + const std::unordered_map& implicitInputs) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -419,7 +415,7 @@ namespace Dml } // Check whether this graph is a subgraph, or contains any node with a subgraph. - bool modelUsesSubgraph = ModelUsesSubgraph(graph); + bool containsSubgraph = ContainsSubgraph(graph); uint32_t splittingNodeIndex = 0; @@ -454,10 +450,10 @@ namespace Dml // Add a unique partition if graph node usage is not supported. // // Partitioning is disabled in models with subgraphs to work around issues with implicit inputs. - // The partitioning algorithm does not currently consider such inputs. Transfering shared initializers + // The partitioning algorithm does not currently consider such inputs. Transferring shared initializers // for partitions could also cause problems. Note, operators with subgraphs are currently not efficient // anyhow due to CPU/GPU copies. - if (modelUsesSubgraph || !isDmlGraphNode) + if (containsSubgraph || !isDmlGraphNode) { partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap)); continue; @@ -505,7 +501,7 @@ namespace Dml firstNonFinalInputPartition->AddInput(arg->Name()); } - if (graphInputs.find(arg->Name()) != graphInputs.end()) + if (graphInputs.find(arg->Name()) != graphInputs.end() || implicitInputs.find(arg->Name()) != implicitInputs.end()) { firstNonFinalInputPartition->AddInput(arg->Name()); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 990ba00fc4672..37d577f647fb5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include "core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h" namespace Dml @@ -48,5 +50,6 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - gsl::span additionalSplittingNodes); + gsl::span additionalSplittingNodes, + const std::unordered_map& implicitInputs); } // namespace Dml