diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp index 7d8ed17e7d925..013ad949c1c3f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp @@ -415,7 +415,7 @@ template void PopulateEdges( if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end()) { throw std::range_error("Neither there is any graph input with name " + edgeName->str() + - "nor there is any node which has " + edgeName->str() + " as one of the output."); + " nor there is any node which has " + edgeName->str() + " as one of the output."); } auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()]; DmlIntermediateSerializedGraphEdge intermediateEdge = {}; @@ -475,6 +475,7 @@ DmlSerializedGraphDesc DeserializeDmlGraph( inputEdges, intermediateEdges, edgeToOutgoingNodeIndexMap); + PopulateEdges( nodeIndex, flatbufferNode->outputNames(), @@ -482,7 +483,7 @@ DmlSerializedGraphDesc DeserializeDmlGraph( outputEdges, intermediateEdges, edgeToOutgoingNodeIndexMap); - + DmlSerializedGraphNode node = {}; if (flatbufferNode->name()->size() == 0) { @@ -503,7 +504,7 @@ DmlSerializedGraphDesc DeserializeDmlGraph( ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()}; node.Desc = constantNode; - // output of this node will part of constantInputs list + // Output of this node will be part of constantInputs list. for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++) { constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 202b762d99e01..27168bc8e9763 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -596,37 +596,37 @@ namespace DmlGraphFusionHelper const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { - if (graphSerializationEnabled) - { - - const std::wstring modelName = GetModelName(graph.ModelPath()); - auto buffer = SerializeDmlGraph(graphDesc); - - const std::wstring partitionName = - L"Partition_" + - std::to_wstring(partitionIndex) + - L".bin"; - WriteToFile(modelName, partitionName, buffer.data(), buffer.size()); - - std::vector> rawData; - DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData); - GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {}; - deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount; - deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges); - deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges); - deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes); - deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount; - deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges); - deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList; - deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes; - - compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( - deserializedDmlGraphDesc, - indexedSubGraph, - providerImpl, - serializedGraphInputIndexToSubgraphInputIndex, - serializedGraphLargeConstantNameToSubgraphInputIndex); - } + if (graphSerializationEnabled) + { + + const std::wstring modelName = GetModelName(graph.ModelPath()); + auto buffer = SerializeDmlGraph(graphDesc); + + const std::wstring partitionName = + L"Partition_" + + std::to_wstring(partitionIndex) + + L".bin"; + WriteToFile(modelName, partitionName, buffer.data(), buffer.size()); + + std::vector> rawData; + DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData); + GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {}; + deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount; + deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges); + deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges); + deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes); + deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount; + deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges); + deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList; + deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes; + + compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + deserializedDmlGraphDesc, + indexedSubGraph, + providerImpl, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex); + } auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp index 5355964e8db74..ed406fa259fe6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp @@ -517,6 +517,9 @@ flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& grap return builder.Release(); } + std::vector nodesInTopologicalOrder(graphDesc.Nodes.size()); + PerformTopologicalSortAndCheckIsAcyclic(graphDesc, nodesInTopologicalOrder); + // create input/output edge index to name map std::unordered_map> graphInputIndexToNameMap = ConvertToEdgeIndexToNameMap(graphDesc.InputEdges, builder); @@ -548,14 +551,14 @@ flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& grap // Create flatbuffer node objects std::vector> nodes(graphDesc.Nodes.size()); - for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(nodesInTopologicalOrder.size()); nodeIndex++) { nodes[nodeIndex] = SerializeNode( builder, - nodeIndex, - graphDesc.Nodes[nodeIndex], - nodeToInputNames[nodeIndex], - nodeToOutputNames[nodeIndex]); + nodesInTopologicalOrder[nodeIndex], + graphDesc.Nodes[nodesInTopologicalOrder[nodeIndex]], + nodeToInputNames[nodesInTopologicalOrder[nodeIndex]], + nodeToOutputNames[nodesInTopologicalOrder[nodeIndex]]); } // Convert to std::vector to create the object. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphHelper.h new file mode 100644 index 0000000000000..d2dd7cd8eff1b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphHelper.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include + +inline void PerformTopologicalSortAndCheckIsAcyclic( + const DmlSerializedGraphDesc& graphDesc, + std::vector& nodesInTopologicalOrder) +{ + uint32_t nodeCount = static_cast(graphDesc.Nodes.size()); + std::queue queue; + std::vector inDegree(nodeCount, 0); + std::vector> children(nodeCount); + + // Don't need to iterate through InputEdges because those inputs don't represent any node + // and the purpose of this topological sort is to come up with a order to correctly iterate + // through nodes . + for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphDesc.IntermediateEdges) + { + inDegree[intermediateEdge.ToNodeIndex]++; + children[intermediateEdge.FromNodeIndex].push_back(intermediateEdge.ToNodeIndex); + } + + for (uint32_t nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++) + { + if (inDegree[nodeIndex] == 0) + { + queue.push(nodeIndex); + } + } + + uint32_t nodeIndex = 0; + while (!queue.empty()) + { + if (nodeIndex >= nodeCount) + { + throw std::invalid_argument("Given graph is not acyclic."); + } + + uint32_t currNodeIndex = queue.front(); + queue.pop(); + nodesInTopologicalOrder[nodeIndex++] = currNodeIndex; + + for (uint32_t child : children[currNodeIndex]) + { + if (--inDegree[child] == 0) + { + queue.push(child); + } + } + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index e6f008af5c23f..a346c0c9fb17a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -36,7 +36,9 @@ namespace Dml::GraphDescBuilder std::vector& graphNodes, std::vector& graphInputEdges, std::vector& graphIntermediateEdges, - std::vector& graphOutputEdges) + std::vector& graphOutputEdges, + std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex) { enum class NodeState { @@ -124,8 +126,10 @@ namespace Dml::GraphDescBuilder graphNodes.resize(graphNodes.size() - shift); // Adjust the node indices in the input edges + std::unordered_set usedInputEdgeIndex; for (auto& inputEdge : graphInputEdges) { + usedInputEdgeIndex.insert(inputEdge.GraphInputIndex); inputEdge.ToNodeIndex = shiftedIndicesMapping[inputEdge.ToNodeIndex]; } @@ -136,10 +140,54 @@ namespace Dml::GraphDescBuilder } // Adjust the node indices in the intermediate edges + std::unordered_set usedLargeConstantNames; for (auto& intermediateEdge : graphIntermediateEdges) { intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex]; intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex]; + // We need to update the edge name only when the name contains the intermediateEdge.FromNodeIndex + size_t pos = intermediateEdge.Name.find("nodeIdx:"); + if (pos != std::string::npos) + { + if (pos != 0) + { + std::string constantNamePartComingFromModel = intermediateEdge.Name.substr(0, pos - 1); + usedLargeConstantNames.insert(constantNamePartComingFromModel); // need part of name which is coming from the model. + intermediateEdge.Name = constantNamePartComingFromModel; + intermediateEdge.Name += "-nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex); + } + else + { + intermediateEdge.Name = "nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex); + } + } + } + + + // Erase the mapping if the input Edge is not used by any node + for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end();) + { + if (!usedInputEdgeIndex.count(it->first)) + { + it = serializedGraphInputIndexToSubgraphInputIndex.erase(it); + } + else + { + it++; + } + } + + // Erase the mapping if the input Edge is not used by any node + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end();) + { + if (!usedLargeConstantNames.count(std::string(it->first))) + { + it = serializedGraphLargeConstantNameToSubgraphInputIndex.erase(it); + } + else + { + it++; + } } } @@ -516,7 +564,12 @@ namespace Dml::GraphDescBuilder graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } - RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges); + RemoveUnconnectedNodes(dmlGraphNodes, + dmlGraphInputEdges, + dmlGraphIntermediateEdges, + dmlGraphOutputEdges, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex); GraphDesc graphDesc{}; graphDesc.InputCount = static_cast(dmlGraphInputEdges.size()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h index 1a796b25c5d1f..5873157272e3c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h @@ -59,6 +59,7 @@ #include "External/DirectMLHelpers/DmlSerializedGraphDesc.h" #include "External/DirectMLHelpers/DmlGraphSerialization.h" #include "External/DirectMLHelpers/DmlGraphDeserialization.h" +#include "External/DirectMLHelpers/DmlGraphHelper.h" using Microsoft::WRL::ComPtr;