Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML EP] DML Graph Serialization Bug #19748

Merged
merged 21 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4dbe6c9
Fix crash in PadFusion transform (#18544)
jeffbloo Nov 21, 2023
8c12d30
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Feb 7, 2024
2df2f57
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Feb 14, 2024
0b32f3e
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Feb 15, 2024
520e3e8
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Feb 17, 2024
4b7df02
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Feb 23, 2024
ca03fdb
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Feb 26, 2024
0c797cb
Fix serialization bug
sumitsays Mar 1, 2024
9f8cf28
Typo in the comment
sumitsays Mar 1, 2024
23affee
Added comment
sumitsays Mar 1, 2024
b660772
remove redundant copy of the edge name
sumitsays Mar 1, 2024
efa5741
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Mar 8, 2024
4bc0ab0
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Mar 12, 2024
491bf66
Fix naming of the edge.
sumitsays Mar 12, 2024
2ad427f
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Mar 28, 2024
bc9f2c2
Merge branch 'main' into user/sumita/fixSerializationbug
sumitsays Mar 28, 2024
335853c
Add topological sort
sumitsays Mar 28, 2024
ef37d2d
Merge branch 'main' of https://github.com/microsoft/onnxruntime
sumitsays Mar 28, 2024
ea21071
Merge branch 'main' into user/sumita/fixSerializationbug
sumitsays Mar 28, 2024
89b6870
Remove un-used input edge index and name from the corresponding map
sumitsays Mar 29, 2024
69d9fd6
Fixed string matching
sumitsays Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end())
{
throw std::range_error("Neither there is any graph input with name " + edgeName->str() +
"nor there is any node which has " + edgeName->str() + " as one of the output.");
" nor there is any node which has " + edgeName->str() + " as one of the output.");
}
auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()];
DmlIntermediateSerializedGraphEdge intermediateEdge = {};
Expand Down Expand Up @@ -475,14 +475,15 @@
inputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);

PopulateEdges<DmlOutputSerializedGraphEdge>(
nodeIndex,
flatbufferNode->outputNames(),
graphOutputEdgeToIndexMap,
outputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);

Check warning on line 486 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp:486: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
DmlSerializedGraphNode node = {};
if (flatbufferNode->name()->size() == 0)
{
Expand All @@ -503,7 +504,7 @@

ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()};
node.Desc = constantNode;
// output of this node will part of constantInputs list
// Output of this node will be part of constantInputs list.
for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++)
{
constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,37 +596,37 @@
const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex)
{
if (graphSerializationEnabled)
{

const std::wstring modelName = GetModelName(graph.ModelPath());
auto buffer = SerializeDmlGraph(graphDesc);

const std::wstring partitionName =
L"Partition_" +
std::to_wstring(partitionIndex) +
L".bin";
WriteToFile(modelName, partitionName, buffer.data(), buffer.size());

std::vector<std::unique_ptr<std::byte[]>> rawData;
DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData);
GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {};
deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount;
deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges);
deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges);
deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes);
deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount;
deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges);
deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList;
deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes;

compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
deserializedDmlGraphDesc,
indexedSubGraph,
providerImpl,
serializedGraphInputIndexToSubgraphInputIndex,
serializedGraphLargeConstantNameToSubgraphInputIndex);
}
if (graphSerializationEnabled)
{

Check warning on line 601 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp:601: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]

Check warning on line 601 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp:601: Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
const std::wstring modelName = GetModelName(graph.ModelPath());
auto buffer = SerializeDmlGraph(graphDesc);

Check warning on line 604 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp:604: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
const std::wstring partitionName =
L"Partition_" +
std::to_wstring(partitionIndex) +
L".bin";
WriteToFile(modelName, partitionName, buffer.data(), buffer.size());

Check warning on line 610 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp:610: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
std::vector<std::unique_ptr<std::byte[]>> rawData;
DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData);
GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {};
deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount;
deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges);
deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges);
deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes);
deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount;
deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges);
deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList;
deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes;

Check warning on line 622 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp:622: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
deserializedDmlGraphDesc,
indexedSubGraph,
providerImpl,
serializedGraphInputIndexToSubgraphInputIndex,
serializedGraphLargeConstantNameToSubgraphInputIndex);
}

auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name);
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& grap
return builder.Release();
}

std::vector<uint32_t> nodesInTopologicalOrder(graphDesc.Nodes.size());
PerformTopologicalSortAndCheckIsAcyclic(graphDesc, nodesInTopologicalOrder);

// create input/output edge index to name map
std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>> graphInputIndexToNameMap =
ConvertToEdgeIndexToNameMap<DmlInputSerializedGraphEdge>(graphDesc.InputEdges, builder);
Expand Down Expand Up @@ -548,14 +551,14 @@ flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& grap

// Create flatbuffer node objects
std::vector<flatbuffers::Offset<dml::ir::DmlGraphNode>> nodes(graphDesc.Nodes.size());
for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(graphDesc.Nodes.size()); nodeIndex++)
for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(nodesInTopologicalOrder.size()); nodeIndex++)
{
nodes[nodeIndex] = SerializeNode(
builder,
nodeIndex,
graphDesc.Nodes[nodeIndex],
nodeToInputNames[nodeIndex],
nodeToOutputNames[nodeIndex]);
nodesInTopologicalOrder[nodeIndex],
graphDesc.Nodes[nodesInTopologicalOrder[nodeIndex]],
nodeToInputNames[nodesInTopologicalOrder[nodeIndex]],
nodeToOutputNames[nodesInTopologicalOrder[nodeIndex]]);
}

// Convert to std::vector to create the <dml::ir::DmlGraphDesc> object.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

#pragma once
#include <queue>

inline void PerformTopologicalSortAndCheckIsAcyclic(
const DmlSerializedGraphDesc& graphDesc,
std::vector<uint32_t>& nodesInTopologicalOrder)
{
uint32_t nodeCount = static_cast<uint32_t>(graphDesc.Nodes.size());
std::queue<uint32_t> queue;
std::vector<uint32_t> inDegree(nodeCount, 0);
std::vector<std::vector<uint32_t>> children(nodeCount);

Check warning on line 13 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphHelper.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphHelper.h:13: Add #include <vector> for vector<> [build/include_what_you_use] [4]

// Don't need to iterate through InputEdges because those inputs don't represent any node
// and the purpose of this topological sort is to come up with a order to correctly iterate
// through nodes .
for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphDesc.IntermediateEdges)
{
inDegree[intermediateEdge.ToNodeIndex]++;
children[intermediateEdge.FromNodeIndex].push_back(intermediateEdge.ToNodeIndex);
}

for (uint32_t nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++)
{
if (inDegree[nodeIndex] == 0)
{
queue.push(nodeIndex);
}
}

uint32_t nodeIndex = 0;
while (!queue.empty())
{
if (nodeIndex >= nodeCount)
{
throw std::invalid_argument("Given graph is not acyclic.");
}

uint32_t currNodeIndex = queue.front();
queue.pop();
nodesInTopologicalOrder[nodeIndex++] = currNodeIndex;

for (uint32_t child : children[currNodeIndex])
{
if (--inDegree[child] == 0)
{
queue.push(child);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
std::vector<DmlSerializedGraphNode>& graphNodes,
std::vector<DmlInputSerializedGraphEdge>& graphInputEdges,
std::vector<DmlIntermediateSerializedGraphEdge>& graphIntermediateEdges,
std::vector<DmlOutputSerializedGraphEdge>& graphOutputEdges)
std::vector<DmlOutputSerializedGraphEdge>& graphOutputEdges,
std::unordered_map<uint32_t, uint32_t>& serializedGraphInputIndexToSubgraphInputIndex,
std::unordered_map<std::string_view, uint32_t>& serializedGraphLargeConstantNameToSubgraphInputIndex)
{
enum class NodeState
{
Expand Down Expand Up @@ -124,8 +126,10 @@
graphNodes.resize(graphNodes.size() - shift);

// Adjust the node indices in the input edges
std::unordered_set<uint32_t> usedInputEdgeIndex;
for (auto& inputEdge : graphInputEdges)
{
usedInputEdgeIndex.insert(inputEdge.GraphInputIndex);
inputEdge.ToNodeIndex = shiftedIndicesMapping[inputEdge.ToNodeIndex];
}

Expand All @@ -136,10 +140,54 @@
}

// Adjust the node indices in the intermediate edges
std::unordered_set<std::string> usedLargeConstantNames;

Check warning on line 143 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp:143: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
for (auto& intermediateEdge : graphIntermediateEdges)
{
intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex];
intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex];
// We need to update the edge name only when the name contains the intermediateEdge.FromNodeIndex
size_t pos = intermediateEdge.Name.find("nodeIdx:");
if (pos != std::string::npos)
{
if (pos != 0)
{
std::string constantNamePartComingFromModel = intermediateEdge.Name.substr(0, pos - 1);
usedLargeConstantNames.insert(constantNamePartComingFromModel); // need part of name which is coming from the model.

Check warning on line 155 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp:155: At least two spaces is best between code and comments [whitespace/comments] [2]
intermediateEdge.Name = constantNamePartComingFromModel;
intermediateEdge.Name += "-nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex);
}
else
{
intermediateEdge.Name = "nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex);
}
}
}


// Erase the mapping if the input Edge is not used by any node
for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end();)
sumitsays marked this conversation as resolved.
Show resolved Hide resolved
{
if (!usedInputEdgeIndex.count(it->first))
{
it = serializedGraphInputIndexToSubgraphInputIndex.erase(it);
}
else
{
it++;
}
}

Check warning on line 179 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp:179: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
// Erase the mapping if the input Edge is not used by any node
for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end();)
{
if (!usedLargeConstantNames.count(std::string(it->first)))
{
it = serializedGraphLargeConstantNameToSubgraphInputIndex.erase(it);
}
else
{
it++;
}
}
}

Expand Down Expand Up @@ -516,7 +564,12 @@
graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex);
}

RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges);
RemoveUnconnectedNodes(dmlGraphNodes,
dmlGraphInputEdges,
dmlGraphIntermediateEdges,
dmlGraphOutputEdges,
serializedGraphInputIndexToSubgraphInputIndex,
serializedGraphLargeConstantNameToSubgraphInputIndex);

GraphDesc graphDesc{};
graphDesc.InputCount = static_cast<uint32_t>(dmlGraphInputEdges.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include "External/DirectMLHelpers/DmlSerializedGraphDesc.h"
#include "External/DirectMLHelpers/DmlGraphSerialization.h"
#include "External/DirectMLHelpers/DmlGraphDeserialization.h"
#include "External/DirectMLHelpers/DmlGraphHelper.h"

using Microsoft::WRL::ComPtr;

Expand Down
Loading