Skip to content

Commit

Permalink
[DML EP] Fix the output shapes of nodes with multiple outputs in the …
Browse files Browse the repository at this point in the history
…graph builder (#20289)

The graph builder currently doesn't assign the correct shapes for
subgraphs that have more than 1 output, and where each output comes from
a different node. `nodeOutputShapes` should be a map of shapes (1:1
relationship), and not a map of lists of shapes (1:N relationship) since
an output referenced by `arg->Name()` can only have 1 output.

Take for example the following example of a subgraph where a node has 2
outputs, then each output feeds into an elementwise op. Both nodes will
have a `targetIndex` of 0, and we were using this target index to query
their shape, resulting in both outputs querying the same shape. In
reality, what we need to do is use the `GraphOutputIndex` ofthe subgraph
to query the correct output shape of the subgraph.
  • Loading branch information
PatriceVignola authored Apr 12, 2024
1 parent b33216b commit 01acc25
Showing 1 changed file with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ namespace Dml::GraphDescBuilder
it++;
}
}

// Erase the mapping if the input Edge is not used by any node
for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end();)
{
Expand Down Expand Up @@ -236,25 +236,25 @@ namespace Dml::GraphDescBuilder
uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node)
};

std::unordered_map<std::string, EdgeShapes> nodeOutputShapes;
std::unordered_map<std::string, std::vector<uint32_t>> nodeOutputShapes;

// Map from ORT subgraph input names to indices
std::unordered_map<std::string_view, uint32_t> subgraphInputNameToIndexMap;

// - Map from ORT node's output names to DmlGraph <NodeAndIndex>.
// - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph,
// then ORT node's output names will become output edges for the operatorDmlGraph.
// - This map will be populated for those output edges.
std::unordered_map<std::string, NodeAndIndex> dmlGraphNodeOutputNameToNodeAndIndexMap;

// This map will be used to re-index an subGraphInputIndex to sequential input index
// for DmlGraph
std::unordered_map<uint32_t, uint32_t> subGraphInputIndexToDmlGraphInputIndex;

// Iterate through each node and create a corresponding node in the new graph
// We can iterate the nodes in any order because the edge connectivity will take care of the topological order
std::unordered_map<std::string, std::vector<uint32_t>> inferredOutputShapes;

std::vector<DmlSerializedGraphNode> dmlGraphNodes;
std::vector<DmlInputSerializedGraphEdge> dmlGraphInputEdges;
std::vector<DmlIntermediateSerializedGraphEdge> dmlGraphIntermediateEdges;
Expand Down Expand Up @@ -357,8 +357,8 @@ namespace Dml::GraphDescBuilder
for (int i = 0; i < node.OutputDefs().size(); ++i)
{
inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i);
}
}

// Algorithm:
// 1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it,
// because there would be an intermediate edge from the constantNode and source of the intermediate edge
Expand All @@ -367,7 +367,7 @@ namespace Dml::GraphDescBuilder
// 3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges.
// 4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex
// 5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list.

for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges)
{
const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex];
Expand All @@ -381,8 +381,8 @@ namespace Dml::GraphDescBuilder
DmlSerializedGraphNode constantNode = {};
constantNode.Name = arg->Name();

// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex];
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors();
Expand Down Expand Up @@ -441,10 +441,10 @@ namespace Dml::GraphDescBuilder
if (iter != subgraphInputNameToIndexMap.end())
{
const uint32_t subgraphInputIndex = iter->second;

// Either this edge will be
// a constant input, then it will be an intermediate edge and
// set the OWNED_BY_DML flag if it is large constant
// a constant input, then it will be an intermediate edge and
// set the OWNED_BY_DML flag if it is large constant
// or,
// a non-constant input, then it will be a mainDmlGraphInputEdge.
if (subgraphInputIndex < isConstGpuGraphInputCount &&
Expand Down Expand Up @@ -526,7 +526,7 @@ namespace Dml::GraphDescBuilder
edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex);
dmlGraphIntermediateEdges.push_back(edge);
}

// populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges
for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges)
{
Expand All @@ -540,7 +540,7 @@ namespace Dml::GraphDescBuilder
operatorDmlGraphToDmlGraphNodeIndexMap,
dmlGraphNodes);
dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex};
nodeOutputShapes[arg->Name()] = outputShapes;
nodeOutputShapes[arg->Name()] = outputShapes.GetShape(operatorGraphOutputEdge.GraphOutputIndex);
}
}
}
Expand All @@ -561,7 +561,7 @@ namespace Dml::GraphDescBuilder
edge.GraphOutputIndex = gsl::narrow_cast<uint32_t>(outputIndex);
edge.Name = graphOutput->Name();
dmlGraphOutputEdges.push_back(edge);
graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex);
graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()];
}

RemoveUnconnectedNodes(dmlGraphNodes,
Expand Down

0 comments on commit 01acc25

Please sign in to comment.