Skip to content

Commit

Permalink
Fix serialization bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sumitsays committed Mar 1, 2024
1 parent ca03fdb commit 0c797cb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ template <typename EdgeType> void PopulateEdges(
if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end())
{
throw std::range_error("Neither there is any graph input with name " + edgeName->str() +
"nor there is any node which has " + edgeName->str() + " as one of the output.");
" nor there is any node which has " + edgeName->str() + " as one of the output.");
}
auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()];
DmlIntermediateSerializedGraphEdge intermediateEdge = {};
Expand Down Expand Up @@ -464,25 +464,32 @@ DmlSerializedGraphDesc DeserializeDmlGraph(
std::vector<DmlOutputSerializedGraphEdge> outputEdges;
std::vector<DmlIntermediateSerializedGraphEdge> intermediateEdges;

// Iterate the output edges first because nodes are not in topologically sorted.
for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++)
{
const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex);

PopulateEdges<DmlInputSerializedGraphEdge>(
nodeIndex,
flatbufferNode->inputNames(),
graphInputEdgeToIndexMap,
inputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);
PopulateEdges<DmlOutputSerializedGraphEdge>(
nodeIndex,
flatbufferNode->outputNames(),
graphOutputEdgeToIndexMap,
outputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);
}

for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++)
{
const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex);

PopulateEdges<DmlInputSerializedGraphEdge>(
nodeIndex,
flatbufferNode->inputNames(),
graphInputEdgeToIndexMap,
inputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);

Check warning on line 492 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp:492: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
DmlSerializedGraphNode node = {};
if (flatbufferNode->name()->size() == 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ namespace Dml::GraphDescBuilder
{
intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex];
intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex];
std::string oldEdgeName = intermediateEdge.Name;
size_t pos = oldEdgeName.find("nodeIdx:");
if (pos != std::string::npos)
{
if (pos != 0)
{
intermediateEdge.Name = oldEdgeName.substr(0, pos);
}
intermediateEdge.Name += "nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex);
}
}
}

Expand Down

0 comments on commit 0c797cb

Please sign in to comment.