From 01acc25d9dd695255ae5c44ec377576408e1081c Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 12 Apr 2024 16:34:49 -0700 Subject: [PATCH] [DML EP] Fix the output shapes of nodes with multiple outputs in the graph builder (#20289) The graph builder currently doesn't assign the correct shapes for subgraphs that have more than 1 output, and where each output comes from a different node. `nodeOutputShapes` should be a map of shapes (1:1 relationship), and not a map of lists of shapes (1:N relationship) since an output referenced by `arg->Name()` can only have 1 output. Take for example the following example of a subgraph where a node has 2 outputs, then each output feeds into an elementwise op. Both nodes will have a `targetIndex` of 0, and we were using this target index to query their shape, resulting in both outputs querying the same shape. In reality, what we need to do is use the `GraphOutputIndex` ofthe subgraph to query the correct output shape of the subgraph. --- .../src/GraphDescBuilder.cpp | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index a346c0c9fb17a..3b0dbd542547c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -176,7 +176,7 @@ namespace Dml::GraphDescBuilder it++; } } - + // Erase the mapping if the input Edge is not used by any node for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end();) { @@ -236,25 +236,25 @@ namespace Dml::GraphDescBuilder uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node) }; - std::unordered_map nodeOutputShapes; + std::unordered_map> nodeOutputShapes; // Map from ORT subgraph input names to indices std::unordered_map subgraphInputNameToIndexMap; - + // - Map from ORT node's output names to DmlGraph . // - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph, // then ORT node's output names will become output edges for the operatorDmlGraph. // - This map will be populated for those output edges. std::unordered_map dmlGraphNodeOutputNameToNodeAndIndexMap; - + // This map will be used to re-index an subGraphInputIndex to sequential input index // for DmlGraph std::unordered_map subGraphInputIndexToDmlGraphInputIndex; - + // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order std::unordered_map> inferredOutputShapes; - + std::vector dmlGraphNodes; std::vector dmlGraphInputEdges; std::vector dmlGraphIntermediateEdges; @@ -357,8 +357,8 @@ namespace Dml::GraphDescBuilder for (int i = 0; i < node.OutputDefs().size(); ++i) { inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); - } - + } + // Algorithm: // 1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it, // because there would be an intermediate edge from the constantNode and source of the intermediate edge @@ -367,7 +367,7 @@ namespace Dml::GraphDescBuilder // 3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges. // 4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex // 5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list. - + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) { const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; @@ -381,8 +381,8 @@ namespace Dml::GraphDescBuilder DmlSerializedGraphNode constantNode = {}; constantNode.Name = arg->Name(); - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex]; std::vector toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors(); @@ -441,10 +441,10 @@ namespace Dml::GraphDescBuilder if (iter != subgraphInputNameToIndexMap.end()) { const uint32_t subgraphInputIndex = iter->second; - + // Either this edge will be - // a constant input, then it will be an intermediate edge and - // set the OWNED_BY_DML flag if it is large constant + // a constant input, then it will be an intermediate edge and + // set the OWNED_BY_DML flag if it is large constant // or, // a non-constant input, then it will be a mainDmlGraphInputEdge. if (subgraphInputIndex < isConstGpuGraphInputCount && @@ -526,7 +526,7 @@ namespace Dml::GraphDescBuilder edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex); dmlGraphIntermediateEdges.push_back(edge); } - + // populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges) { @@ -540,7 +540,7 @@ namespace Dml::GraphDescBuilder operatorDmlGraphToDmlGraphNodeIndexMap, dmlGraphNodes); dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex}; - nodeOutputShapes[arg->Name()] = outputShapes; + nodeOutputShapes[arg->Name()] = outputShapes.GetShape(operatorGraphOutputEdge.GraphOutputIndex); } } } @@ -561,7 +561,7 @@ namespace Dml::GraphDescBuilder edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); edge.Name = graphOutput->Name(); dmlGraphOutputEdges.push_back(edge); - graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()]; } RemoveUnconnectedNodes(dmlGraphNodes,