Skip to content

Commit

Permalink
[DML EP] Add subgraph fusion support (#18125)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Oct 31, 2023
1 parent 6ae7c51 commit 749bcc7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ namespace Dml
bool& modified,
int graph_level,
const onnxruntime::logging::Logger& logger) const
{
return ApplyImplHelper(graph, modified, graph_level, logger, {});
}

onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImplHelper(
onnxruntime::Graph& graph,
bool& modified,
int graph_level,
const onnxruntime::logging::Logger& logger,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const
{
onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider;
const gsl::not_null<const onnxruntime::KernelRegistry*> registry = m_providerImpl->GetKernelRegistry().get();
Expand All @@ -49,6 +59,30 @@ namespace Dml
std::vector<std::shared_ptr<CompiledPartitionInfo>> compiledPartitionInfos;
std::vector<onnxruntime::NodeIndex> additionalSplittingNodes;

onnxruntime::GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

for (auto node_index : node_topology_list)
{
auto* node = graph.GetNode(node_index);
if (!node)
{
continue; // node was removed
}

std::unordered_map<std::string, const onnxruntime::NodeArg*> subgraphImplicitInputDefs;
for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs())
{
subgraphImplicitInputDefs[inputDef->Name()] = inputDef;
}

for (auto& entry : node->GetAttributeNameToMutableSubgraphMap())
{
auto& subgraph = *entry.second;
ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs));
}
}

do
{
// Initializers needed by any graph partition
Expand All @@ -62,7 +96,8 @@ namespace Dml
m_providerImpl->GetSupportedDeviceDataTypeMask(),
graphNodePropertyMap,
requiredInitializerMap,
additionalSplittingNodes);
additionalSplittingNodes,
implicitInputDefs);

// Reset the splitting nodes for the current iteration
additionalSplittingNodes.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,41 @@
// Licensed under the MIT License.
#pragma once


#include <string>
#include <unordered_map>
#include "core/optimizer/graph_transformer.h"
#include "core/framework/execution_providers.h"

namespace Dml
{
class ExecutionProviderImpl;
class ExecutionProviderImpl;

class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer
{
public:
DmlGraphFusionTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
);

public:
static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_";
static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain";

class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer
{
public:
DmlGraphFusionTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
);
private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph,
bool& modified,
int graph_level,
const onnxruntime::logging::Logger& logger) const final;

public:
inline const static char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_";
inline const static char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain";
onnxruntime::common::Status ApplyImplHelper(
onnxruntime::Graph& graph,
bool& modified,
int graph_level,
const onnxruntime::logging::Logger& logger,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputDefs) const;

private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph,
bool& modified,
int graph_level,
const onnxruntime::logging::Logger& logger) const final;
private:
const ExecutionProviderImpl* m_providerImpl = nullptr;
};
private:
const ExecutionProviderImpl* m_providerImpl = nullptr;
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,8 @@ namespace Dml
// Whether any operator in the model contains a subgraph. This is true
// if the graph being partitioned is itself within a subgraph, or contains
// an operator with a subgraph.
bool ModelUsesSubgraph(const onnxruntime::GraphViewer& graph)
bool ContainsSubgraph(const onnxruntime::GraphViewer& graph)
{
if (graph.IsSubgraph())
{
return true;
}

const std::vector<onnxruntime::NodeIndex>& toplogicalOrder = graph.GetNodesInTopologicalOrder();

for (size_t nodeIndex : toplogicalOrder)
Expand Down Expand Up @@ -384,7 +379,8 @@ namespace Dml
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_set<std::string>& requiredInitializerMap,
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes)
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs)
{
// Nodes are uniquely identified by the name of their first output argument
std::vector<std::unique_ptr<GraphPartition>> partitions;
Expand Down Expand Up @@ -419,7 +415,7 @@ namespace Dml
}

// Check whether this graph is a subgraph, or contains any node with a subgraph.
bool modelUsesSubgraph = ModelUsesSubgraph(graph);
bool containsSubgraph = ContainsSubgraph(graph);

uint32_t splittingNodeIndex = 0;

Expand Down Expand Up @@ -454,10 +450,10 @@ namespace Dml
// Add a unique partition if graph node usage is not supported.
//
// Partitioning is disabled in models with subgraphs to work around issues with implicit inputs.
// The partitioning algorithm does not currently consider such inputs. Transfering shared initializers
// The partitioning algorithm does not currently consider such inputs. Transferring shared initializers
// for partitions could also cause problems. Note, operators with subgraphs are currently not efficient
// anyhow due to CPU/GPU copies.
if (modelUsesSubgraph || !isDmlGraphNode)
if (containsSubgraph || !isDmlGraphNode)
{
partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap));
continue;
Expand Down Expand Up @@ -505,7 +501,7 @@ namespace Dml
firstNonFinalInputPartition->AddInput(arg->Name());
}

if (graphInputs.find(arg->Name()) != graphInputs.end())
if (graphInputs.find(arg->Name()) != graphInputs.end() || implicitInputs.find(arg->Name()) != implicitInputs.end())
{
firstNonFinalInputPartition->AddInput(arg->Name());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <string>
#include <unordered_map>
#include "core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h"

namespace Dml
Expand Down Expand Up @@ -48,5 +50,6 @@ namespace Dml
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_set<std::string>& requiredInitializerMap,
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes);
gsl::span<const onnxruntime::NodeIndex> additionalSplittingNodes,
const std::unordered_map<std::string, const onnxruntime::NodeArg*>& implicitInputs);
} // namespace Dml

0 comments on commit 749bcc7

Please sign in to comment.