diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp index 7d8ed17e7d925..128123fa32f12 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp @@ -415,7 +415,7 @@ template void PopulateEdges( if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end()) { throw std::range_error("Neither there is any graph input with name " + edgeName->str() + - "nor there is any node which has " + edgeName->str() + " as one of the output."); + " nor there is any node which has " + edgeName->str() + " as one of the output."); } auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()]; DmlIntermediateSerializedGraphEdge intermediateEdge = {}; @@ -464,17 +464,11 @@ DmlSerializedGraphDesc DeserializeDmlGraph( std::vector outputEdges; std::vector intermediateEdges; + // Iterate the output edges first because nodes are not in topologically sorted. for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++) { const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex); - PopulateEdges( - nodeIndex, - flatbufferNode->inputNames(), - graphInputEdgeToIndexMap, - inputEdges, - intermediateEdges, - edgeToOutgoingNodeIndexMap); PopulateEdges( nodeIndex, flatbufferNode->outputNames(), @@ -482,7 +476,20 @@ DmlSerializedGraphDesc DeserializeDmlGraph( outputEdges, intermediateEdges, edgeToOutgoingNodeIndexMap); + } + + for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++) + { + const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex); + PopulateEdges( + nodeIndex, + flatbufferNode->inputNames(), + graphInputEdgeToIndexMap, + inputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + DmlSerializedGraphNode node = {}; if (flatbufferNode->name()->size() == 0) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index e6f008af5c23f..f8347544a32be 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -140,6 +140,16 @@ namespace Dml::GraphDescBuilder { intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex]; intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex]; + std::string oldEdgeName = intermediateEdge.Name; + size_t pos = oldEdgeName.find("nodeIdx:"); + if (pos != std::string::npos) + { + if (pos != 0) + { + intermediateEdge.Name = oldEdgeName.substr(0, pos); + } + intermediateEdge.Name += "nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex); + } } }